提交 9a3f6681 authored 作者: Tim Cooijmans's avatar Tim Cooijmans

BatchedDot: generalize matrix-matrix code to matrix-vector, vector-matrix, vector-vector products

上级 4ce8a480
...@@ -3440,43 +3440,152 @@ class BatchedDot(Op): ...@@ -3440,43 +3440,152 @@ class BatchedDot(Op):
from theano.tensor.blas import ldflags from theano.tensor.blas import ldflags
return ldflags(libs=False, include_dir=True) return ldflags(libs=False, include_dir=True)
def c_code_cleanup(self, node, name, inputs, outputs, sub):
return """
// clean up views
Py_XDECREF(xs); xs = 0;
Py_XDECREF(ys); ys = 0;
Py_XDECREF(zs); zs = 0;
"""
def c_code(self, node, name, inp, out, sub): def c_code(self, node, name, inp, out, sub):
_x, _y = inp _x, _y = inp
_zout, = out _z, = out
fail = sub["fail"] fail = sub["fail"]
x_ndim, y_ndim, z_ndim = node.inputs[0].ndim, node.inputs[1].ndim, node.outputs[0].ndim
# generate code to allocate output based on runtime input shapes
z_dims = ["PyArray_DIMS(%s)[0]" % _x]
if x_ndim == 3: z_dims.append("PyArray_DIMS(%s)[1]" % _x)
if y_ndim == 3: z_dims.append("PyArray_DIMS(%s)[2]" % _y)
assert len(z_dims) == z_ndim
z_shape_correct = " && ".join("PyArray_DIMS(%s)[%i] == %s"
% (_z, i, dim) for i, dim in enumerate(z_dims))
z_shape = ", ".join(z_dims)
allocate = """
if ((NULL == %(_z)s) || !(%(z_shape_correct)s))
{
npy_intp dims[3] = {%(z_shape)s};
if (NULL != %(_z)s) Py_XDECREF(%(_z)s);
%(_z)s = (PyArrayObject*)PyArray_SimpleNew(
%(z_ndim)s, dims, PyArray_TYPE(%(_x)s));
if(!%(_z)s) {
PyErr_SetString(PyExc_MemoryError,
"failed to alloc BatchedDot output");
%(fail)s
}
}
""" % locals()
# generate code to reallocate inputs/output contiguously if necessary
contiguate = []
for var, ndim in [(_x, x_ndim), (_y, y_ndim), (_z, z_ndim)]:
strides = "PyArray_STRIDES(%s)" % var
not_contiguous = " || ".join([
" || ".join("{strides}[{i}] < 1 || {strides}[{i}] % type_size"
.format(strides=strides, i=i) for i in range(ndim)),
"(%s)" % " && ".join("{strides}[{i}] != type_size"
.format(strides=strides, i=i) for i in range(ndim)),
])
contiguate.append("""
if (%(not_contiguous)s) {
PyArrayObject * _copy = (PyArrayObject *) PyArray_Copy(%(var)s);
if (!_copy)
%(fail)s
Py_XDECREF(%(var)s);
%(var)s = _copy;
}
""" % locals())
contiguate = "\n".join(contiguate)
def c_dimshuffle(newname, oldname, shape):
_fail = fail
_shape = ", ".join("1" if axis is None else "PyArray_DIMS(%s)[%i]" % (oldname, axis)
for axis in shape)
return """{
npy_intp dims[3] = {%(_shape)s};
PyArray_Dims newshape = {.ptr = dims, .len = 3};
%(newname)s = (PyArrayObject*)PyArray_Newshape(%(oldname)s, &newshape, NPY_KEEPORDER);
if (!%(newname)s)
%(_fail)s
// make sure we didn't accidentally copy
assert(PyArray_DATA(%(oldname)s) == PyArray_DATA(%(newname)s));
}""" % locals()
# create tensor3 views for any of x, y, z that are not tensor3, so that
# we only need to implement the tensor3-tensor3 batched dot product.
# xs, ys and zs will point to these views, or to the original array if
# it was already tensor3.
# in the latter case, we artificially increase the reference count of
# the original array so that the c_code_cleanup method can decref them
# all indiscriminately.
upcast = []
if x_ndim == 3:
upcast.append("xs = %(_x)s; Py_XINCREF(xs);")
elif x_ndim == 2:
upcast.append(c_dimshuffle("xs", _x, (0, None, 1)))
if y_ndim == 3:
upcast.append("ys = %(_y)s; Py_XINCREF(ys);")
elif y_ndim == 2:
upcast.append(c_dimshuffle("ys", _y, (0, 1, None)))
# upcast of z depends on shapes of both inputs
if x_ndim == 3 and y_ndim == 3:
upcast.append("zs = %(_z)s; Py_XINCREF(zs);")
else:
upcast.append(c_dimshuffle(
"zs", _z, (0,
None if x_ndim == 2 else 1,
None if y_ndim == 2 else 1)))
upcast = "\n".join(upcast) % locals()
return """ return """
int unit = 0; int unit = 0;
int type_num = PyArray_DESCR(%(_x)s)->type_num; int type_num = PyArray_DESCR(%(_x)s)->type_num;
int type_size = PyArray_DESCR(%(_x)s)->elsize; // in bytes int type_size = PyArray_DESCR(%(_x)s)->elsize; // in bytes
npy_intp* Nx = PyArray_DIMS(%(_x)s); // xs, ys, zs will point to views onto %(_x)s, %(_y)s, %(_z)s
npy_intp* Ny = PyArray_DIMS(%(_y)s); PyArrayObject *xs = 0, *ys = 0, *zs = 0;
npy_intp* Nz = 0; npy_intp *Nx = 0, *Ny = 0, *Nz = 0;
npy_intp *Sx = 0, *Sy = 0, *Sz = 0;
npy_intp* Sx = PyArray_STRIDES(%(_x)s);
npy_intp* Sy = PyArray_STRIDES(%(_y)s);
npy_intp* Sz = 0;
// strides for x, y, z in dimensions 1, 2 // strides for x, y, z in dimensions 1, 2
int sx_1 = 0, sx_2 = 0, sy_1 = 0, sy_2 = 0, sz_1 = 0, sz_2 = 0; int sx_1 = 0, sx_2 = 0, sy_1 = 0, sy_2 = 0, sz_1 = 0, sz_2 = 0;
if (PyArray_NDIM(%(_x)s) != 3) { if (PyArray_NDIM(%(_x)s) != %(x_ndim)s) {
PyErr_Format(PyExc_NotImplementedError, PyErr_Format(PyExc_NotImplementedError,
"rank(x) != 3. rank(x) is %%d.", PyArray_NDIM(%(_x)s)); "rank(x) != %(x_ndim)s. rank(x) is %%d.",
PyArray_NDIM(%(_x)s));
%(fail)s; %(fail)s;
} }
if (PyArray_NDIM(%(_y)s) != 3) { if (PyArray_NDIM(%(_y)s) != %(y_ndim)s) {
PyErr_Format(PyExc_NotImplementedError, PyErr_Format(PyExc_NotImplementedError,
"rank(y) != 3. rank(y) is %%d.", PyArray_NDIM(%(_y)s)); "rank(y) != %(y_ndim)s. rank(y) is %%d.",
PyArray_NDIM(%(_y)s));
%(fail)s; %(fail)s;
} }
if (%(_zout)s && PyArray_NDIM(%(_zout)s) != 3) { if (%(_z)s && PyArray_NDIM(%(_z)s) != %(z_ndim)s) {
PyErr_Format(PyExc_NotImplementedError, PyErr_Format(PyExc_NotImplementedError,
"rank(z) != 3. rank(z) is %%d.", PyArray_NDIM(%(_zout)s)); "rank(z) != %(z_ndim)s. rank(z) is %%d.",
PyArray_NDIM(%(_z)s));
%(fail)s; %(fail)s;
} }
// allocate output
%(allocate)s
// reallocate any noncontiguous arrays or arrays with invalid strides
%(contiguate)s
// add dims to make sure everything is tensor3
%(upcast)s
// from here on, use xs, ys and zs as they are tensor3 and share memory
// with the original %(_x)s, %(_y)s and %(_z)s arrays.
Nx = PyArray_DIMS(xs); Sx = PyArray_STRIDES(xs);
Ny = PyArray_DIMS(ys); Sy = PyArray_STRIDES(ys);
Nz = PyArray_DIMS(zs); Sz = PyArray_STRIDES(zs);
if (Nx[0] != Ny[0]) { if (Nx[0] != Ny[0]) {
PyErr_Format(PyExc_ValueError, PyErr_Format(PyExc_ValueError,
"Shape mismatch: batch sizes unequal." "Shape mismatch: batch sizes unequal."
...@@ -3486,8 +3595,8 @@ class BatchedDot(Op): ...@@ -3486,8 +3595,8 @@ class BatchedDot(Op):
Ny[0], Ny[1], Ny[2]); Ny[0], Ny[1], Ny[2]);
%(fail)s; %(fail)s;
} }
if (Nx[2] != Ny[1])
{ if (Nx[2] != Ny[1]) {
PyErr_Format(PyExc_ValueError, PyErr_Format(PyExc_ValueError,
"Shape mismatch: summation axis sizes unequal." "Shape mismatch: summation axis sizes unequal."
" x.shape is (%%d, %%d, %%d)," " x.shape is (%%d, %%d, %%d),"
...@@ -3497,91 +3606,23 @@ class BatchedDot(Op): ...@@ -3497,91 +3606,23 @@ class BatchedDot(Op):
%(fail)s; %(fail)s;
} }
if ((NULL == %(_zout)s) if ((PyArray_DESCR(xs)->type_num != NPY_DOUBLE)
|| (PyArray_DIMS(%(_zout)s)[0] != PyArray_DIMS(%(_x)s)[0]) && (PyArray_DESCR(xs)->type_num != NPY_FLOAT))
|| (PyArray_DIMS(%(_zout)s)[1] != PyArray_DIMS(%(_x)s)[1])
|| (PyArray_DIMS(%(_zout)s)[2] != PyArray_DIMS(%(_y)s)[2]))
{
npy_intp dims[3] = {
PyArray_DIMS(%(_x)s)[0],
PyArray_DIMS(%(_x)s)[1],
PyArray_DIMS(%(_y)s)[2],
};
if (NULL != %(_zout)s) Py_XDECREF(%(_zout)s);
%(_zout)s = (PyArrayObject*)PyArray_SimpleNew(3, Nz,
PyArray_TYPE(%(_x)s));
//fprintf(stderr, "BatchedDot Allocating %%i %%i %%i\\n", Nz[0], Nz[1], Nz[2]);
if(!%(_zout)s) {
PyErr_SetString(PyExc_MemoryError,
"failed to alloc batched_dot22 output");
%(fail)s
}
}
Nz = PyArray_DIMS(%(_zout)s);
Sz = PyArray_STRIDES(%(_zout)s);
if ((PyArray_DESCR(%(_x)s)->type_num != NPY_DOUBLE)
&& (PyArray_DESCR(%(_x)s)->type_num != NPY_FLOAT))
{PyErr_SetString(PyExc_NotImplementedError, "type(x) is not double or float"); %(fail)s;} {PyErr_SetString(PyExc_NotImplementedError, "type(x) is not double or float"); %(fail)s;}
if ((PyArray_DESCR(%(_y)s)->type_num != NPY_DOUBLE) if ((PyArray_DESCR(ys)->type_num != NPY_DOUBLE)
&& (PyArray_DESCR(%(_y)s)->type_num != NPY_FLOAT)) && (PyArray_DESCR(ys)->type_num != NPY_FLOAT))
{PyErr_SetString(PyExc_NotImplementedError, "type(y) is not double or float"); %(fail)s;} {PyErr_SetString(PyExc_NotImplementedError, "type(y) is not double or float"); %(fail)s;}
if ((PyArray_DESCR(%(_zout)s)->type_num != NPY_DOUBLE) if ((PyArray_DESCR(zs)->type_num != NPY_DOUBLE)
&& (PyArray_DESCR(%(_zout)s)->type_num != NPY_FLOAT)) && (PyArray_DESCR(zs)->type_num != NPY_FLOAT))
{PyErr_SetString(PyExc_NotImplementedError, "type(z) is not double or float"); %(fail)s;} {PyErr_SetString(PyExc_NotImplementedError, "type(z) is not double or float"); %(fail)s;}
if ((PyArray_DESCR(%(_x)s)->type_num != PyArray_DESCR(%(_y)s)->type_num) if ((PyArray_DESCR(xs)->type_num != PyArray_DESCR(%(_y)s)->type_num)
||(PyArray_DESCR(%(_x)s)->type_num != PyArray_DESCR(%(_zout)s)->type_num)) ||(PyArray_DESCR(xs)->type_num != PyArray_DESCR(zs)->type_num))
{ PyErr_SetString(PyExc_NotImplementedError, "type(x), type(y), type(z) are not all the same"); %(fail)s; } { PyErr_SetString(PyExc_NotImplementedError, "type(x), type(y), type(z) are not all the same"); %(fail)s; }
/* /* encode the stride structure of _x,_y,_z into a single integer. */
If some matrices are not contiguous on either dimensions,
or have invalid strides, copy their content into a contiguous one
*/
if ((Sx[0] < 1) || (Sx[1] < 1) || (Sx[2] < 2) ||
(Sx[0] %% type_size) || (Sx[1] %% type_size) || (Sx[2] %% type_size) ||
((Sx[0] != type_size) && (Sx[1] != type_size) && (Sx[2] != type_size)))
{
PyArrayObject * _x_copy = (PyArrayObject *) PyArray_Copy(%(_x)s);
if (!_x_copy)
%(fail)s
Py_XDECREF(%(_x)s);
%(_x)s = _x_copy;
Sx = PyArray_STRIDES(%(_x)s);
}
if ((Sy[0] < 1) || (Sy[1] < 1) || (Sy[2] < 2) ||
(Sy[0] %% type_size) || (Sy[1] %% type_size) || (Sy[2] %% type_size) ||
((Sy[0] != type_size) && (Sy[1] != type_size) && (Sy[2] != type_size)))
{
PyArrayObject * _y_copy = (PyArrayObject *) PyArray_Copy(%(_y)s);
if (!_y_copy)
%(fail)s
Py_XDECREF(%(_y)s);
%(_y)s = _y_copy;
Sy = PyArray_STRIDES(%(_y)s);
}
if ((Sz[0] < 1) || (Sz[1] < 1) || (Sz[2] < 2) ||
(Sz[0] %% type_size) || (Sz[1] %% type_size) || (Sz[2] %% type_size) ||
((Sz[0] != type_size) && (Sz[1] != type_size) && (Sz[2] != type_size)))
{
PyArrayObject * _z_copy = (PyArrayObject *) PyArray_Copy(%(_zout)s);
if (!_z_copy)
%(fail)s
Py_XDECREF(%(_zout)s);
%(_zout)s = _z_copy;
Sz = PyArray_STRIDES(%(_zout)s);
}
/*
encode the stride structure of _x,_y,_zout into a single integer.
Note we don't care about axis 0 since we loop over it outside the gemm call.
*/
unit |= ((Sx[2] == type_size || Nx[2] == 1) ? 0x0 : (Sx[1] == type_size || Nx[1]==1) ? 0x1 : 0x2) << 8; unit |= ((Sx[2] == type_size || Nx[2] == 1) ? 0x0 : (Sx[1] == type_size || Nx[1]==1) ? 0x1 : 0x2) << 8;
unit |= ((Sy[2] == type_size || Ny[2] == 1) ? 0x0 : (Sy[1] == type_size || Ny[1]==1) ? 0x1 : 0x2) << 4; unit |= ((Sy[2] == type_size || Ny[2] == 1) ? 0x0 : (Sy[1] == type_size || Ny[1]==1) ? 0x1 : 0x2) << 4;
unit |= ((Sz[2] == type_size || Nz[2] == 1) ? 0x0 : (Sz[1] == type_size || Nz[1]==1) ? 0x1 : 0x2) << 0; unit |= ((Sz[2] == type_size || Nz[2] == 1) ? 0x0 : (Sz[1] == type_size || Nz[1]==1) ? 0x1 : 0x2) << 0;
...@@ -3606,9 +3647,9 @@ class BatchedDot(Op): ...@@ -3606,9 +3647,9 @@ class BatchedDot(Op):
{ {
float a = 1.0; float a = 1.0;
float b = 0.0; float b = 0.0;
float* x = (float*)PyArray_DATA(%(_x)s); float* x = (float*)PyArray_DATA(xs);
float* y = (float*)PyArray_DATA(%(_y)s); float* y = (float*)PyArray_DATA(ys);
float* z = (float*)PyArray_DATA(%(_zout)s); float* z = (float*)PyArray_DATA(zs);
char N = 'N'; char N = 'N';
char T = 'T'; char T = 'T';
int Nz1 = Nz[1], Nz2 = Nz[2], Nx2 = Nx[2]; int Nz1 = Nz[1], Nz2 = Nz[2], Nx2 = Nx[2];
...@@ -3634,9 +3675,9 @@ class BatchedDot(Op): ...@@ -3634,9 +3675,9 @@ class BatchedDot(Op):
{ {
double a = 1.0; double a = 1.0;
double b = 0.0; double b = 0.0;
double* x = (double*)PyArray_DATA(%(_x)s); double* x = (double*)PyArray_DATA(xs);
double* y = (double*)PyArray_DATA(%(_y)s); double* y = (double*)PyArray_DATA(ys);
double* z = (double*)PyArray_DATA(%(_zout)s); double* z = (double*)PyArray_DATA(zs);
char N = 'N'; char N = 'N';
char T = 'T'; char T = 'T';
int Nz1 = Nz[1], Nz2 = Nz[2], Nx2 = Nx[2]; int Nz1 = Nz[1], Nz2 = Nz[2], Nx2 = Nx[2];
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论