tightened dot, blas

上级 5319c2a2
......@@ -254,9 +254,9 @@ class omega_op(gof.PythonOp):
mod.add_function(instantiate)
module_dir = os.path.expanduser('~/.omega/compiled')
sys.path.insert(0, module_dir)
mod.compile(location = module_dir)
sys.path.insert(0, module_dir)
module = __import__("%s" % module_name) #, {}, {}, [module_name])
sys.path = sys.path[1:]
......@@ -842,34 +842,15 @@ iinv_elemwise = inv_elemwise.inplace_version()
## Dot product ##
class dot(omega_op):
impl = numpy.dot
def grad(x, y, gz):
return dot(gz, transpose(y)), dot(transpose(x), gz)
def specs(x, y):
# todo: handle all tensors!
shape = (x[2][0], y[2][1])
return (numpy.ndarray, upcast(x[1], y[1]), shape)
def c_headers(self):
return ["<gsl/gsl_cblas.h>"]
def c_libs(self):
return ["cblas", "atlas", "g2c"]
def c_impl((_x, _y), (_z, )):
dtype = _x.spec[1]
if dtype.char == 'f':
gemm = 'cblas_sgemm'
elif dtype.char == 'd':
gemm = 'cblas_dgemm'
else:
raise NotImplementedError
class blas_code :
@staticmethod
def gemm_xyz(a_init, b_init):
mod = '%'
return """
%(dtype)s a = 1.0;
%(dtype)s b = 0.0;
const char * error_string = NULL;
%(dtype)s* x = (%(dtype)s*)PyArray_DATA(_x);
%(dtype)s* y = (%(dtype)s*)PyArray_DATA(_y);
%(dtype)s* z = (%(dtype)s*)PyArray_DATA(_z);
int type_num = _x->descr->type_num;
int type_size = _x->descr->elsize; // in bytes
npy_intp* Nx = _x->dimensions;
npy_intp* Ny = _y->dimensions;
......@@ -878,58 +859,170 @@ class dot(omega_op):
npy_intp* Sx = _x->strides;
npy_intp* Sy = _y->strides;
npy_intp* Sz = _z->strides;
size_t sx_0, sx_1, sy_0, sy_1, sz_0, sz_1;
int unit = 0;
if (_x->nd != 2) goto _dot_execute_fallback;
if (_y->nd != 2) goto _dot_execute_fallback;
if (_z->nd != 2) goto _dot_execute_fallback;
if ((_x->descr->type_num != PyArray_DOUBLE)
&& (_x->descr->type_num != PyArray_FLOAT))
goto _dot_execute_fallback;
if ((_y->descr->type_num != PyArray_DOUBLE)
&& (_y->descr->type_num != PyArray_FLOAT))
goto _dot_execute_fallback;
if ((_y->descr->type_num != PyArray_DOUBLE)
&& (_y->descr->type_num != PyArray_FLOAT))
goto _dot_execute_fallback;
if ((_x->descr->type_num != _y->descr->type_num)
||(_x->descr->type_num != _z->descr->type_num))
goto _dot_execute_fallback;
if ((Nx[0] != Nz[0]) || (Nx[1] != Ny[0]) || (Ny[1] != Nz[1]))
{
fprintf(stderr, "Should be calling mat_gemm_general, but quitting instead (1)\\n");
return 1;
error_string = "Input dimensions do not agree";
goto _dot_execute_fail;
}
if ((Sx[0] < 1) || (Sx[1] < 1)
|| (Sy[0] < 1) || (Sy[1] < 1)
|| (Sz[0] < 1) || (Sz[1] < 1))
if ((Sx[0] < 1) || (Sx[1] < 1) || (Sx[0] %(mod)s type_size) || (Sx[1] %(mod)s type_size)
|| (Sy[0] < 1) || (Sy[1] < 1) || (Sy[0] %(mod)s type_size) || (Sy[1] %(mod)s type_size)
|| (Sz[0] < 1) || (Sz[1] < 1) || (Sz[0] %(mod)s type_size) || (Sz[1] %(mod)s type_size))
{
fprintf(stderr, "Should be calling mat_gemm_general, but quitting instead (2)\\n");
return 1;
//return mat_gemm_general(a, A, B, b, C);
goto _dot_execute_fallback;
}
//TODO: OPTIMIZE for many special cases:
//- gemv
//- ger
//- ddot
//- others?
int unit = 0;
unit |= ((Sx[1] == sizeof(%(dtype)s)) ? 0x0 : (Sx[0] == sizeof(%(dtype)s)) ? 0x1 : 0x2) << 0;
unit |= ((Sy[1] == sizeof(%(dtype)s)) ? 0x0 : (Sy[0] == sizeof(%(dtype)s)) ? 0x1 : 0x2) << 4;
unit |= ((Sz[1] == sizeof(%(dtype)s)) ? 0x0 : (Sz[0] == sizeof(%(dtype)s)) ? 0x1 : 0x2) << 8;
/*
encode the stride structure of _x,_y,_z into a single integer
*/
unit |= ((Sx[1] == type_size) ? 0x0 : (Sx[0] == type_size) ? 0x1 : 0x2) << 0;
unit |= ((Sy[1] == type_size) ? 0x0 : (Sy[0] == type_size) ? 0x1 : 0x2) << 4;
unit |= ((Sz[1] == type_size) ? 0x0 : (Sz[0] == type_size) ? 0x1 : 0x2) << 8;
/* create appropriate strides for malformed matrices that are row or column
* vectors
*/
size_t sx_0 = (Nx[0] > 1) ? Sx[0]/sizeof(%(dtype)s) : Nx[1];
size_t sx_1 = (Nx[1] > 1) ? Sx[1]/sizeof(%(dtype)s) : Nx[0];
size_t sy_0 = (Ny[0] > 1) ? Sy[0]/sizeof(%(dtype)s) : Ny[1];
size_t sy_1 = (Ny[1] > 1) ? Sy[1]/sizeof(%(dtype)s) : Ny[0];
size_t sz_0 = (Nz[0] > 1) ? Sz[0]/sizeof(%(dtype)s) : Nz[1];
size_t sz_1 = (Nz[1] > 1) ? Sz[1]/sizeof(%(dtype)s) : Nz[0];
switch(unit)
sx_0 = (Nx[0] > 1) ? Sx[0]/type_size : Nx[1];
sx_1 = (Nx[1] > 1) ? Sx[1]/type_size : Nx[0];
sy_0 = (Ny[0] > 1) ? Sy[0]/type_size : Ny[1];
sy_1 = (Ny[1] > 1) ? Sy[1]/type_size : Ny[0];
sz_0 = (Nz[0] > 1) ? Sz[0]/type_size : Nz[1];
sz_1 = (Nz[1] > 1) ? Sz[1]/type_size : Nz[0];
switch (type_num)
{
case 0x000: %(gemm)s(CblasRowMajor, CblasNoTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_0, y, sy_0, b, z, sz_0); break;
case 0x001: %(gemm)s(CblasRowMajor, CblasTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_1, y, sy_0, b, z, sz_0); break;
case 0x010: %(gemm)s(CblasRowMajor, CblasNoTrans, CblasTrans, Nz[0], Nz[1], Nx[1], a, x, sx_0, y, sy_1, b, z, sz_0); break;
case 0x011: %(gemm)s(CblasRowMajor, CblasTrans, CblasTrans, Nz[0], Nz[1], Nx[1], a, x, sx_1, y, sy_1, b, z, sz_0); break;
case 0x100: %(gemm)s(CblasColMajor, CblasTrans, CblasTrans, Nz[0], Nz[1], Nx[1], a, x, sx_0, y, sy_0, b, z, sz_1); break;
case 0x101: %(gemm)s(CblasColMajor, CblasNoTrans, CblasTrans, Nz[0], Nz[1], Nx[1], a, x, sx_1, y, sy_0, b, z, sz_1); break;
case 0x110: %(gemm)s(CblasColMajor, CblasTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_0, y, sy_1, b, z, sz_1); break;
case 0x111: %(gemm)s(CblasColMajor, CblasNoTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_1, y, sy_1, b, z, sz_1); break;
default:
fprintf(stderr, "Should be calling mat_gemm_general, but quitting instead (3)\\n");
return 1;
};
case PyArray_FLOAT:
{
float a = %(a_init)s;
float b = %(b_init)s;
float* x = (float*)PyArray_DATA(_x);
float* y = (float*)PyArray_DATA(_y);
float* z = (float*)PyArray_DATA(_z);
switch(unit)
{
case 0x000: cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_0, y, sy_0, b, z, sz_0); break;
case 0x001: cblas_sgemm(CblasRowMajor, CblasTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_1, y, sy_0, b, z, sz_0); break;
case 0x010: cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, Nz[0], Nz[1], Nx[1], a, x, sx_0, y, sy_1, b, z, sz_0); break;
case 0x011: cblas_sgemm(CblasRowMajor, CblasTrans, CblasTrans, Nz[0], Nz[1], Nx[1], a, x, sx_1, y, sy_1, b, z, sz_0); break;
case 0x100: cblas_sgemm(CblasColMajor, CblasTrans, CblasTrans, Nz[0], Nz[1], Nx[1], a, x, sx_0, y, sy_0, b, z, sz_1); break;
case 0x101: cblas_sgemm(CblasColMajor, CblasNoTrans, CblasTrans, Nz[0], Nz[1], Nx[1], a, x, sx_1, y, sy_0, b, z, sz_1); break;
case 0x110: cblas_sgemm(CblasColMajor, CblasTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_0, y, sy_1, b, z, sz_1); break;
case 0x111: cblas_sgemm(CblasColMajor, CblasNoTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_1, y, sy_1, b, z, sz_1); break;
default: goto _dot_execute_fallback;
};
}
break;
case PyArray_DOUBLE:
{
double a = %(a_init)s;
double b = %(b_init)s;
double* x = (double*)PyArray_DATA(_x);
double* y = (double*)PyArray_DATA(_y);
double* z = (double*)PyArray_DATA(_z);
switch(unit)
{
case 0x000: cblas_dgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_0, y, sy_0, b, z, sz_0); break;
case 0x001: cblas_dgemm(CblasRowMajor, CblasTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_1, y, sy_0, b, z, sz_0); break;
case 0x010: cblas_dgemm(CblasRowMajor, CblasNoTrans, CblasTrans, Nz[0], Nz[1], Nx[1], a, x, sx_0, y, sy_1, b, z, sz_0); break;
case 0x011: cblas_dgemm(CblasRowMajor, CblasTrans, CblasTrans, Nz[0], Nz[1], Nx[1], a, x, sx_1, y, sy_1, b, z, sz_0); break;
case 0x100: cblas_dgemm(CblasColMajor, CblasTrans, CblasTrans, Nz[0], Nz[1], Nx[1], a, x, sx_0, y, sy_0, b, z, sz_1); break;
case 0x101: cblas_dgemm(CblasColMajor, CblasNoTrans, CblasTrans, Nz[0], Nz[1], Nx[1], a, x, sx_1, y, sy_0, b, z, sz_1); break;
case 0x110: cblas_dgemm(CblasColMajor, CblasTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_0, y, sy_1, b, z, sz_1); break;
case 0x111: cblas_dgemm(CblasColMajor, CblasNoTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_1, y, sy_1, b, z, sz_1); break;
default: goto _dot_execute_fallback;
};
}
break;
}
return 0; //success!
_dot_execute_fallback:
PyErr_SetString(PyExc_NotImplementedError,
"dot->execute() fallback");
return -1;
_dot_execute_fail:
if (error_string == NULL)
PyErr_SetString(PyExc_ValueError,
"dot->execute() cant run on these inputs");
return -1;
/* v 1 */
""" % dict(dtype = '_x_dtype', gemm = gemm)
""" % locals()
# currently unused, preferring the fallback method (throwing
# NotImplementedError) for when gemm won't work.
_templated_memaligned_gemm = """
template <typename Ta, typename Tx, typename Ty, typename Tb, typename Tz>
int general_gemm(int zM, int zN, int xN,.
Ta a,
Tx * x, int xm, int xn,
Tx * y, int ym, int yn,
Tb b,
Tz * z, int zm, int zn)
{
for (int i = 0; i < zM; ++i)
{
for (int j = 0; j < zN; ++j)
{
Tz zij = 0.0;
for (int k = 0; k < xN; ++k)
{
zij += x[i*xm+k*xn] * y[k*ym+j*yn];
}
z[i * zm + j * zn] *= b;
z[i * zm + j * zn] += a * zij;
}
}
}
"""
class dot(omega_op):
def impl(x, y):
z = numpy.dot(x,y)
#print z.dtype
return z
def grad(x, y, gz):
return dot(gz, transpose(y)), dot(transpose(x), gz)
def specs(x, y):
# todo: handle all tensors!
shape = (x[2][0], y[2][1])
return (numpy.ndarray, upcast(x[1], y[1]), shape)
def c_headers(self):
return ["<gsl/gsl_cblas.h>"]
def c_libs(self):
return ["cblas", "atlas", "g2c"]
def c_impl((_x, _y), (_z, )):
return blas_code.gemm_xyz('1.0', '0.0')
class gemm(omega_op, inplace):
......@@ -964,83 +1057,8 @@ class gemm(omega_op, inplace):
return ["<gsl/gsl_cblas.h>"]
def c_libs(self):
return ["cblas", "atlas", "g2c"]
def c_impl((_z, _a, _x, _y, _b), (_zout,)):
dtype = _x.spec[1]
if dtype.char == 'f':
cblas_gemm = 'cblas_sgemm'
elif dtype.char == 'd':
cblas_gemm = 'cblas_dgemm'
else:
raise NotImplementedError
return """
%(dtype)s a = ((%(dtype)s*)PyArray_DATA(_a))[0];
%(dtype)s b = ((%(dtype)s*)PyArray_DATA(_b))[0];
%(dtype)s* x = (%(dtype)s*)PyArray_DATA(_x);
%(dtype)s* y = (%(dtype)s*)PyArray_DATA(_y);
%(dtype)s* z = (%(dtype)s*)PyArray_DATA(_z);
npy_intp* Nx = _x->dimensions;
npy_intp* Ny = _y->dimensions;
npy_intp* Nz = _z->dimensions;
npy_intp* Sx = _x->strides;
npy_intp* Sy = _y->strides;
npy_intp* Sz = _z->strides;
if ((Nx[0] != Nz[0]) || (Nx[1] != Ny[0]) || (Ny[1] != Nz[1]))
{
fprintf(stderr, "Should be calling mat_gemm_general, but quitting instead\\n");
return 1;
}
if ((Sx[0] < 1) || (Sx[1] < 1)
|| (Sy[0] < 1) || (Sy[1] < 1)
|| (Sz[0] < 1) || (Sz[1] < 1))
{
fprintf(stderr, "Should be calling mat_gemm_general, but quitting instead\\n");
return 1;
//return mat_gemm_general(a, A, B, b, C);
}
//TODO: OPTIMIZE for many special cases:
//- gemv
//- ger
//- ddot
//- others?
int unit = 0;
unit |= ((Sx[1] == sizeof(%(dtype)s)) ? 0x0 : (Sx[0] == sizeof(%(dtype)s)) ? 0x1 : 0x2) << 0;
unit |= ((Sy[1] == sizeof(%(dtype)s)) ? 0x0 : (Sy[0] == sizeof(%(dtype)s)) ? 0x1 : 0x2) << 4;
unit |= ((Sz[1] == sizeof(%(dtype)s)) ? 0x0 : (Sz[0] == sizeof(%(dtype)s)) ? 0x1 : 0x2) << 8;
/* create appropriate strides for malformed matrices that are row or column
* vectors
*/
size_t sx_0 = (Nx[0] > 1) ? Sx[0]/sizeof(%(dtype)s) : Nx[1];
size_t sx_1 = (Nx[1] > 1) ? Sx[1]/sizeof(%(dtype)s) : Nx[0];
size_t sy_0 = (Ny[0] > 1) ? Sy[0]/sizeof(%(dtype)s) : Ny[1];
size_t sy_1 = (Ny[1] > 1) ? Sy[1]/sizeof(%(dtype)s) : Ny[0];
size_t sz_0 = (Nz[0] > 1) ? Sz[0]/sizeof(%(dtype)s) : Nz[1];
size_t sz_1 = (Nz[1] > 1) ? Sz[1]/sizeof(%(dtype)s) : Nz[0];
switch(unit)
{
case 0x000: %(cblas_gemm)s(CblasRowMajor, CblasNoTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_0, y, sy_0, b, z, sz_0); break;
case 0x001: %(cblas_gemm)s(CblasRowMajor, CblasTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_1, y, sy_0, b, z, sz_0); break;
case 0x010: %(cblas_gemm)s(CblasRowMajor, CblasNoTrans, CblasTrans, Nz[0], Nz[1], Nx[1], a, x, sx_0, y, sy_1, b, z, sz_0); break;
case 0x011: %(cblas_gemm)s(CblasRowMajor, CblasTrans, CblasTrans, Nz[0], Nz[1], Nx[1], a, x, sx_1, y, sy_1, b, z, sz_0); break;
case 0x100: %(cblas_gemm)s(CblasColMajor, CblasTrans, CblasTrans, Nz[0], Nz[1], Nx[1], a, x, sx_0, y, sy_0, b, z, sz_1); break;
case 0x101: %(cblas_gemm)s(CblasColMajor, CblasNoTrans, CblasTrans, Nz[0], Nz[1], Nx[1], a, x, sx_1, y, sy_0, b, z, sz_1); break;
case 0x110: %(cblas_gemm)s(CblasColMajor, CblasTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_0, y, sy_1, b, z, sz_1); break;
case 0x111: %(cblas_gemm)s(CblasColMajor, CblasNoTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_1, y, sy_1, b, z, sz_1); break;
default:
fprintf(stderr, "Should be calling mat_gemm_general, but quitting instead\\n");
return 1;
};
/* v 1 */
""" % dict(dtype = '_x_dtype', cblas_gemm = cblas_gemm)
def c_impl((_zin, _a, _x, _y, _b), (_z,)):
return blas_code.gemm_xyz(str(_a), str(_b))
## Transposition ##
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论