提交 78bbb561 authored 作者: Tim Cooijmans's avatar Tim Cooijmans

BatchedDot: reduce duplication

上级 1ac274a9
...@@ -3422,7 +3422,85 @@ class BatchedDot(Op): ...@@ -3422,7 +3422,85 @@ class BatchedDot(Op):
def c_support_code(self): def c_support_code(self):
from theano.tensor.blas_headers import blas_header_text from theano.tensor.blas_headers import blas_header_text
return blas_header_text() batch_gemm_defn = """
template<typename dtype, typename function>
bool batch_gemm(function gemm, int type_size,
PyArrayObject* xs, PyArrayObject* ys, PyArrayObject* zs) {
npy_intp *Nx = PyArray_DIMS(xs), *Sx = PyArray_STRIDES(xs);
npy_intp *Ny = PyArray_DIMS(ys), *Sy = PyArray_STRIDES(ys);
npy_intp *Nz = PyArray_DIMS(zs), *Sz = PyArray_STRIDES(zs);
if (Nx[0] != Ny[0]) {
PyErr_Format(PyExc_ValueError,
"Shape mismatch: batch sizes unequal."
" x.shape is (%d, %d, %d),"
" y.shape is (%d, %d, %d).",
Nx[0], Nx[1], Nx[2],
Ny[0], Ny[1], Ny[2]);
return 1;
}
if (Nx[2] != Ny[1]) {
PyErr_Format(PyExc_ValueError,
"Shape mismatch: summation axis sizes unequal."
" x.shape is (%d, %d, %d),"
" y.shape is (%d, %d, %d).",
Nx[0], Nx[1], Nx[2],
Ny[0], Ny[1], Ny[2]);
return 1;
}
/* encode the stride structure of _x,_y,_z into a single integer. */
int unit = 0;
unit |= ((Sx[2] == type_size || Nx[2] == 1) ? 0x0 : (Sx[1] == type_size || Nx[1]==1) ? 0x1 : 0x2) << 8;
unit |= ((Sy[2] == type_size || Ny[2] == 1) ? 0x0 : (Sy[1] == type_size || Ny[1]==1) ? 0x1 : 0x2) << 4;
unit |= ((Sz[2] == type_size || Nz[2] == 1) ? 0x0 : (Sz[1] == type_size || Nz[1]==1) ? 0x1 : 0x2) << 0;
/* create appropriate strides for malformed matrices that are row or column
* vectors, or empty matrices.
* In that case, the value of the stride does not really matter, but
* some versions of BLAS insist that:
* - they are not smaller than the number of elements in the array,
* - they are not 0.
*/
int sx_1 = (Nx[1] > 1) ? Sx[1]/type_size : (Nx[2] + 1);
int sx_2 = (Nx[2] > 1) ? Sx[2]/type_size : (Nx[1] + 1);
int sy_1 = (Ny[1] > 1) ? Sy[1]/type_size : (Ny[2] + 1);
int sy_2 = (Ny[2] > 1) ? Sy[2]/type_size : (Ny[1] + 1);
int sz_1 = (Nz[1] > 1) ? Sz[1]/type_size : (Nz[2] + 1);
int sz_2 = (Nz[2] > 1) ? Sz[2]/type_size : (Nz[1] + 1);
dtype* x = (dtype*)PyArray_DATA(xs);
dtype* y = (dtype*)PyArray_DATA(ys);
dtype* z = (dtype*)PyArray_DATA(zs);
dtype a = 1.0;
dtype b = 0.0;
char N = 'N';
char T = 'T';
int Nz1 = Nz[1], Nz2 = Nz[2], Nx2 = Nx[2];
// loop over batch axis
for (int i = 0; i < Nz[0]; i++) {
switch(unit)
{
case 0x000: gemm(&N, &N, &Nz2, &Nz1, &Nx2, &a, y, &sy_1, x, &sx_1, &b, z, &sz_1); break;
case 0x100: gemm(&N, &T, &Nz2, &Nz1, &Nx2, &a, y, &sy_1, x, &sx_2, &b, z, &sz_1); break;
case 0x010: gemm(&T, &N, &Nz2, &Nz1, &Nx2, &a, y, &sy_2, x, &sx_1, &b, z, &sz_1); break;
case 0x110: gemm(&T, &T, &Nz2, &Nz1, &Nx2, &a, y, &sy_2, x, &sx_2, &b, z, &sz_1); break;
case 0x001: gemm(&T, &T, &Nz1, &Nz2, &Nx2, &a, x, &sx_1, y, &sy_1, &b, z, &sz_2); break;
case 0x101: gemm(&N, &T, &Nz1, &Nz2, &Nx2, &a, x, &sx_2, y, &sy_1, &b, z, &sz_2); break;
case 0x011: gemm(&T, &N, &Nz1, &Nz2, &Nx2, &a, x, &sx_1, y, &sy_2, &b, z, &sz_2); break;
case 0x111: gemm(&N, &N, &Nz1, &Nz2, &Nx2, &a, x, &sx_2, y, &sy_2, &b, z, &sz_2); break;
default: PyErr_SetString(PyExc_ValueError, "some matrix has no unit stride"); return 1;
};
x += Sx[0]; y += Sy[0]; z += Sz[0];
}
return 0;
}
"""
return blas_header_text() + batch_gemm_defn
def c_libraries(self): def c_libraries(self):
from theano.tensor.blas import ldflags from theano.tensor.blas import ldflags
...@@ -3541,18 +3619,11 @@ class BatchedDot(Op): ...@@ -3541,18 +3619,11 @@ class BatchedDot(Op):
upcast = "\n".join(upcast) % locals() upcast = "\n".join(upcast) % locals()
return """ return """
int unit = 0;
int type_num = PyArray_DESCR(%(_x)s)->type_num; int type_num = PyArray_DESCR(%(_x)s)->type_num;
int type_size = PyArray_DESCR(%(_x)s)->elsize; // in bytes int type_size = PyArray_DESCR(%(_x)s)->elsize; // in bytes
// xs, ys, zs will point to views onto %(_x)s, %(_y)s, %(_z)s // xs, ys, zs will point to views onto %(_x)s, %(_y)s, %(_z)s
PyArrayObject *xs = 0, *ys = 0, *zs = 0; PyArrayObject *xs = 0, *ys = 0, *zs = 0;
npy_intp *Nx = 0, *Ny = 0, *Nz = 0;
npy_intp *Sx = 0, *Sy = 0, *Sz = 0;
// strides for x, y, z in dimensions 1, 2
int sx_1 = 0, sx_2 = 0, sy_1 = 0, sy_2 = 0, sz_1 = 0, sz_2 = 0;
if (PyArray_NDIM(%(_x)s) != %(x_ndim)s) { if (PyArray_NDIM(%(_x)s) != %(x_ndim)s) {
PyErr_Format(PyExc_NotImplementedError, PyErr_Format(PyExc_NotImplementedError,
...@@ -3582,30 +3653,6 @@ class BatchedDot(Op): ...@@ -3582,30 +3653,6 @@ class BatchedDot(Op):
// from here on, use xs, ys and zs as they are tensor3 and share memory // from here on, use xs, ys and zs as they are tensor3 and share memory
// with the original %(_x)s, %(_y)s and %(_z)s arrays. // with the original %(_x)s, %(_y)s and %(_z)s arrays.
Nx = PyArray_DIMS(xs); Sx = PyArray_STRIDES(xs);
Ny = PyArray_DIMS(ys); Sy = PyArray_STRIDES(ys);
Nz = PyArray_DIMS(zs); Sz = PyArray_STRIDES(zs);
if (Nx[0] != Ny[0]) {
PyErr_Format(PyExc_ValueError,
"Shape mismatch: batch sizes unequal."
" x.shape is (%%d, %%d, %%d),"
" y.shape is (%%d, %%d, %%d).",
Nx[0], Nx[1], Nx[2],
Ny[0], Ny[1], Ny[2]);
%(fail)s;
}
if (Nx[2] != Ny[1]) {
PyErr_Format(PyExc_ValueError,
"Shape mismatch: summation axis sizes unequal."
" x.shape is (%%d, %%d, %%d),"
" y.shape is (%%d, %%d, %%d).",
Nx[0], Nx[1], Nx[2],
Ny[0], Ny[1], Ny[2]);
%(fail)s;
}
if ((PyArray_DESCR(xs)->type_num != NPY_DOUBLE) if ((PyArray_DESCR(xs)->type_num != NPY_DOUBLE)
&& (PyArray_DESCR(xs)->type_num != NPY_FLOAT)) && (PyArray_DESCR(xs)->type_num != NPY_FLOAT))
{PyErr_SetString(PyExc_NotImplementedError, "type(x) is not double or float"); %(fail)s;} {PyErr_SetString(PyExc_NotImplementedError, "type(x) is not double or float"); %(fail)s;}
...@@ -3622,84 +3669,16 @@ class BatchedDot(Op): ...@@ -3622,84 +3669,16 @@ class BatchedDot(Op):
||(PyArray_DESCR(xs)->type_num != PyArray_DESCR(zs)->type_num)) ||(PyArray_DESCR(xs)->type_num != PyArray_DESCR(zs)->type_num))
{ PyErr_SetString(PyExc_NotImplementedError, "type(x), type(y), type(z) are not all the same"); %(fail)s; } { PyErr_SetString(PyExc_NotImplementedError, "type(x), type(y), type(z) are not all the same"); %(fail)s; }
/* encode the stride structure of _x,_y,_z into a single integer. */
unit |= ((Sx[2] == type_size || Nx[2] == 1) ? 0x0 : (Sx[1] == type_size || Nx[1]==1) ? 0x1 : 0x2) << 8;
unit |= ((Sy[2] == type_size || Ny[2] == 1) ? 0x0 : (Sy[1] == type_size || Ny[1]==1) ? 0x1 : 0x2) << 4;
unit |= ((Sz[2] == type_size || Nz[2] == 1) ? 0x0 : (Sz[1] == type_size || Nz[1]==1) ? 0x1 : 0x2) << 0;
/* create appropriate strides for malformed matrices that are row or column
* vectors, or empty matrices.
* In that case, the value of the stride does not really matter, but
* some versions of BLAS insist that:
* - they are not smaller than the number of elements in the array,
* - they are not 0.
*/
sx_1 = (Nx[1] > 1) ? Sx[1]/type_size : (Nx[2] + 1);
sx_2 = (Nx[2] > 1) ? Sx[2]/type_size : (Nx[1] + 1);
sy_1 = (Ny[1] > 1) ? Sy[1]/type_size : (Ny[2] + 1);
sy_2 = (Ny[2] > 1) ? Sy[2]/type_size : (Ny[1] + 1);
sz_1 = (Nz[1] > 1) ? Sz[1]/type_size : (Nz[2] + 1);
sz_2 = (Nz[2] > 1) ? Sz[2]/type_size : (Nz[1] + 1);
switch (type_num) switch (type_num)
{ {
case NPY_FLOAT: case NPY_FLOAT:
{ if (batch_gemm<float>(sgemm_, type_size, xs, ys, zs)) {
float a = 1.0; %(fail)s;
float b = 0.0;
float* x = (float*)PyArray_DATA(xs);
float* y = (float*)PyArray_DATA(ys);
float* z = (float*)PyArray_DATA(zs);
char N = 'N';
char T = 'T';
int Nz1 = Nz[1], Nz2 = Nz[2], Nx2 = Nx[2];
// loop over batch axis
for (int i = 0; i < Nz[0]; i++) {
switch(unit)
{
case 0x000: sgemm_(&N, &N, &Nz2, &Nz1, &Nx2, &a, y, &sy_1, x, &sx_1, &b, z, &sz_1); break;
case 0x100: sgemm_(&N, &T, &Nz2, &Nz1, &Nx2, &a, y, &sy_1, x, &sx_2, &b, z, &sz_1); break;
case 0x010: sgemm_(&T, &N, &Nz2, &Nz1, &Nx2, &a, y, &sy_2, x, &sx_1, &b, z, &sz_1); break;
case 0x110: sgemm_(&T, &T, &Nz2, &Nz1, &Nx2, &a, y, &sy_2, x, &sx_2, &b, z, &sz_1); break;
case 0x001: sgemm_(&T, &T, &Nz1, &Nz2, &Nx2, &a, x, &sx_1, y, &sy_1, &b, z, &sz_2); break;
case 0x101: sgemm_(&N, &T, &Nz1, &Nz2, &Nx2, &a, x, &sx_2, y, &sy_1, &b, z, &sz_2); break;
case 0x011: sgemm_(&T, &N, &Nz1, &Nz2, &Nx2, &a, x, &sx_1, y, &sy_2, &b, z, &sz_2); break;
case 0x111: sgemm_(&N, &N, &Nz1, &Nz2, &Nx2, &a, x, &sx_2, y, &sy_2, &b, z, &sz_2); break;
default: PyErr_SetString(PyExc_ValueError, "some matrix has no unit stride"); %(fail)s;
};
x += Sx[0]; y += Sy[0]; z += Sz[0];
}
} }
break; break;
case NPY_DOUBLE: case NPY_DOUBLE:
{ if (batch_gemm<double>(dgemm_, type_size, xs, ys, zs)) {
double a = 1.0; %(fail)s;
double b = 0.0;
double* x = (double*)PyArray_DATA(xs);
double* y = (double*)PyArray_DATA(ys);
double* z = (double*)PyArray_DATA(zs);
char N = 'N';
char T = 'T';
int Nz1 = Nz[1], Nz2 = Nz[2], Nx2 = Nx[2];
// loop over batch axis
for (int i = 0; i < Nz[0]; i++) {
switch(unit)
{
case 0x000: dgemm_(&N, &N, &Nz2, &Nz1, &Nx2, &a, y, &sy_1, x, &sx_1, &b, z, &sz_1); break;
case 0x100: dgemm_(&N, &T, &Nz2, &Nz1, &Nx2, &a, y, &sy_1, x, &sx_2, &b, z, &sz_1); break;
case 0x010: dgemm_(&T, &N, &Nz2, &Nz1, &Nx2, &a, y, &sy_2, x, &sx_1, &b, z, &sz_1); break;
case 0x110: dgemm_(&T, &T, &Nz2, &Nz1, &Nx2, &a, y, &sy_2, x, &sx_2, &b, z, &sz_1); break;
case 0x001: dgemm_(&T, &T, &Nz1, &Nz2, &Nx2, &a, x, &sx_1, y, &sy_1, &b, z, &sz_2); break;
case 0x101: dgemm_(&N, &T, &Nz1, &Nz2, &Nx2, &a, x, &sx_2, y, &sy_1, &b, z, &sz_2); break;
case 0x011: dgemm_(&T, &N, &Nz1, &Nz2, &Nx2, &a, x, &sx_1, y, &sy_2, &b, z, &sz_2); break;
case 0x111: dgemm_(&N, &N, &Nz1, &Nz2, &Nx2, &a, x, &sx_2, y, &sy_2, &b, z, &sz_2); break;
default: PyErr_SetString(PyExc_ValueError,
"some matrix has no unit stride");
%(fail)s;
};
x += Sx[0]; y += Sy[0]; z += Sz[0];
}
} }
break; break;
} }
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论