提交 4ce8a480 authored 作者: Tim Cooijmans's avatar Tim Cooijmans

BatchedDot: first stab at C implementation

上级 cccef96c
......@@ -3420,6 +3420,253 @@ class BatchedDot(Op):
for i in xrange(z[0].shape[0]):
z[0][i] = numpy.dot(x[i], y[i])
def c_support_code(self):
from theano.tensor.blas_headers import blas_header_text
return blas_header_text()
def c_libraries(self):
from theano.tensor.blas import ldflags
return ldflags()
def c_compile_args(self):
from theano.tensor.blas import ldflags
return ldflags(libs=False, flags=True)
def c_lib_dirs(self):
from theano.tensor.blas import ldflags
return ldflags(libs=False, libs_dir=True)
def c_header_dirs(self):
from theano.tensor.blas import ldflags
return ldflags(libs=False, include_dir=True)
def c_code(self, node, name, inp, out, sub):
_x, _y = inp
_zout, = out
fail = sub["fail"]
return """
int unit = 0;
int type_num = PyArray_DESCR(%(_x)s)->type_num;
int type_size = PyArray_DESCR(%(_x)s)->elsize; // in bytes
npy_intp* Nx = PyArray_DIMS(%(_x)s);
npy_intp* Ny = PyArray_DIMS(%(_y)s);
npy_intp* Nz = 0;
npy_intp* Sx = PyArray_STRIDES(%(_x)s);
npy_intp* Sy = PyArray_STRIDES(%(_y)s);
npy_intp* Sz = 0;
// strides for x, y, z in dimensions 1, 2
int sx_1 = 0, sx_2 = 0, sy_1 = 0, sy_2 = 0, sz_1 = 0, sz_2 = 0;
if (PyArray_NDIM(%(_x)s) != 3) {
PyErr_Format(PyExc_NotImplementedError,
"rank(x) != 3. rank(x) is %%d.", PyArray_NDIM(%(_x)s));
%(fail)s;
}
if (PyArray_NDIM(%(_y)s) != 3) {
PyErr_Format(PyExc_NotImplementedError,
"rank(y) != 3. rank(y) is %%d.", PyArray_NDIM(%(_y)s));
%(fail)s;
}
if (%(_zout)s && PyArray_NDIM(%(_zout)s) != 3) {
PyErr_Format(PyExc_NotImplementedError,
"rank(z) != 3. rank(z) is %%d.", PyArray_NDIM(%(_zout)s));
%(fail)s;
}
if (Nx[0] != Ny[0]) {
PyErr_Format(PyExc_ValueError,
"Shape mismatch: batch sizes unequal."
" x.shape is (%%d, %%d, %%d),"
" y.shape is (%%d, %%d, %%d).",
Nx[0], Nx[1], Nx[2],
Ny[0], Ny[1], Ny[2]);
%(fail)s;
}
if (Nx[2] != Ny[1])
{
PyErr_Format(PyExc_ValueError,
"Shape mismatch: summation axis sizes unequal."
" x.shape is (%%d, %%d, %%d),"
" y.shape is (%%d, %%d, %%d).",
Nx[0], Nx[1], Nx[2],
Ny[0], Ny[1], Ny[2]);
%(fail)s;
}
if ((NULL == %(_zout)s)
|| (PyArray_DIMS(%(_zout)s)[0] != PyArray_DIMS(%(_x)s)[0])
|| (PyArray_DIMS(%(_zout)s)[1] != PyArray_DIMS(%(_x)s)[1])
|| (PyArray_DIMS(%(_zout)s)[2] != PyArray_DIMS(%(_y)s)[2]))
{
npy_intp dims[3] = {
PyArray_DIMS(%(_x)s)[0],
PyArray_DIMS(%(_x)s)[1],
PyArray_DIMS(%(_y)s)[2],
};
if (NULL != %(_zout)s) Py_XDECREF(%(_zout)s);
%(_zout)s = (PyArrayObject*)PyArray_SimpleNew(3, Nz,
PyArray_TYPE(%(_x)s));
//fprintf(stderr, "BatchedDot Allocating %%i %%i %%i\\n", Nz[0], Nz[1], Nz[2]);
if(!%(_zout)s) {
PyErr_SetString(PyExc_MemoryError,
"failed to alloc batched_dot22 output");
%(fail)s
}
}
Nz = PyArray_DIMS(%(_zout)s);
Sz = PyArray_STRIDES(%(_zout)s);
if ((PyArray_DESCR(%(_x)s)->type_num != NPY_DOUBLE)
&& (PyArray_DESCR(%(_x)s)->type_num != NPY_FLOAT))
{PyErr_SetString(PyExc_NotImplementedError, "type(x) is not double or float"); %(fail)s;}
if ((PyArray_DESCR(%(_y)s)->type_num != NPY_DOUBLE)
&& (PyArray_DESCR(%(_y)s)->type_num != NPY_FLOAT))
{PyErr_SetString(PyExc_NotImplementedError, "type(y) is not double or float"); %(fail)s;}
if ((PyArray_DESCR(%(_zout)s)->type_num != NPY_DOUBLE)
&& (PyArray_DESCR(%(_zout)s)->type_num != NPY_FLOAT))
{PyErr_SetString(PyExc_NotImplementedError, "type(z) is not double or float"); %(fail)s;}
if ((PyArray_DESCR(%(_x)s)->type_num != PyArray_DESCR(%(_y)s)->type_num)
||(PyArray_DESCR(%(_x)s)->type_num != PyArray_DESCR(%(_zout)s)->type_num))
{ PyErr_SetString(PyExc_NotImplementedError, "type(x), type(y), type(z) are not all the same"); %(fail)s; }
/*
If some matrices are not contiguous on either dimensions,
or have invalid strides, copy their content into a contiguous one
*/
if ((Sx[0] < 1) || (Sx[1] < 1) || (Sx[2] < 2) ||
(Sx[0] %% type_size) || (Sx[1] %% type_size) || (Sx[2] %% type_size) ||
((Sx[0] != type_size) && (Sx[1] != type_size) && (Sx[2] != type_size)))
{
PyArrayObject * _x_copy = (PyArrayObject *) PyArray_Copy(%(_x)s);
if (!_x_copy)
%(fail)s
Py_XDECREF(%(_x)s);
%(_x)s = _x_copy;
Sx = PyArray_STRIDES(%(_x)s);
}
if ((Sy[0] < 1) || (Sy[1] < 1) || (Sy[2] < 2) ||
(Sy[0] %% type_size) || (Sy[1] %% type_size) || (Sy[2] %% type_size) ||
((Sy[0] != type_size) && (Sy[1] != type_size) && (Sy[2] != type_size)))
{
PyArrayObject * _y_copy = (PyArrayObject *) PyArray_Copy(%(_y)s);
if (!_y_copy)
%(fail)s
Py_XDECREF(%(_y)s);
%(_y)s = _y_copy;
Sy = PyArray_STRIDES(%(_y)s);
}
if ((Sz[0] < 1) || (Sz[1] < 1) || (Sz[2] < 2) ||
(Sz[0] %% type_size) || (Sz[1] %% type_size) || (Sz[2] %% type_size) ||
((Sz[0] != type_size) && (Sz[1] != type_size) && (Sz[2] != type_size)))
{
PyArrayObject * _z_copy = (PyArrayObject *) PyArray_Copy(%(_zout)s);
if (!_z_copy)
%(fail)s
Py_XDECREF(%(_zout)s);
%(_zout)s = _z_copy;
Sz = PyArray_STRIDES(%(_zout)s);
}
/*
encode the stride structure of _x,_y,_zout into a single integer.
Note we don't care about axis 0 since we loop over it outside the gemm call.
*/
unit |= ((Sx[2] == type_size || Nx[2] == 1) ? 0x0 : (Sx[1] == type_size || Nx[1]==1) ? 0x1 : 0x2) << 8;
unit |= ((Sy[2] == type_size || Ny[2] == 1) ? 0x0 : (Sy[1] == type_size || Ny[1]==1) ? 0x1 : 0x2) << 4;
unit |= ((Sz[2] == type_size || Nz[2] == 1) ? 0x0 : (Sz[1] == type_size || Nz[1]==1) ? 0x1 : 0x2) << 0;
/* create appropriate strides for malformed matrices that are row or column
* vectors, or empty matrices.
* In that case, the value of the stride does not really matter, but
* some versions of BLAS insist that:
* - they are not smaller than the number of elements in the array,
* - they are not 0.
*/
sx_1 = (Nx[1] > 1) ? Sx[1]/type_size : (Nx[2] + 1);
sx_2 = (Nx[2] > 1) ? Sx[2]/type_size : (Nx[1] + 1);
sy_1 = (Ny[1] > 1) ? Sy[1]/type_size : (Ny[2] + 1);
sy_2 = (Ny[2] > 1) ? Sy[2]/type_size : (Ny[1] + 1);
sz_1 = (Nz[1] > 1) ? Sz[1]/type_size : (Nz[2] + 1);
sz_2 = (Nz[2] > 1) ? Sz[2]/type_size : (Nz[1] + 1);
switch (type_num)
{
case NPY_FLOAT:
{
float a = 1.0;
float b = 0.0;
float* x = (float*)PyArray_DATA(%(_x)s);
float* y = (float*)PyArray_DATA(%(_y)s);
float* z = (float*)PyArray_DATA(%(_zout)s);
char N = 'N';
char T = 'T';
int Nz1 = Nz[1], Nz2 = Nz[2], Nx2 = Nx[2];
// loop over batch axis
for (int i = 0; i < Nz[0]; i++) {
switch(unit)
{
case 0x000: sgemm_(&N, &N, &Nz2, &Nz1, &Nx2, &a, y, &sy_1, x, &sx_1, &b, z, &sz_1); break;
case 0x100: sgemm_(&N, &T, &Nz2, &Nz1, &Nx2, &a, y, &sy_1, x, &sx_2, &b, z, &sz_1); break;
case 0x010: sgemm_(&T, &N, &Nz2, &Nz1, &Nx2, &a, y, &sy_2, x, &sx_1, &b, z, &sz_1); break;
case 0x110: sgemm_(&T, &T, &Nz2, &Nz1, &Nx2, &a, y, &sy_2, x, &sx_2, &b, z, &sz_1); break;
case 0x001: sgemm_(&T, &T, &Nz1, &Nz2, &Nx2, &a, x, &sx_1, y, &sy_1, &b, z, &sz_2); break;
case 0x101: sgemm_(&N, &T, &Nz1, &Nz2, &Nx2, &a, x, &sx_2, y, &sy_1, &b, z, &sz_2); break;
case 0x011: sgemm_(&T, &N, &Nz1, &Nz2, &Nx2, &a, x, &sx_1, y, &sy_2, &b, z, &sz_2); break;
case 0x111: sgemm_(&N, &N, &Nz1, &Nz2, &Nx2, &a, x, &sx_2, y, &sy_2, &b, z, &sz_2); break;
default: PyErr_SetString(PyExc_ValueError, "some matrix has no unit stride"); %(fail)s;
};
x += Sz[0]; y += Sy[0]; z += Sz[0];
}
}
break;
case NPY_DOUBLE:
{
double a = 1.0;
double b = 0.0;
double* x = (double*)PyArray_DATA(%(_x)s);
double* y = (double*)PyArray_DATA(%(_y)s);
double* z = (double*)PyArray_DATA(%(_zout)s);
char N = 'N';
char T = 'T';
int Nz1 = Nz[1], Nz2 = Nz[2], Nx2 = Nx[2];
// loop over batch axis
for (int i = 0; i < Nz[0]; i++) {
switch(unit)
{
case 0x000: dgemm_(&N, &N, &Nz2, &Nz1, &Nx2, &a, y, &sy_1, x, &sx_1, &b, z, &sz_1); break;
case 0x100: dgemm_(&N, &T, &Nz2, &Nz1, &Nx2, &a, y, &sy_1, x, &sx_2, &b, z, &sz_1); break;
case 0x010: dgemm_(&T, &N, &Nz2, &Nz1, &Nx2, &a, y, &sy_2, x, &sx_1, &b, z, &sz_1); break;
case 0x110: dgemm_(&T, &T, &Nz2, &Nz1, &Nx2, &a, y, &sy_2, x, &sx_2, &b, z, &sz_1); break;
case 0x001: dgemm_(&T, &T, &Nz1, &Nz2, &Nx2, &a, x, &sx_1, y, &sy_1, &b, z, &sz_2); break;
case 0x101: dgemm_(&N, &T, &Nz1, &Nz2, &Nx2, &a, x, &sx_2, y, &sy_1, &b, z, &sz_2); break;
case 0x011: dgemm_(&T, &N, &Nz1, &Nz2, &Nx2, &a, x, &sx_1, y, &sy_2, &b, z, &sz_2); break;
case 0x111: dgemm_(&N, &N, &Nz1, &Nz2, &Nx2, &a, x, &sx_2, y, &sy_2, &b, z, &sz_2); break;
default: PyErr_SetString(PyExc_ValueError,
"some matrix has no unit stride");
%(fail)s;
};
x += Sz[0]; y += Sy[0]; z += Sz[0];
}
}
break;
}
""" % locals()
def c_code_cache_version(self):
return None
def grad(self, inp, grads):
x, y = inp
gz, = grads
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论