more init

上级 78b30d32
import numpy
import core
from gof import PatternOptimizer as pattern_opt, OpSubOptimizer as op_sub
import scipy.weave as weave
"""
......@@ -10,50 +12,7 @@ by this file are aimed at the goal of inserting gemm Ops in place of more
fine-grained motifs of iadd, isub, scale, and dot.
"""
_gemm_support_code = """
template< typename T >
struct TMat_t
{
T * __restrict__ d;/**< pointer to element (0,0) */
size_t M; /**< number of rows */
size_t N; /**< number of columns */
size_t m; /**< row stride */
size_t n; /**< column stride */
bool invalid;
/** null */
TMat_t(const PyArrayObject *o) :
d((double*) o->data),
M((o->nd==2) ? o->dimensions[0] : 0),
N((o->nd==2) ? o->dimensions[1] : 0),
m((o->nd==2) ? o->strides[0] / sizeof(double) : 0),
n((o->nd==2) ? o->strides[1] / sizeof(double) : 0),
invalid((o->nd !=2)
|| (o->descr->elsize != sizeof(T)))
{
}
/** unsafe element access */
const T & operator()(size_t i, size_t j) const
{
return d[ i * m + j*n];
}
/** unsafe element access */
T & operator()(size_t i, size_t j)
{
return d[ i * m + j*n];
}
/** safe element access */
const T & at(size_t i, size_t j) const
{
return d[ assert((i < M) && (j < N)), i * m + j*n];
}
/** safe element access */
T & at(size_t i, size_t j)
{
return d[ assert((i < M) && (j < N)), i * m + j*n];
}
};
typedef TMat_t<double> mat_t;
_general_gemm_code = """
static int mat_gemm_general(double a, const mat_t &A, const mat_t &B, double b, mat_t &C)
{
......@@ -68,90 +27,82 @@ _gemm_support_code = """
}
return 0;
}
"""
static int mat_gemm(double a, const mat_t &A, const mat_t &B, double b, mat_t &C)
_gemm_code_template = """
if ((Nx[0] != Nz[0]) || (Nx[1] != Ny[0]) || (Ny[1] != Nz[1]))
{
if ((A.M != C.M) || (A.N != B.M) || (B.N != C.N))
{
PyErr_SetString(PyExc_ValueError, "mat_gemm input array size mismatch");
return 1;
}
if ((A.m < 1) || (A.n < 1)
||(B.m < 1) || (B.n < 1)
||(C.m < 1) || (C.n < 1))
{
return mat_gemm_general(a, A, B, b, C);
}
//TODO: OPTIMIZE for many special cases:
//- gemv
//- ger
//- ddot
//- others?
int unit = 0;
unit |= ((A.n == 1) ? 0x0 : (A.m == 1) ? 0x1 : 0x2) << 0;
unit |= ((B.n == 1) ? 0x0 : (B.m == 1) ? 0x1 : 0x2) << 4;
unit |= ((C.n == 1) ? 0x0 : (C.m == 1) ? 0x1 : 0x2) << 8;
/*
fprintf(stderr, "M N %zu %zu %zu %zu %zu %zu\n", A.M, A.N, B.M, B.N, C.M, C.N);
fprintf(stderr, "m n %zu %zu %zu %zu %zu %zu\n", A.m, A.n, B.m, B.n, C.m, C.n);
fprintf(stderr, "unit %i\n", unit);
*/
size_t A_m = (A.M > 1) ? A.m : A.N;
size_t A_n = (A.N > 1) ? A.n : A.M;
size_t B_m = (B.M > 1) ? B.m : B.N;
size_t B_n = (B.N > 1) ? B.n : B.M;
size_t C_m = (C.M > 1) ? C.m : C.N;
size_t C_n = (C.N > 1) ? C.n : C.M;
switch(unit)
{
case 0x000: cblas_dgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, C.M, C.N, A.N, a, A.d, A_m, B.d, B_m, b, C.d, C_m); break;
case 0x001: cblas_dgemm(CblasRowMajor, CblasTrans, CblasNoTrans, C.M, C.N, A.N, a, A.d, A_n, B.d, B_m, b, C.d, C_m); break;
case 0x010: cblas_dgemm(CblasRowMajor, CblasNoTrans, CblasTrans, C.M, C.N, A.N, a, A.d, A_m, B.d, B_n, b, C.d, C_m); break;
case 0x011: cblas_dgemm(CblasRowMajor, CblasTrans, CblasTrans, C.M, C.N, A.N, a, A.d, A_n, B.d, B_n, b, C.d, C_m); break;
case 0x100: cblas_dgemm(CblasColMajor, CblasTrans, CblasTrans, C.M, C.N, A.N, a, A.d, A_m, B.d, B_m, b, C.d, C_n); break;
case 0x101: cblas_dgemm(CblasColMajor, CblasNoTrans, CblasTrans, C.M, C.N, A.N, a, A.d, A_n, B.d, B_m, b, C.d, C_n); break;
case 0x110: cblas_dgemm(CblasColMajor, CblasTrans, CblasNoTrans, C.M, C.N, A.N, a, A.d, A_m, B.d, B_n, b, C.d, C_n); break;
case 0x111: cblas_dgemm(CblasColMajor, CblasNoTrans, CblasNoTrans, C.M, C.N, A.N, a, A.d, A_n, B.d, B_n, b, C.d, C_n); break;
default: mat_gemm_general(a, A, B, b, C); break;
};
return 0;
PyErr_SetString(PyExc_ValueError, "mat_gemm input array size mismatch");
fprintf(stderr, "Should be calling mat_gemm_general, but quitting instead\\n");
exit(1);
}
"""
_gemm_code = """
mat_t mx(x_array), my(y_array), mz(z_array);
if (mx.invalid || my.invalid || mz.invalid)
if ((Sx[0] < 1) || (Sx[1] < 1)
|| (Sy[0] < 1) || (Sy[1] < 1)
|| (Sz[0] < 1) || (Sz[1] < 1))
{
fprintf(stderr, "error in train_classifier_new.py, _gemm_code\\n");
fprintf(stderr, "Should be calling mat_gemm_general, but quitting instead\\n");
exit(1);
//return mat_gemm_general(a, A, B, b, C);
}
else
//TODO: OPTIMIZE for many special cases:
//- gemv
//- ger
//- ddot
//- others?
int unit = 0;
unit |= ((Sx[1] == sizeof(%(dtype)s)) ? 0x0 : (Sx[0] == sizeof(%(dtype)s)) ? 0x1 : 0x2) << 0;
unit |= ((Sy[1] == sizeof(%(dtype)s)) ? 0x0 : (Sy[0] == sizeof(%(dtype)s)) ? 0x1 : 0x2) << 4;
unit |= ((Sz[1] == sizeof(%(dtype)s)) ? 0x0 : (Sz[0] == sizeof(%(dtype)s)) ? 0x1 : 0x2) << 8;
/* create appropriate strides for malformed matrices that are row or column
* vectors
*/
size_t sx_0 = (Nx[0] > 1) ? Sx[0]/sizeof(%(dtype)s) : Nx[1];
size_t sx_1 = (Nx[1] > 1) ? Sx[1]/sizeof(%(dtype)s) : Nx[0];
size_t sy_0 = (Ny[0] > 1) ? Sy[0]/sizeof(%(dtype)s) : Ny[1];
size_t sy_1 = (Ny[1] > 1) ? Sy[1]/sizeof(%(dtype)s) : Ny[0];
size_t sz_0 = (Nz[0] > 1) ? Sz[0]/sizeof(%(dtype)s) : Nz[1];
size_t sz_1 = (Nz[1] > 1) ? Sz[1]/sizeof(%(dtype)s) : Nz[0];
switch(unit)
{
mat_gemm(a[0], mx, my, b[0], mz);
}
case 0x000: %(gemm)s(CblasRowMajor, CblasNoTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a[0], x, sx_0, y, sy_0, b[0], z, sz_0); break;
case 0x001: %(gemm)s(CblasRowMajor, CblasTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a[0], x, sx_1, y, sy_0, b[0], z, sz_0); break;
case 0x010: %(gemm)s(CblasRowMajor, CblasNoTrans, CblasTrans, Nz[0], Nz[1], Nx[1], a[0], x, sx_0, y, sy_1, b[0], z, sz_0); break;
case 0x011: %(gemm)s(CblasRowMajor, CblasTrans, CblasTrans, Nz[0], Nz[1], Nx[1], a[0], x, sx_1, y, sy_1, b[0], z, sz_0); break;
case 0x100: %(gemm)s(CblasColMajor, CblasTrans, CblasTrans, Nz[0], Nz[1], Nx[1], a[0], x, sx_0, y, sy_0, b[0], z, sz_1); break;
case 0x101: %(gemm)s(CblasColMajor, CblasNoTrans, CblasTrans, Nz[0], Nz[1], Nx[1], a[0], x, sx_1, y, sy_0, b[0], z, sz_1); break;
case 0x110: %(gemm)s(CblasColMajor, CblasTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a[0], x, sx_0, y, sy_1, b[0], z, sz_1); break;
case 0x111: %(gemm)s(CblasColMajor, CblasNoTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a[0], x, sx_1, y, sy_1, b[0], z, sz_1); break;
default:
fprintf(stderr, "Should be calling mat_gemm_general, but quitting instead\\n");
exit(1);
};
/* v 1 */
"""
"""
def _gemm(a, x, y, b, z):
weave.inline(_gemm_code,
_gemm_code = { 'f': _gemm_code_template % { 'gemm':'cblas_sgemm', 'dtype':'float'},
'd': _gemm_code_template % { 'gemm':'cblas_dgemm', 'dtype':'double'}}
def _gemm_rank2(a, x, y, b, z):
weave.inline(_gemm_code[z.dtype.char],
['a', 'x', 'y', 'b', 'z'],
support_code = _gemm_support_code,
headers=['<gsl/gsl_cblas.h>'],
libraries=['cblas','atlas', 'g2c'])
libraries=['cblas','goto', 'g2c'])
#TODO: modify gemm to work with vectors and tensors too!
# (trac ticket 18)
class gemm(core.omega_op, core.inplace):
def impl_unused(z, a,x,y,b):
def _gemm(a, x, y, b, z):
if len(x.shape) == 2 and len(y.shape) == 2:
_gemm_rank2(a, x, y, b, z)
else:
if b == 0.0:
if a == 1.0:
z = numpy.dot(x,y)
z[:] = numpy.dot(x,y)
elif a == -1.0:
z = -numpy.dot(x,y)
z[:] = -numpy.dot(x,y)
else:
z = a * numpy.dot(x,y)
z[:] = a * numpy.dot(x,y)
elif b == 1.0:
if a == 1.0:
z += numpy.dot(x,y)
......@@ -162,7 +113,19 @@ class gemm(core.omega_op, core.inplace):
else:
z *= b
z += a * numpy.dot(x,y)
return z[:]
_gdot_coefs = { 'f':
(numpy.ones((),dtype='float32'),numpy.zeros((),dtype='float32')),
'd': (numpy.ones(()),numpy.zeros(()))}
def _gdot(x,y):
a,b = _gdot_coefs[x.dtype.char]
z = numpy.ndarray((x.shape[0],y.shape[1]),dtype=x.dtype)
_gemm(a, x, y, b, z)
return z
class gemm(core.omega_op, core.inplace):
def impl(z, a, x, y, b):
_gemm(a, x, y, b, z)
return z[:]
......@@ -170,10 +133,20 @@ class gemm(core.omega_op, core.inplace):
def grad(x,gz):
raise NotImplemented
class gdot(core.omega_op):
impl = _gdot
#TODO: put more optimizations in here
def grad(x,gz):
raise NotImplemented
#TODO: put more optimizations in here (Trac # 18)
optimizations = [
pattern_opt(
(C.isub_elemwise, 'z', (C.dot,'x','y')),
(gemm, 'z', -1.0, 'x', 'y', 1.0))
(core.isub_elemwise, 'z', (core.dot,'x','y')),
(gemm, 'z', -1.0, 'x', 'y', 1.0)),
pattern_opt(
(core.dot,'x', 'y'),
(gdot, 'x', 'y'))
]
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论