提交 71ada6d8 authored 作者: Olivier Breuleux's avatar Olivier Breuleux

merge

import numpy
import core import core
from gof import PatternOptimizer as pattern_opt, OpSubOptimizer as op_sub
import scipy.weave as weave import scipy.weave as weave
""" """
...@@ -10,50 +12,7 @@ by this file are aimed at the goal of inserting gemm Ops in place of more ...@@ -10,50 +12,7 @@ by this file are aimed at the goal of inserting gemm Ops in place of more
fine-grained motifs of iadd, isub, scale, and dot. fine-grained motifs of iadd, isub, scale, and dot.
""" """
_gemm_support_code = """ _general_gemm_code = """
template< typename T >
struct TMat_t
{
T * __restrict__ d;/**< pointer to element (0,0) */
size_t M; /**< number of rows */
size_t N; /**< number of columns */
size_t m; /**< row stride */
size_t n; /**< column stride */
bool invalid;
/** null */
TMat_t(const PyArrayObject *o) :
d((double*) o->data),
M((o->nd==2) ? o->dimensions[0] : 0),
N((o->nd==2) ? o->dimensions[1] : 0),
m((o->nd==2) ? o->strides[0] / sizeof(double) : 0),
n((o->nd==2) ? o->strides[1] / sizeof(double) : 0),
invalid((o->nd !=2)
|| (o->descr->elsize != sizeof(T)))
{
}
/** unsafe element access */
const T & operator()(size_t i, size_t j) const
{
return d[ i * m + j*n];
}
/** unsafe element access */
T & operator()(size_t i, size_t j)
{
return d[ i * m + j*n];
}
/** safe element access */
const T & at(size_t i, size_t j) const
{
return d[ assert((i < M) && (j < N)), i * m + j*n];
}
/** safe element access */
T & at(size_t i, size_t j)
{
return d[ assert((i < M) && (j < N)), i * m + j*n];
}
};
typedef TMat_t<double> mat_t;
static int mat_gemm_general(double a, const mat_t &A, const mat_t &B, double b, mat_t &C) static int mat_gemm_general(double a, const mat_t &A, const mat_t &B, double b, mat_t &C)
{ {
...@@ -68,19 +27,22 @@ _gemm_support_code = """ ...@@ -68,19 +27,22 @@ _gemm_support_code = """
} }
return 0; return 0;
} }
"""
static int mat_gemm(double a, const mat_t &A, const mat_t &B, double b, mat_t &C) _gemm_code_template = """
{ if ((Nx[0] != Nz[0]) || (Nx[1] != Ny[0]) || (Ny[1] != Nz[1]))
if ((A.M != C.M) || (A.N != B.M) || (B.N != C.N))
{ {
PyErr_SetString(PyExc_ValueError, "mat_gemm input array size mismatch"); PyErr_SetString(PyExc_ValueError, "mat_gemm input array size mismatch");
return 1; fprintf(stderr, "Should be calling mat_gemm_general, but quitting instead\\n");
exit(1);
} }
if ((A.m < 1) || (A.n < 1) if ((Sx[0] < 1) || (Sx[1] < 1)
||(B.m < 1) || (B.n < 1) || (Sy[0] < 1) || (Sy[1] < 1)
||(C.m < 1) || (C.n < 1)) || (Sz[0] < 1) || (Sz[1] < 1))
{ {
return mat_gemm_general(a, A, B, b, C); fprintf(stderr, "Should be calling mat_gemm_general, but quitting instead\\n");
exit(1);
//return mat_gemm_general(a, A, B, b, C);
} }
//TODO: OPTIMIZE for many special cases: //TODO: OPTIMIZE for many special cases:
...@@ -90,68 +52,57 @@ _gemm_support_code = """ ...@@ -90,68 +52,57 @@ _gemm_support_code = """
//- others? //- others?
int unit = 0; int unit = 0;
unit |= ((A.n == 1) ? 0x0 : (A.m == 1) ? 0x1 : 0x2) << 0; unit |= ((Sx[1] == sizeof(%(dtype)s)) ? 0x0 : (Sx[0] == sizeof(%(dtype)s)) ? 0x1 : 0x2) << 0;
unit |= ((B.n == 1) ? 0x0 : (B.m == 1) ? 0x1 : 0x2) << 4; unit |= ((Sy[1] == sizeof(%(dtype)s)) ? 0x0 : (Sy[0] == sizeof(%(dtype)s)) ? 0x1 : 0x2) << 4;
unit |= ((C.n == 1) ? 0x0 : (C.m == 1) ? 0x1 : 0x2) << 8; unit |= ((Sz[1] == sizeof(%(dtype)s)) ? 0x0 : (Sz[0] == sizeof(%(dtype)s)) ? 0x1 : 0x2) << 8;
/*
fprintf(stderr, "M N %zu %zu %zu %zu %zu %zu\n", A.M, A.N, B.M, B.N, C.M, C.N);
fprintf(stderr, "m n %zu %zu %zu %zu %zu %zu\n", A.m, A.n, B.m, B.n, C.m, C.n);
fprintf(stderr, "unit %i\n", unit);
*/
size_t A_m = (A.M > 1) ? A.m : A.N; /* create appropriate strides for malformed matrices that are row or column
size_t A_n = (A.N > 1) ? A.n : A.M; * vectors
size_t B_m = (B.M > 1) ? B.m : B.N; */
size_t B_n = (B.N > 1) ? B.n : B.M; size_t sx_0 = (Nx[0] > 1) ? Sx[0]/sizeof(%(dtype)s) : Nx[1];
size_t C_m = (C.M > 1) ? C.m : C.N; size_t sx_1 = (Nx[1] > 1) ? Sx[1]/sizeof(%(dtype)s) : Nx[0];
size_t C_n = (C.N > 1) ? C.n : C.M; size_t sy_0 = (Ny[0] > 1) ? Sy[0]/sizeof(%(dtype)s) : Ny[1];
size_t sy_1 = (Ny[1] > 1) ? Sy[1]/sizeof(%(dtype)s) : Ny[0];
size_t sz_0 = (Nz[0] > 1) ? Sz[0]/sizeof(%(dtype)s) : Nz[1];
size_t sz_1 = (Nz[1] > 1) ? Sz[1]/sizeof(%(dtype)s) : Nz[0];
switch(unit) switch(unit)
{ {
case 0x000: cblas_dgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, C.M, C.N, A.N, a, A.d, A_m, B.d, B_m, b, C.d, C_m); break; case 0x000: %(gemm)s(CblasRowMajor, CblasNoTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a[0], x, sx_0, y, sy_0, b[0], z, sz_0); break;
case 0x001: cblas_dgemm(CblasRowMajor, CblasTrans, CblasNoTrans, C.M, C.N, A.N, a, A.d, A_n, B.d, B_m, b, C.d, C_m); break; case 0x001: %(gemm)s(CblasRowMajor, CblasTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a[0], x, sx_1, y, sy_0, b[0], z, sz_0); break;
case 0x010: cblas_dgemm(CblasRowMajor, CblasNoTrans, CblasTrans, C.M, C.N, A.N, a, A.d, A_m, B.d, B_n, b, C.d, C_m); break; case 0x010: %(gemm)s(CblasRowMajor, CblasNoTrans, CblasTrans, Nz[0], Nz[1], Nx[1], a[0], x, sx_0, y, sy_1, b[0], z, sz_0); break;
case 0x011: cblas_dgemm(CblasRowMajor, CblasTrans, CblasTrans, C.M, C.N, A.N, a, A.d, A_n, B.d, B_n, b, C.d, C_m); break; case 0x011: %(gemm)s(CblasRowMajor, CblasTrans, CblasTrans, Nz[0], Nz[1], Nx[1], a[0], x, sx_1, y, sy_1, b[0], z, sz_0); break;
case 0x100: cblas_dgemm(CblasColMajor, CblasTrans, CblasTrans, C.M, C.N, A.N, a, A.d, A_m, B.d, B_m, b, C.d, C_n); break; case 0x100: %(gemm)s(CblasColMajor, CblasTrans, CblasTrans, Nz[0], Nz[1], Nx[1], a[0], x, sx_0, y, sy_0, b[0], z, sz_1); break;
case 0x101: cblas_dgemm(CblasColMajor, CblasNoTrans, CblasTrans, C.M, C.N, A.N, a, A.d, A_n, B.d, B_m, b, C.d, C_n); break; case 0x101: %(gemm)s(CblasColMajor, CblasNoTrans, CblasTrans, Nz[0], Nz[1], Nx[1], a[0], x, sx_1, y, sy_0, b[0], z, sz_1); break;
case 0x110: cblas_dgemm(CblasColMajor, CblasTrans, CblasNoTrans, C.M, C.N, A.N, a, A.d, A_m, B.d, B_n, b, C.d, C_n); break; case 0x110: %(gemm)s(CblasColMajor, CblasTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a[0], x, sx_0, y, sy_1, b[0], z, sz_1); break;
case 0x111: cblas_dgemm(CblasColMajor, CblasNoTrans, CblasNoTrans, C.M, C.N, A.N, a, A.d, A_n, B.d, B_n, b, C.d, C_n); break; case 0x111: %(gemm)s(CblasColMajor, CblasNoTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a[0], x, sx_1, y, sy_1, b[0], z, sz_1); break;
default: mat_gemm_general(a, A, B, b, C); break; default:
fprintf(stderr, "Should be calling mat_gemm_general, but quitting instead\\n");
exit(1);
}; };
return 0; /* v 1 */
} """
"""
_gemm_code = """
mat_t mx(x_array), my(y_array), mz(z_array);
if (mx.invalid || my.invalid || mz.invalid)
{
fprintf(stderr, "error in train_classifier_new.py, _gemm_code\\n");
}
else
{
mat_gemm(a[0], mx, my, b[0], mz);
}
""" _gemm_code = { 'f': _gemm_code_template % { 'gemm':'cblas_sgemm', 'dtype':'float'},
def _gemm(a, x, y, b, z): 'd': _gemm_code_template % { 'gemm':'cblas_dgemm', 'dtype':'double'}}
weave.inline(_gemm_code,
def _gemm_rank2(a, x, y, b, z):
weave.inline(_gemm_code[z.dtype.char],
['a', 'x', 'y', 'b', 'z'], ['a', 'x', 'y', 'b', 'z'],
support_code = _gemm_support_code,
headers=['<gsl/gsl_cblas.h>'], headers=['<gsl/gsl_cblas.h>'],
libraries=['cblas','atlas', 'g2c']) libraries=['cblas','goto', 'g2c'])
#TODO: modify gemm to work with vectors and tensors too! def _gemm(a, x, y, b, z):
# (trac ticket 18) if len(x.shape) == 2 and len(y.shape) == 2:
class gemm(core.omega_op, core.inplace): _gemm_rank2(a, x, y, b, z)
def impl_unused(z, a,x,y,b): else:
if b == 0.0: if b == 0.0:
if a == 1.0: if a == 1.0:
z = numpy.dot(x,y) z[:] = numpy.dot(x,y)
elif a == -1.0: elif a == -1.0:
z = -numpy.dot(x,y) z[:] = -numpy.dot(x,y)
else: else:
z = a * numpy.dot(x,y) z[:] = a * numpy.dot(x,y)
elif b == 1.0: elif b == 1.0:
if a == 1.0: if a == 1.0:
z += numpy.dot(x,y) z += numpy.dot(x,y)
...@@ -162,7 +113,19 @@ class gemm(core.omega_op, core.inplace): ...@@ -162,7 +113,19 @@ class gemm(core.omega_op, core.inplace):
else: else:
z *= b z *= b
z += a * numpy.dot(x,y) z += a * numpy.dot(x,y)
return z[:]
_gdot_coefs = { 'f':
(numpy.ones((),dtype='float32'),numpy.zeros((),dtype='float32')),
'd': (numpy.ones(()),numpy.zeros(()))}
def _gdot(x,y):
a,b = _gdot_coefs[x.dtype.char]
z = numpy.ndarray((x.shape[0],y.shape[1]),dtype=x.dtype)
_gemm(a, x, y, b, z)
return z
class gemm(core.omega_op, core.inplace):
def impl(z, a, x, y, b): def impl(z, a, x, y, b):
_gemm(a, x, y, b, z) _gemm(a, x, y, b, z)
return z[:] return z[:]
...@@ -170,10 +133,20 @@ class gemm(core.omega_op, core.inplace): ...@@ -170,10 +133,20 @@ class gemm(core.omega_op, core.inplace):
def grad(x,gz): def grad(x,gz):
raise NotImplemented raise NotImplemented
class gdot(core.omega_op):
impl = _gdot
#TODO: put more optimizations in here def grad(x,gz):
raise NotImplemented
#TODO: put more optimizations in here (Trac # 18)
optimizations = [ optimizations = [
pattern_opt( pattern_opt(
(C.isub_elemwise, 'z', (C.dot,'x','y')), (core.isub_elemwise, 'z', (core.dot,'x','y')),
(gemm, 'z', -1.0, 'x', 'y', 1.0)) (gemm, 'z', -1.0, 'x', 'y', 1.0)),
pattern_opt(
(core.dot,'x', 'y'),
(gdot, 'x', 'y'))
] ]
...@@ -34,19 +34,44 @@ class profile_linker: ...@@ -34,19 +34,44 @@ class profile_linker:
self.order = env.toposort() self.order = env.toposort()
self.thunks = [op._perform for op in self.order] self.thunks = [op._perform for op in self.order]
self.n_calls = 0 self.n_calls = 0
self.n_thunks = 0
self.times = [0.0 for op in self.order] self.times = [0.0 for op in self.order]
def __call__(self): #TODO: popen2("dot -Tpng | display") and actually make the graph window pop up
def print_for_dot(self):
print "digraph unix { size = '6,6'; node [color = lightblue2; style = filled];"
for op in self.order:
for input in op.inputs:
if input.owner:
print input.owner.__class__.__name__ + str(abs(id(input.owner))), " -> ", op.__class__.__name__ + str(abs(id(op))), ";"
def slow_call(self):
"""Run the program, timing each thunk. """
for i, thunk in enumerate(self.thunks): for i, thunk in enumerate(self.thunks):
start_time = time.time() start_time = time.time()
thunk() thunk()
self.times[i] += time.time() - start_time self.times[i] += time.time() - start_time
self.n_thunks += 1
self.n_calls += 1
def fast_call(self):
"""Run the program, but only time the entire loop."""
start_time = time.time()
for th in self.thunks:
th()
self.n_thunks += len(self.thunks)
self.n_calls += 1 self.n_calls += 1
self.times[0] += time.time() - start_time
__call__ = slow_call
def dump(self): def dump(self, proportion=True):
"""Print statistics accumulated so far."""
total_time = sum(self.times) total_time = sum(self.times)
print self.n_calls, 'calls took', total_time, 'seconds' print self.n_calls, 'calls took', total_time, 'seconds to evaluate',
print self.n_thunks, 'thunks'
if 0:
print 'Proportion of CPU per op' print 'Proportion of CPU per op'
for op, t in zip(self.order, self.times): for op, t in zip(self.order, self.times):
s_op = str(op).split()[0][1:] s_op = str(op).split()[0][1:]
...@@ -58,7 +83,10 @@ class profile_linker: ...@@ -58,7 +83,10 @@ class profile_linker:
s_op = str(op).split()[0][1:] s_op = str(op).split()[0][1:]
dct[s_op] = dct.get(s_op, 0.0) + t dct[s_op] = dct.get(s_op, 0.0) + t
for t, s_op in reversed(sorted([(t,op) for op, t in dct.items()])): for t, s_op in reversed(sorted([(t,op) for op, t in dct.items()])):
if proportion:
print " %-35s %4.5f"% (s_op, t/total_time) print " %-35s %4.5f"% (s_op, t/total_time)
else:
print " %-35s %4.5f"% (s_op, t)
......
...@@ -752,10 +752,10 @@ iscale = scale.inplace_version() ...@@ -752,10 +752,10 @@ iscale = scale.inplace_version()
class sqr(elemwise): class sqr(elemwise):
def grad(x, gz):
return scale(mul_elemwise(x, gz), 2.0)
def impl(x): def impl(x):
return x * x return x * x
def grad(x, gz):
return scale(mul_elemwise(x, gz), 2.0)
def c_foreach((x, ), (z, )): def c_foreach((x, ), (z, )):
"z = x * x;" "z = x * x;"
...@@ -775,9 +775,9 @@ isqr.impl = lambda x: x.__imul__(x) ...@@ -775,9 +775,9 @@ isqr.impl = lambda x: x.__imul__(x)
class sqrt(elemwise): class sqrt(elemwise):
impl = numpy.sqrt
def grad(x, gz): def grad(x, gz):
return scale(div(gz, sqrt(x)), 0.5) return scale(div(gz, sqrt(x)), 0.5)
impl = numpy.sqrt
def c_foreach((x, ), (z, )): def c_foreach((x, ), (z, )):
"z = pow(x, 0.5);" "z = pow(x, 0.5);"
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论