提交 71ada6d8 authored 作者: Olivier Breuleux's avatar Olivier Breuleux

merge

import numpy
import core
from gof import PatternOptimizer as pattern_opt, OpSubOptimizer as op_sub
import scipy.weave as weave
"""
......@@ -10,50 +12,7 @@ by this file are aimed at the goal of inserting gemm Ops in place of more
fine-grained motifs of iadd, isub, scale, and dot.
"""
_gemm_support_code = """
template< typename T >
struct TMat_t
{
T * __restrict__ d;/**< pointer to element (0,0) */
size_t M; /**< number of rows */
size_t N; /**< number of columns */
size_t m; /**< row stride */
size_t n; /**< column stride */
bool invalid;
/** null */
TMat_t(const PyArrayObject *o) :
d((double*) o->data),
M((o->nd==2) ? o->dimensions[0] : 0),
N((o->nd==2) ? o->dimensions[1] : 0),
m((o->nd==2) ? o->strides[0] / sizeof(double) : 0),
n((o->nd==2) ? o->strides[1] / sizeof(double) : 0),
invalid((o->nd !=2)
|| (o->descr->elsize != sizeof(T)))
{
}
/** unsafe element access */
const T & operator()(size_t i, size_t j) const
{
return d[ i * m + j*n];
}
/** unsafe element access */
T & operator()(size_t i, size_t j)
{
return d[ i * m + j*n];
}
/** safe element access */
const T & at(size_t i, size_t j) const
{
return d[ assert((i < M) && (j < N)), i * m + j*n];
}
/** safe element access */
T & at(size_t i, size_t j)
{
return d[ assert((i < M) && (j < N)), i * m + j*n];
}
};
typedef TMat_t<double> mat_t;
_general_gemm_code = """
static int mat_gemm_general(double a, const mat_t &A, const mat_t &B, double b, mat_t &C)
{
......@@ -68,19 +27,22 @@ _gemm_support_code = """
}
return 0;
}
"""
static int mat_gemm(double a, const mat_t &A, const mat_t &B, double b, mat_t &C)
{
if ((A.M != C.M) || (A.N != B.M) || (B.N != C.N))
_gemm_code_template = """
if ((Nx[0] != Nz[0]) || (Nx[1] != Ny[0]) || (Ny[1] != Nz[1]))
{
PyErr_SetString(PyExc_ValueError, "mat_gemm input array size mismatch");
return 1;
fprintf(stderr, "Should be calling mat_gemm_general, but quitting instead\\n");
exit(1);
}
if ((A.m < 1) || (A.n < 1)
||(B.m < 1) || (B.n < 1)
||(C.m < 1) || (C.n < 1))
if ((Sx[0] < 1) || (Sx[1] < 1)
|| (Sy[0] < 1) || (Sy[1] < 1)
|| (Sz[0] < 1) || (Sz[1] < 1))
{
return mat_gemm_general(a, A, B, b, C);
fprintf(stderr, "Should be calling mat_gemm_general, but quitting instead\\n");
exit(1);
//return mat_gemm_general(a, A, B, b, C);
}
//TODO: OPTIMIZE for many special cases:
......@@ -90,68 +52,57 @@ _gemm_support_code = """
//- others?
int unit = 0;
unit |= ((A.n == 1) ? 0x0 : (A.m == 1) ? 0x1 : 0x2) << 0;
unit |= ((B.n == 1) ? 0x0 : (B.m == 1) ? 0x1 : 0x2) << 4;
unit |= ((C.n == 1) ? 0x0 : (C.m == 1) ? 0x1 : 0x2) << 8;
/*
fprintf(stderr, "M N %zu %zu %zu %zu %zu %zu\n", A.M, A.N, B.M, B.N, C.M, C.N);
fprintf(stderr, "m n %zu %zu %zu %zu %zu %zu\n", A.m, A.n, B.m, B.n, C.m, C.n);
fprintf(stderr, "unit %i\n", unit);
*/
unit |= ((Sx[1] == sizeof(%(dtype)s)) ? 0x0 : (Sx[0] == sizeof(%(dtype)s)) ? 0x1 : 0x2) << 0;
unit |= ((Sy[1] == sizeof(%(dtype)s)) ? 0x0 : (Sy[0] == sizeof(%(dtype)s)) ? 0x1 : 0x2) << 4;
unit |= ((Sz[1] == sizeof(%(dtype)s)) ? 0x0 : (Sz[0] == sizeof(%(dtype)s)) ? 0x1 : 0x2) << 8;
size_t A_m = (A.M > 1) ? A.m : A.N;
size_t A_n = (A.N > 1) ? A.n : A.M;
size_t B_m = (B.M > 1) ? B.m : B.N;
size_t B_n = (B.N > 1) ? B.n : B.M;
size_t C_m = (C.M > 1) ? C.m : C.N;
size_t C_n = (C.N > 1) ? C.n : C.M;
/* create appropriate strides for malformed matrices that are row or column
* vectors
*/
size_t sx_0 = (Nx[0] > 1) ? Sx[0]/sizeof(%(dtype)s) : Nx[1];
size_t sx_1 = (Nx[1] > 1) ? Sx[1]/sizeof(%(dtype)s) : Nx[0];
size_t sy_0 = (Ny[0] > 1) ? Sy[0]/sizeof(%(dtype)s) : Ny[1];
size_t sy_1 = (Ny[1] > 1) ? Sy[1]/sizeof(%(dtype)s) : Ny[0];
size_t sz_0 = (Nz[0] > 1) ? Sz[0]/sizeof(%(dtype)s) : Nz[1];
size_t sz_1 = (Nz[1] > 1) ? Sz[1]/sizeof(%(dtype)s) : Nz[0];
switch(unit)
{
case 0x000: cblas_dgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, C.M, C.N, A.N, a, A.d, A_m, B.d, B_m, b, C.d, C_m); break;
case 0x001: cblas_dgemm(CblasRowMajor, CblasTrans, CblasNoTrans, C.M, C.N, A.N, a, A.d, A_n, B.d, B_m, b, C.d, C_m); break;
case 0x010: cblas_dgemm(CblasRowMajor, CblasNoTrans, CblasTrans, C.M, C.N, A.N, a, A.d, A_m, B.d, B_n, b, C.d, C_m); break;
case 0x011: cblas_dgemm(CblasRowMajor, CblasTrans, CblasTrans, C.M, C.N, A.N, a, A.d, A_n, B.d, B_n, b, C.d, C_m); break;
case 0x100: cblas_dgemm(CblasColMajor, CblasTrans, CblasTrans, C.M, C.N, A.N, a, A.d, A_m, B.d, B_m, b, C.d, C_n); break;
case 0x101: cblas_dgemm(CblasColMajor, CblasNoTrans, CblasTrans, C.M, C.N, A.N, a, A.d, A_n, B.d, B_m, b, C.d, C_n); break;
case 0x110: cblas_dgemm(CblasColMajor, CblasTrans, CblasNoTrans, C.M, C.N, A.N, a, A.d, A_m, B.d, B_n, b, C.d, C_n); break;
case 0x111: cblas_dgemm(CblasColMajor, CblasNoTrans, CblasNoTrans, C.M, C.N, A.N, a, A.d, A_n, B.d, B_n, b, C.d, C_n); break;
default: mat_gemm_general(a, A, B, b, C); break;
case 0x000: %(gemm)s(CblasRowMajor, CblasNoTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a[0], x, sx_0, y, sy_0, b[0], z, sz_0); break;
case 0x001: %(gemm)s(CblasRowMajor, CblasTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a[0], x, sx_1, y, sy_0, b[0], z, sz_0); break;
case 0x010: %(gemm)s(CblasRowMajor, CblasNoTrans, CblasTrans, Nz[0], Nz[1], Nx[1], a[0], x, sx_0, y, sy_1, b[0], z, sz_0); break;
case 0x011: %(gemm)s(CblasRowMajor, CblasTrans, CblasTrans, Nz[0], Nz[1], Nx[1], a[0], x, sx_1, y, sy_1, b[0], z, sz_0); break;
case 0x100: %(gemm)s(CblasColMajor, CblasTrans, CblasTrans, Nz[0], Nz[1], Nx[1], a[0], x, sx_0, y, sy_0, b[0], z, sz_1); break;
case 0x101: %(gemm)s(CblasColMajor, CblasNoTrans, CblasTrans, Nz[0], Nz[1], Nx[1], a[0], x, sx_1, y, sy_0, b[0], z, sz_1); break;
case 0x110: %(gemm)s(CblasColMajor, CblasTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a[0], x, sx_0, y, sy_1, b[0], z, sz_1); break;
case 0x111: %(gemm)s(CblasColMajor, CblasNoTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a[0], x, sx_1, y, sy_1, b[0], z, sz_1); break;
default:
fprintf(stderr, "Should be calling mat_gemm_general, but quitting instead\\n");
exit(1);
};
return 0;
}
"""
_gemm_code = """
mat_t mx(x_array), my(y_array), mz(z_array);
if (mx.invalid || my.invalid || mz.invalid)
{
fprintf(stderr, "error in train_classifier_new.py, _gemm_code\\n");
}
else
{
mat_gemm(a[0], mx, my, b[0], mz);
}
/* v 1 */
"""
"""
def _gemm(a, x, y, b, z):
weave.inline(_gemm_code,
_gemm_code = { 'f': _gemm_code_template % { 'gemm':'cblas_sgemm', 'dtype':'float'},
'd': _gemm_code_template % { 'gemm':'cblas_dgemm', 'dtype':'double'}}
def _gemm_rank2(a, x, y, b, z):
weave.inline(_gemm_code[z.dtype.char],
['a', 'x', 'y', 'b', 'z'],
support_code = _gemm_support_code,
headers=['<gsl/gsl_cblas.h>'],
libraries=['cblas','atlas', 'g2c'])
libraries=['cblas','goto', 'g2c'])
#TODO: modify gemm to work with vectors and tensors too!
# (trac ticket 18)
class gemm(core.omega_op, core.inplace):
def impl_unused(z, a,x,y,b):
def _gemm(a, x, y, b, z):
if len(x.shape) == 2 and len(y.shape) == 2:
_gemm_rank2(a, x, y, b, z)
else:
if b == 0.0:
if a == 1.0:
z = numpy.dot(x,y)
z[:] = numpy.dot(x,y)
elif a == -1.0:
z = -numpy.dot(x,y)
z[:] = -numpy.dot(x,y)
else:
z = a * numpy.dot(x,y)
z[:] = a * numpy.dot(x,y)
elif b == 1.0:
if a == 1.0:
z += numpy.dot(x,y)
......@@ -162,7 +113,19 @@ class gemm(core.omega_op, core.inplace):
else:
z *= b
z += a * numpy.dot(x,y)
return z[:]
_gdot_coefs = { 'f':
(numpy.ones((),dtype='float32'),numpy.zeros((),dtype='float32')),
'd': (numpy.ones(()),numpy.zeros(()))}
def _gdot(x,y):
a,b = _gdot_coefs[x.dtype.char]
z = numpy.ndarray((x.shape[0],y.shape[1]),dtype=x.dtype)
_gemm(a, x, y, b, z)
return z
class gemm(core.omega_op, core.inplace):
def impl(z, a, x, y, b):
_gemm(a, x, y, b, z)
return z[:]
......@@ -170,10 +133,20 @@ class gemm(core.omega_op, core.inplace):
def grad(x,gz):
raise NotImplemented
class gdot(core.omega_op):
impl = _gdot
#TODO: put more optimizations in here
def grad(x,gz):
raise NotImplemented
#TODO: put more optimizations in here (Trac # 18)
optimizations = [
pattern_opt(
(C.isub_elemwise, 'z', (C.dot,'x','y')),
(gemm, 'z', -1.0, 'x', 'y', 1.0))
(core.isub_elemwise, 'z', (core.dot,'x','y')),
(gemm, 'z', -1.0, 'x', 'y', 1.0)),
pattern_opt(
(core.dot,'x', 'y'),
(gdot, 'x', 'y'))
]
......@@ -34,19 +34,44 @@ class profile_linker:
self.order = env.toposort()
self.thunks = [op._perform for op in self.order]
self.n_calls = 0
self.n_thunks = 0
self.times = [0.0 for op in self.order]
def __call__(self):
#TODO: popen2("dot -Tpng | display") and actually make the graph window pop up
def print_for_dot(self):
print "digraph unix { size = '6,6'; node [color = lightblue2; style = filled];"
for op in self.order:
for input in op.inputs:
if input.owner:
print input.owner.__class__.__name__ + str(abs(id(input.owner))), " -> ", op.__class__.__name__ + str(abs(id(op))), ";"
def slow_call(self):
"""Run the program, timing each thunk. """
for i, thunk in enumerate(self.thunks):
start_time = time.time()
thunk()
self.times[i] += time.time() - start_time
self.n_thunks += 1
self.n_calls += 1
def fast_call(self):
"""Run the program, but only time the entire loop."""
start_time = time.time()
for th in self.thunks:
th()
self.n_thunks += len(self.thunks)
self.n_calls += 1
self.times[0] += time.time() - start_time
__call__ = slow_call
def dump(self):
def dump(self, proportion=True):
"""Print statistics accumulated so far."""
total_time = sum(self.times)
print self.n_calls, 'calls took', total_time, 'seconds'
print self.n_calls, 'calls took', total_time, 'seconds to evaluate',
print self.n_thunks, 'thunks'
if 0:
print 'Proportion of CPU per op'
for op, t in zip(self.order, self.times):
s_op = str(op).split()[0][1:]
......@@ -58,7 +83,10 @@ class profile_linker:
s_op = str(op).split()[0][1:]
dct[s_op] = dct.get(s_op, 0.0) + t
for t, s_op in reversed(sorted([(t,op) for op, t in dct.items()])):
if proportion:
print " %-35s %4.5f"% (s_op, t/total_time)
else:
print " %-35s %4.5f"% (s_op, t)
......
......@@ -752,10 +752,10 @@ iscale = scale.inplace_version()
class sqr(elemwise):
def grad(x, gz):
return scale(mul_elemwise(x, gz), 2.0)
def impl(x):
return x * x
def grad(x, gz):
return scale(mul_elemwise(x, gz), 2.0)
def c_foreach((x, ), (z, )):
"z = x * x;"
......@@ -775,9 +775,9 @@ isqr.impl = lambda x: x.__imul__(x)
class sqrt(elemwise):
impl = numpy.sqrt
def grad(x, gz):
return scale(div(gz, sqrt(x)), 0.5)
impl = numpy.sqrt
def c_foreach((x, ), (z, )):
"z = pow(x, 0.5);"
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论