提交 c539a57b authored 作者: Tim Cooijmans's avatar Tim Cooijmans

BatchedDot: move to theano/tensor/blas.py

上级 0fde9a49
......@@ -3372,445 +3372,6 @@ def transpose(x, axes=None):
return ret
class BatchedDot(Op):
"""
Computes the batched dot product of two variables:
batched_dot(a, b)[i] = dot(a[i], b[i])
"""
__props__ = ()
def make_node(self, *inputs):
inputs = list(map(as_tensor_variable, inputs))
if len(inputs) != 2:
raise TypeError(
'theano.tensor.BatchedDot: 2 arguments required, %d given ' %
len(inputs))
if inputs[0].ndim not in (2, 3):
raise TypeError(
'theano.tensor.BatchedDot: input 0 (0-indexed) must have ndim'
' of 2 or 3, %d given. Consider calling '
'theano.tensor.batched_dot instead.' % inputs[0].ndim)
if inputs[1].ndim not in (2, 3):
raise TypeError(
'theano.tensor.BatchedDot: input 1 (0-indexed) must have ndim'
'of 2 or 3, %d given. Consider calling '
'theano.tensor.batched_dot instead.' % inputs[1].ndim)
dtype = scal.upcast(*[input.type.dtype for input in inputs])
# upcast inputs to common dtype if needed
upcasted_inputs = [cast(input, dtype) for input in inputs]
broadcastable = ((inputs[0].type.broadcastable[0] or
inputs[1].type.broadcastable[0],) +
inputs[0].type.broadcastable[1:-1] +
inputs[1].type.broadcastable[2:])
return Apply(self, upcasted_inputs, [tensor(dtype, broadcastable)])
def perform(self, node, inp, out):
x, y = inp
z, = out
if x.shape[0] != y.shape[0]:
raise TypeError(
'theano.tensor.BatchedDot: inputs [%s] must have the same size'
'in axis 0, but have sizes [%s].'
% (", ".join(map(str, inp)),
", ".join([str(i.shape[0]) for i in inp])))
shape = self.infer_shape(node, [i.shape for i in inp])[0]
dtype = node.outputs[0].dtype
z0 = z[0] = numpy.empty(shape, dtype=dtype)
for i in xrange(z0.shape[0]):
z0[i] = numpy.dot(x[i], y[i])
def c_support_code(self):
from theano.tensor.blas_headers import blas_header_text
batch_gemm_defn = """
template<typename dtype, typename function>
bool batch_gemm(function gemm, int type_size,
PyArrayObject* xs, PyArrayObject* ys, PyArrayObject* zs) {
npy_intp *Nx = PyArray_DIMS(xs), *Sx = PyArray_STRIDES(xs);
npy_intp *Ny = PyArray_DIMS(ys), *Sy = PyArray_STRIDES(ys);
npy_intp *Nz = PyArray_DIMS(zs), *Sz = PyArray_STRIDES(zs);
if (Nx[0] != Ny[0]) {
PyErr_Format(PyExc_ValueError,
"Shape mismatch: batch sizes unequal."
" x.shape is (%d, %d, %d),"
" y.shape is (%d, %d, %d).",
Nx[0], Nx[1], Nx[2],
Ny[0], Ny[1], Ny[2]);
return 1;
}
if (Nx[2] != Ny[1]) {
PyErr_Format(PyExc_ValueError,
"Shape mismatch: summation axis sizes unequal."
" x.shape is (%d, %d, %d),"
" y.shape is (%d, %d, %d).",
Nx[0], Nx[1], Nx[2],
Ny[0], Ny[1], Ny[2]);
return 1;
}
/* encode the stride structure of _x,_y,_z into a single integer. */
int unit = 0;
unit |= ((Sx[2] == type_size || Nx[2] == 1) ? 0x0 : (Sx[1] == type_size || Nx[1]==1) ? 0x1 : 0x2) << 8;
unit |= ((Sy[2] == type_size || Ny[2] == 1) ? 0x0 : (Sy[1] == type_size || Ny[1]==1) ? 0x1 : 0x2) << 4;
unit |= ((Sz[2] == type_size || Nz[2] == 1) ? 0x0 : (Sz[1] == type_size || Nz[1]==1) ? 0x1 : 0x2) << 0;
/* create appropriate strides for malformed matrices that are row or column
* vectors, or empty matrices.
* In that case, the value of the stride does not really matter, but
* some versions of BLAS insist that:
* - they are not smaller than the number of elements in the array,
* - they are not 0.
*/
int sx_1 = (Nx[1] > 1) ? Sx[1]/type_size : (Nx[2] + 1);
int sx_2 = (Nx[2] > 1) ? Sx[2]/type_size : (Nx[1] + 1);
int sy_1 = (Ny[1] > 1) ? Sy[1]/type_size : (Ny[2] + 1);
int sy_2 = (Ny[2] > 1) ? Sy[2]/type_size : (Ny[1] + 1);
int sz_1 = (Nz[1] > 1) ? Sz[1]/type_size : (Nz[2] + 1);
int sz_2 = (Nz[2] > 1) ? Sz[2]/type_size : (Nz[1] + 1);
dtype* x = (dtype*)PyArray_DATA(xs);
dtype* y = (dtype*)PyArray_DATA(ys);
dtype* z = (dtype*)PyArray_DATA(zs);
dtype a = 1.0;
dtype b = 0.0;
char N = 'N';
char T = 'T';
int Nz1 = Nz[1], Nz2 = Nz[2], Nx2 = Nx[2];
// loop over batch axis
for (int i = 0; i < Nz[0]; i++) {
switch(unit)
{
case 0x000: gemm(&N, &N, &Nz2, &Nz1, &Nx2, &a, y, &sy_1, x, &sx_1, &b, z, &sz_1); break;
case 0x100: gemm(&N, &T, &Nz2, &Nz1, &Nx2, &a, y, &sy_1, x, &sx_2, &b, z, &sz_1); break;
case 0x010: gemm(&T, &N, &Nz2, &Nz1, &Nx2, &a, y, &sy_2, x, &sx_1, &b, z, &sz_1); break;
case 0x110: gemm(&T, &T, &Nz2, &Nz1, &Nx2, &a, y, &sy_2, x, &sx_2, &b, z, &sz_1); break;
case 0x001: gemm(&T, &T, &Nz1, &Nz2, &Nx2, &a, x, &sx_1, y, &sy_1, &b, z, &sz_2); break;
case 0x101: gemm(&N, &T, &Nz1, &Nz2, &Nx2, &a, x, &sx_2, y, &sy_1, &b, z, &sz_2); break;
case 0x011: gemm(&T, &N, &Nz1, &Nz2, &Nx2, &a, x, &sx_1, y, &sy_2, &b, z, &sz_2); break;
case 0x111: gemm(&N, &N, &Nz1, &Nz2, &Nx2, &a, x, &sx_2, y, &sy_2, &b, z, &sz_2); break;
default: PyErr_SetString(PyExc_ValueError, "some matrix has no unit stride"); return 1;
};
x += Sx[0] / type_size;
y += Sy[0] / type_size;
z += Sz[0] / type_size;
}
return 0;
}
"""
return blas_header_text() + batch_gemm_defn
def c_libraries(self):
from theano.tensor.blas import ldflags
return ldflags()
def c_compile_args(self):
from theano.tensor.blas import ldflags
return ldflags(libs=False, flags=True)
def c_lib_dirs(self):
from theano.tensor.blas import ldflags
return ldflags(libs=False, libs_dir=True)
def c_header_dirs(self):
from theano.tensor.blas import ldflags
return ldflags(libs=False, include_dir=True)
def c_code_cleanup(self, node, name, inputs, outputs, sub):
return """
// clean up views
Py_XDECREF(xs); xs = 0;
Py_XDECREF(ys); ys = 0;
Py_XDECREF(zs); zs = 0;
"""
def c_code(self, node, name, inp, out, sub):
_x, _y = inp
_z, = out
fail = sub["fail"]
# generate contiguity condition
def contiguous(var, ndim):
strides = "PyArray_STRIDES(%s)" % var
return " && ".join([
" && ".join("{strides}[{i}] > 0 && {strides}[{i}] % type_size == 0"
.format(strides=strides, i=i) for i in range(ndim)),
"(%s)" % " || ".join("{strides}[{i}] == type_size"
.format(strides=strides, i=i) for i in range(ndim)),
])
x_ndim, y_ndim, z_ndim = node.inputs[0].ndim, node.inputs[1].ndim, node.outputs[0].ndim
# generate code to allocate output based on runtime input shapes
z_dims = ["PyArray_DIMS(%s)[0]" % _x]
if x_ndim == 3:
z_dims.append("PyArray_DIMS(%s)[1]" % _x)
if y_ndim == 3:
z_dims.append("PyArray_DIMS(%s)[2]" % _y)
assert len(z_dims) == z_ndim
z_shape_correct = " && ".join("PyArray_DIMS(%s)[%i] == %s"
% (_z, i, dim) for i, dim in enumerate(z_dims))
z_shape = ", ".join(z_dims)
z_contiguous = contiguous(_z, z_ndim)
allocate = """
if (NULL == %(_z)s || !(%(z_shape_correct)s) || !(%(z_contiguous)s))
{
npy_intp dims[%(z_ndim)s] = {%(z_shape)s};
Py_XDECREF(%(_z)s);
%(_z)s = (PyArrayObject*)PyArray_SimpleNew(
%(z_ndim)s, dims, PyArray_TYPE(%(_x)s));
if(!%(_z)s) {
PyErr_SetString(PyExc_MemoryError,
"failed to alloc BatchedDot output");
%(fail)s
}
}
""" % locals()
# code to reallocate inputs contiguously if necessary
contiguate = []
for var, ndim in [(_x, x_ndim), (_y, y_ndim)]:
_contiguous = contiguous(var, ndim)
contiguate.append("""
if (!(%(_contiguous)s)) {
PyArrayObject * _copy = (PyArrayObject *) PyArray_Copy(%(var)s);
if (!_copy)
%(fail)s
Py_XDECREF(%(var)s);
%(var)s = _copy;
}
""" % locals())
contiguate = "\n".join(contiguate)
def c_dimshuffle(newname, oldname, shape):
_fail = fail
_shape = ", ".join("1" if axis is None else "PyArray_DIMS(%s)[%i]" % (oldname, axis)
for axis in shape)
return """{
npy_intp dims[3] = {%(_shape)s};
PyArray_Dims newshape = {dims, 3};
%(newname)s = (PyArrayObject*)PyArray_Newshape(%(oldname)s, &newshape, NPY_ANYORDER);
if (!%(newname)s)
%(_fail)s
// make sure we didn't accidentally copy
assert(PyArray_DATA(%(oldname)s) == PyArray_DATA(%(newname)s));
}""" % locals()
# create tensor3 views for any of x, y, z that are not tensor3, so that
# we only need to implement the tensor3-tensor3 batched dot product.
# xs, ys and zs will point to these views, or to the original array if
# it was already tensor3.
# in the latter case, we artificially increase the reference count of
# the original array so that the c_code_cleanup method can decref them
# all indiscriminately.
upcast = []
if x_ndim == 3:
upcast.append("xs = %(_x)s; Py_XINCREF(xs);")
elif x_ndim == 2:
upcast.append(c_dimshuffle("xs", _x, (0, None, 1)))
if y_ndim == 3:
upcast.append("ys = %(_y)s; Py_XINCREF(ys);")
elif y_ndim == 2:
upcast.append(c_dimshuffle("ys", _y, (0, 1, None)))
if z_ndim == 3:
upcast.append("zs = %(_z)s; Py_XINCREF(zs);")
else:
upcast.append(c_dimshuffle(
"zs", _z, (0,
None if x_ndim == 2 else 1,
None if y_ndim == 2 else 1)))
upcast = "\n".join(upcast) % locals()
return """
int type_num = PyArray_DESCR(%(_x)s)->type_num;
int type_size = PyArray_DESCR(%(_x)s)->elsize; // in bytes
// xs, ys, zs will point to views onto %(_x)s, %(_y)s, %(_z)s
PyArrayObject *xs = 0, *ys = 0, *zs = 0;
if (PyArray_NDIM(%(_x)s) != %(x_ndim)s) {
PyErr_Format(PyExc_NotImplementedError,
"rank(x) != %(x_ndim)s. rank(x) is %%d.",
PyArray_NDIM(%(_x)s));
%(fail)s;
}
if (PyArray_NDIM(%(_y)s) != %(y_ndim)s) {
PyErr_Format(PyExc_NotImplementedError,
"rank(y) != %(y_ndim)s. rank(y) is %%d.",
PyArray_NDIM(%(_y)s));
%(fail)s;
}
if (%(_z)s && PyArray_NDIM(%(_z)s) != %(z_ndim)s) {
PyErr_Format(PyExc_NotImplementedError,
"rank(z) != %(z_ndim)s. rank(z) is %%d.",
PyArray_NDIM(%(_z)s));
%(fail)s;
}
// allocate output
%(allocate)s
// reallocate any noncontiguous arrays or arrays with invalid strides
%(contiguate)s
// add dims to make sure everything is tensor3
%(upcast)s
// from here on, use xs, ys and zs as they are tensor3 and share memory
// with the original %(_x)s, %(_y)s and %(_z)s arrays.
if ((PyArray_DESCR(xs)->type_num != NPY_DOUBLE)
&& (PyArray_DESCR(xs)->type_num != NPY_FLOAT))
{PyErr_SetString(PyExc_NotImplementedError, "type(x) is not double or float"); %(fail)s;}
if ((PyArray_DESCR(ys)->type_num != NPY_DOUBLE)
&& (PyArray_DESCR(ys)->type_num != NPY_FLOAT))
{PyErr_SetString(PyExc_NotImplementedError, "type(y) is not double or float"); %(fail)s;}
if ((PyArray_DESCR(zs)->type_num != NPY_DOUBLE)
&& (PyArray_DESCR(zs)->type_num != NPY_FLOAT))
{PyErr_SetString(PyExc_NotImplementedError, "type(z) is not double or float"); %(fail)s;}
if ((PyArray_DESCR(xs)->type_num != PyArray_DESCR(ys)->type_num)
||(PyArray_DESCR(xs)->type_num != PyArray_DESCR(zs)->type_num))
{ PyErr_SetString(PyExc_NotImplementedError, "type(x), type(y), type(z) are not all the same"); %(fail)s; }
switch (type_num)
{
case NPY_FLOAT:
if (batch_gemm<float>(sgemm_, type_size, xs, ys, zs)) {
%(fail)s;
}
break;
case NPY_DOUBLE:
if (batch_gemm<double>(dgemm_, type_size, xs, ys, zs)) {
%(fail)s;
}
break;
}
""" % locals()
def c_code_cache_version(self):
return None
def grad(self, inp, grads):
x, y = inp
gz, = grads
xdim, ydim, gdim = x.type.ndim, y.type.ndim, gz.type.ndim
# grad is a vector, so x is a matrix and y is a matrix
if gdim == 1:
xgrad = gz.dimshuffle(0, 'x') * y
ygrad = gz.dimshuffle(0, 'x') * x
# x is a matrix, y is a tensor3, grad is a matrix
elif xdim == 2 and ydim == 3:
xgrad = batched_dot(gz, y.dimshuffle(0, 2, 1))
ygrad = x.dimshuffle(0, 1, 'x') * gz.dimshuffle(0, 'x', 1)
# x is a tensor3, y is a matrix, grad is a matrix
elif xdim == 3 and ydim == 2:
xgrad = gz.dimshuffle(0, 1, 'x') * y.dimshuffle(0, 'x', 1)
ygrad = batched_dot(x.dimshuffle(0, 2, 1), gz)
# x is a tensor3, y is a tensor3, grad is a tensor3
elif xdim == ydim == 3:
xgrad = batched_dot(gz, y.dimshuffle(0, 2, 1))
ygrad = batched_dot(x.dimshuffle(0, 2, 1), gz)
# If x or y contain broadcastable dimensions but only one of
# them know that a matching dimensions is broadcastable, the
# above code don't always return the right broadcast pattern.
# This cause problem down the road. See gh-1461.
if xgrad.broadcastable != x.broadcastable:
xgrad = patternbroadcast(xgrad, x.broadcastable)
if ygrad.broadcastable != y.broadcastable:
ygrad = patternbroadcast(ygrad, y.broadcastable)
return xgrad, ygrad
def R_op(self, inputs, eval_points):
# R_op for batched_dot(a, b) evaluted at c for a and d for b is
# simply batched_dot(c, b) + batched_dot(a, d)
assert len(inputs) == 2
assert len(eval_points) == 2
if eval_points[0] is None and eval_points[1] is None:
return [None]
debugger_available = config.compute_test_value != 'off'
if debugger_available:
try:
iv0 = gof.op.get_test_value(inputs[0])
except AttributeError:
gof.op.missing_test_message(
'first input passed to BatchedDot.R_op has no test value')
debugger_available = False
try:
iv1 = gof.op.get_test_value(inputs[1])
except AttributeError:
gof.op.missing_test_message(
'second input passed to BatchedDot.R_op has no test value')
debugger_available = False
if eval_points[0]:
try:
ev0 = gof.op.get_test_value(eval_points[0])
except AttributeError:
gof.op.missing_test_message(
'first eval point passed to BatchedDot.R_op '
'has no test value')
debugger_available = False
if eval_points[1]:
try:
ev1 = gof.op.get_test_value(eval_points[1])
except AttributeError:
gof.op.missing_test_message(
'second eval point passed to BatchedDot.R_op '
'has no test value')
debugger_available = False
if debugger_available:
input_values = [iv0, iv1]
eval_point_values = [ev0, ev1]
for i in xrange(2):
if eval_point_values[i] is not None and \
input_values[i].shape != eval_point_values[i].shape:
raise ValueError(
'input ' + str(i) + ' and eval_point ' + str(i) +
' to BatchedDot.R_op should have the same shape, but '
'their shapes are %s and %s, respectively' % (
str(input_values[i].shape),
str(eval_point_values[i].shape)))
if eval_points[0]:
t1 = self(eval_points[0], inputs[1])
if eval_points[1]:
t2 = self(inputs[0], eval_points[1])
if eval_points[0] and eval_points[1]:
return [t1 + t2]
elif eval_points[0]:
return [t1]
else:
return [t2]
def infer_shape(self, node, shapes):
for shape_ in shapes:
if len(shape_) not in (2, 3):
raise NotImplementedError()
xshp, yshp = shapes
return [xshp[:-1] + yshp[2:]]
def batched_dot(a, b):
"""
Compute the batched dot product of two variables:
......@@ -3846,6 +3407,8 @@ def batched_dot(a, b):
return batched_tensordot(
a, b, [[a.ndim - 1], [numpy.maximum(1, b.ndim - 2)]])
else:
# avoid circular import
from blas import BatchedDot
return BatchedDot()(a, b)
......
......@@ -2179,6 +2179,439 @@ blas_optdb.register('local_dot22_to_dot22scalar',
11, 'fast_run')
class BatchedDot(Op):
"""
Computes the batched dot product of two variables:
batched_dot(a, b)[i] = dot(a[i], b[i])
"""
__props__ = ()
def make_node(self, *inputs):
inputs = list(map(T.as_tensor_variable, inputs))
if len(inputs) != 2:
raise TypeError("theano.tensor.blas.BatchedDot: 2 arguments"
" required, %d given " % len(inputs))
if inputs[0].ndim not in (2, 3):
raise TypeError("theano.tensor.blas.BatchedDot: input 0 (0-indexed)"
" must have ndim of 2 or 3, %d given. Consider"
" calling theano.tensor.batched_dot instead."
% inputs[0].ndim)
if inputs[1].ndim not in (2, 3):
raise TypeError("theano.tensor.blas.BatchedDot: input 1 (0-indexed)"
" must have ndim of 2 or 3, %d given. Consider"
" calling theano.tensor.batched_dot instead."
% inputs[1].ndim)
dtype = theano.scalar.upcast(*[input.type.dtype for input in inputs])
# upcast inputs to common dtype if needed
upcasted_inputs = [T.cast(input, dtype) for input in inputs]
broadcastable = ((inputs[0].type.broadcastable[0] or
inputs[1].type.broadcastable[0],) +
inputs[0].type.broadcastable[1:-1] +
inputs[1].type.broadcastable[2:])
return Apply(self, upcasted_inputs, [T.tensor(dtype, broadcastable)])
def perform(self, node, inp, out):
x, y = inp
z, = out
if x.shape[0] != y.shape[0]:
raise TypeError(
"theano.tensor.blas.BatchedDot: inputs [%s] must have the"
" same size in axis 0, but have sizes [%s]." %
(", ".join(map(str, inp)),
", ".join([str(i.shape[0]) for i in inp])))
shape = self.infer_shape(node, [i.shape for i in inp])[0]
dtype = node.outputs[0].dtype
z0 = z[0] = numpy.empty(shape, dtype=dtype)
for i in xrange(z0.shape[0]):
z0[i] = numpy.dot(x[i], y[i])
def c_support_code(self):
batch_gemm_defn = """
template<typename dtype, typename function>
bool batch_gemm(function gemm, int type_size,
PyArrayObject* xs, PyArrayObject* ys, PyArrayObject* zs) {
npy_intp *Nx = PyArray_DIMS(xs), *Sx = PyArray_STRIDES(xs);
npy_intp *Ny = PyArray_DIMS(ys), *Sy = PyArray_STRIDES(ys);
npy_intp *Nz = PyArray_DIMS(zs), *Sz = PyArray_STRIDES(zs);
if (Nx[0] != Ny[0]) {
PyErr_Format(PyExc_ValueError,
"Shape mismatch: batch sizes unequal."
" x.shape is (%d, %d, %d),"
" y.shape is (%d, %d, %d).",
Nx[0], Nx[1], Nx[2],
Ny[0], Ny[1], Ny[2]);
return 1;
}
if (Nx[2] != Ny[1]) {
PyErr_Format(PyExc_ValueError,
"Shape mismatch: summation axis sizes unequal."
" x.shape is (%d, %d, %d),"
" y.shape is (%d, %d, %d).",
Nx[0], Nx[1], Nx[2],
Ny[0], Ny[1], Ny[2]);
return 1;
}
/* encode the stride structure of _x,_y,_z into a single integer. */
int unit = 0;
unit |= ((Sx[2] == type_size || Nx[2] == 1) ? 0x0 : (Sx[1] == type_size || Nx[1]==1) ? 0x1 : 0x2) << 8;
unit |= ((Sy[2] == type_size || Ny[2] == 1) ? 0x0 : (Sy[1] == type_size || Ny[1]==1) ? 0x1 : 0x2) << 4;
unit |= ((Sz[2] == type_size || Nz[2] == 1) ? 0x0 : (Sz[1] == type_size || Nz[1]==1) ? 0x1 : 0x2) << 0;
/* create appropriate strides for malformed matrices that are row or column
* vectors, or empty matrices.
* In that case, the value of the stride does not really matter, but
* some versions of BLAS insist that:
* - they are not smaller than the number of elements in the array,
* - they are not 0.
*/
int sx_1 = (Nx[1] > 1) ? Sx[1]/type_size : (Nx[2] + 1);
int sx_2 = (Nx[2] > 1) ? Sx[2]/type_size : (Nx[1] + 1);
int sy_1 = (Ny[1] > 1) ? Sy[1]/type_size : (Ny[2] + 1);
int sy_2 = (Ny[2] > 1) ? Sy[2]/type_size : (Ny[1] + 1);
int sz_1 = (Nz[1] > 1) ? Sz[1]/type_size : (Nz[2] + 1);
int sz_2 = (Nz[2] > 1) ? Sz[2]/type_size : (Nz[1] + 1);
dtype* x = (dtype*)PyArray_DATA(xs);
dtype* y = (dtype*)PyArray_DATA(ys);
dtype* z = (dtype*)PyArray_DATA(zs);
dtype a = 1.0;
dtype b = 0.0;
char N = 'N';
char T = 'T';
int Nz1 = Nz[1], Nz2 = Nz[2], Nx2 = Nx[2];
// loop over batch axis
for (int i = 0; i < Nz[0]; i++) {
switch(unit)
{
case 0x000: gemm(&N, &N, &Nz2, &Nz1, &Nx2, &a, y, &sy_1, x, &sx_1, &b, z, &sz_1); break;
case 0x100: gemm(&N, &T, &Nz2, &Nz1, &Nx2, &a, y, &sy_1, x, &sx_2, &b, z, &sz_1); break;
case 0x010: gemm(&T, &N, &Nz2, &Nz1, &Nx2, &a, y, &sy_2, x, &sx_1, &b, z, &sz_1); break;
case 0x110: gemm(&T, &T, &Nz2, &Nz1, &Nx2, &a, y, &sy_2, x, &sx_2, &b, z, &sz_1); break;
case 0x001: gemm(&T, &T, &Nz1, &Nz2, &Nx2, &a, x, &sx_1, y, &sy_1, &b, z, &sz_2); break;
case 0x101: gemm(&N, &T, &Nz1, &Nz2, &Nx2, &a, x, &sx_2, y, &sy_1, &b, z, &sz_2); break;
case 0x011: gemm(&T, &N, &Nz1, &Nz2, &Nx2, &a, x, &sx_1, y, &sy_2, &b, z, &sz_2); break;
case 0x111: gemm(&N, &N, &Nz1, &Nz2, &Nx2, &a, x, &sx_2, y, &sy_2, &b, z, &sz_2); break;
default: PyErr_SetString(PyExc_ValueError, "some matrix has no unit stride"); return 1;
};
x += Sx[0] / type_size;
y += Sy[0] / type_size;
z += Sz[0] / type_size;
}
return 0;
}
"""
return blas_header_text() + batch_gemm_defn
def c_libraries(self):
return ldflags()
def c_compile_args(self):
return ldflags(libs=False, flags=True)
def c_lib_dirs(self):
return ldflags(libs=False, libs_dir=True)
def c_header_dirs(self):
return ldflags(libs=False, include_dir=True)
def c_code_cleanup(self, node, name, inputs, outputs, sub):
return """
// clean up views
Py_XDECREF(xs); xs = 0;
Py_XDECREF(ys); ys = 0;
Py_XDECREF(zs); zs = 0;
"""
def c_code(self, node, name, inp, out, sub):
_x, _y = inp
_z, = out
fail = sub["fail"]
# generate contiguity condition
def contiguous(var, ndim):
strides = "PyArray_STRIDES(%s)" % var
return " && ".join([
" && ".join("{strides}[{i}] > 0 && {strides}[{i}] % type_size == 0"
.format(strides=strides, i=i) for i in range(ndim)),
"(%s)" % " || ".join("{strides}[{i}] == type_size"
.format(strides=strides, i=i) for i in range(ndim)),
])
x_ndim, y_ndim, z_ndim = node.inputs[0].ndim, node.inputs[1].ndim, node.outputs[0].ndim
# generate code to allocate output based on runtime input shapes
z_dims = ["PyArray_DIMS(%s)[0]" % _x]
if x_ndim == 3:
z_dims.append("PyArray_DIMS(%s)[1]" % _x)
if y_ndim == 3:
z_dims.append("PyArray_DIMS(%s)[2]" % _y)
assert len(z_dims) == z_ndim
z_shape_correct = " && ".join("PyArray_DIMS(%s)[%i] == %s"
% (_z, i, dim) for i, dim in enumerate(z_dims))
z_shape = ", ".join(z_dims)
z_contiguous = contiguous(_z, z_ndim)
allocate = """
if (NULL == %(_z)s || !(%(z_shape_correct)s) || !(%(z_contiguous)s))
{
npy_intp dims[%(z_ndim)s] = {%(z_shape)s};
Py_XDECREF(%(_z)s);
%(_z)s = (PyArrayObject*)PyArray_SimpleNew(
%(z_ndim)s, dims, PyArray_TYPE(%(_x)s));
if(!%(_z)s) {
PyErr_SetString(PyExc_MemoryError,
"failed to alloc BatchedDot output");
%(fail)s
}
}
""" % locals()
# code to reallocate inputs contiguously if necessary
contiguate = []
for var, ndim in [(_x, x_ndim), (_y, y_ndim)]:
_contiguous = contiguous(var, ndim)
contiguate.append("""
if (!(%(_contiguous)s)) {
PyArrayObject * _copy = (PyArrayObject *) PyArray_Copy(%(var)s);
if (!_copy)
%(fail)s
Py_XDECREF(%(var)s);
%(var)s = _copy;
}
""" % locals())
contiguate = "\n".join(contiguate)
def c_dimshuffle(newname, oldname, shape):
_fail = fail
_shape = ", ".join("1" if axis is None else "PyArray_DIMS(%s)[%i]" % (oldname, axis)
for axis in shape)
return """{
npy_intp dims[3] = {%(_shape)s};
PyArray_Dims newshape = {dims, 3};
%(newname)s = (PyArrayObject*)PyArray_Newshape(%(oldname)s, &newshape, NPY_ANYORDER);
if (!%(newname)s)
%(_fail)s
// make sure we didn't accidentally copy
assert(PyArray_DATA(%(oldname)s) == PyArray_DATA(%(newname)s));
}""" % locals()
# create tensor3 views for any of x, y, z that are not tensor3, so that
# we only need to implement the tensor3-tensor3 batched dot product.
# xs, ys and zs will point to these views, or to the original array if
# it was already tensor3.
# in the latter case, we artificially increase the reference count of
# the original array so that the c_code_cleanup method can decref them
# all indiscriminately.
upcast = []
if x_ndim == 3:
upcast.append("xs = %(_x)s; Py_XINCREF(xs);")
elif x_ndim == 2:
upcast.append(c_dimshuffle("xs", _x, (0, None, 1)))
if y_ndim == 3:
upcast.append("ys = %(_y)s; Py_XINCREF(ys);")
elif y_ndim == 2:
upcast.append(c_dimshuffle("ys", _y, (0, 1, None)))
if z_ndim == 3:
upcast.append("zs = %(_z)s; Py_XINCREF(zs);")
else:
upcast.append(c_dimshuffle(
"zs", _z, (0,
None if x_ndim == 2 else 1,
None if y_ndim == 2 else 1)))
upcast = "\n".join(upcast) % locals()
return """
int type_num = PyArray_DESCR(%(_x)s)->type_num;
int type_size = PyArray_DESCR(%(_x)s)->elsize; // in bytes
// xs, ys, zs will point to views onto %(_x)s, %(_y)s, %(_z)s
PyArrayObject *xs = 0, *ys = 0, *zs = 0;
if (PyArray_NDIM(%(_x)s) != %(x_ndim)s) {
PyErr_Format(PyExc_NotImplementedError,
"rank(x) != %(x_ndim)s. rank(x) is %%d.",
PyArray_NDIM(%(_x)s));
%(fail)s;
}
if (PyArray_NDIM(%(_y)s) != %(y_ndim)s) {
PyErr_Format(PyExc_NotImplementedError,
"rank(y) != %(y_ndim)s. rank(y) is %%d.",
PyArray_NDIM(%(_y)s));
%(fail)s;
}
if (%(_z)s && PyArray_NDIM(%(_z)s) != %(z_ndim)s) {
PyErr_Format(PyExc_NotImplementedError,
"rank(z) != %(z_ndim)s. rank(z) is %%d.",
PyArray_NDIM(%(_z)s));
%(fail)s;
}
// allocate output
%(allocate)s
// reallocate any noncontiguous arrays or arrays with invalid strides
%(contiguate)s
// add dims to make sure everything is tensor3
%(upcast)s
// from here on, use xs, ys and zs as they are tensor3 and share memory
// with the original %(_x)s, %(_y)s and %(_z)s arrays.
if ((PyArray_DESCR(xs)->type_num != NPY_DOUBLE)
&& (PyArray_DESCR(xs)->type_num != NPY_FLOAT))
{PyErr_SetString(PyExc_NotImplementedError, "type(x) is not double or float"); %(fail)s;}
if ((PyArray_DESCR(ys)->type_num != NPY_DOUBLE)
&& (PyArray_DESCR(ys)->type_num != NPY_FLOAT))
{PyErr_SetString(PyExc_NotImplementedError, "type(y) is not double or float"); %(fail)s;}
if ((PyArray_DESCR(zs)->type_num != NPY_DOUBLE)
&& (PyArray_DESCR(zs)->type_num != NPY_FLOAT))
{PyErr_SetString(PyExc_NotImplementedError, "type(z) is not double or float"); %(fail)s;}
if ((PyArray_DESCR(xs)->type_num != PyArray_DESCR(ys)->type_num)
||(PyArray_DESCR(xs)->type_num != PyArray_DESCR(zs)->type_num))
{ PyErr_SetString(PyExc_NotImplementedError, "type(x), type(y), type(z) are not all the same"); %(fail)s; }
switch (type_num)
{
case NPY_FLOAT:
if (batch_gemm<float>(sgemm_, type_size, xs, ys, zs)) {
%(fail)s;
}
break;
case NPY_DOUBLE:
if (batch_gemm<double>(dgemm_, type_size, xs, ys, zs)) {
%(fail)s;
}
break;
}
""" % locals()
def c_code_cache_version(self):
return None
def grad(self, inp, grads):
x, y = inp
gz, = grads
xdim, ydim, gdim = x.type.ndim, y.type.ndim, gz.type.ndim
# grad is a vector, so x is a matrix and y is a matrix
if gdim == 1:
xgrad = gz.dimshuffle(0, 'x') * y
ygrad = gz.dimshuffle(0, 'x') * x
# x is a matrix, y is a tensor3, grad is a matrix
elif xdim == 2 and ydim == 3:
xgrad = T.batched_dot(gz, y.dimshuffle(0, 2, 1))
ygrad = x.dimshuffle(0, 1, 'x') * gz.dimshuffle(0, 'x', 1)
# x is a tensor3, y is a matrix, grad is a matrix
elif xdim == 3 and ydim == 2:
xgrad = gz.dimshuffle(0, 1, 'x') * y.dimshuffle(0, 'x', 1)
ygrad = T.batched_dot(x.dimshuffle(0, 2, 1), gz)
# x is a tensor3, y is a tensor3, grad is a tensor3
elif xdim == ydim == 3:
xgrad = T.batched_dot(gz, y.dimshuffle(0, 2, 1))
ygrad = T.batched_dot(x.dimshuffle(0, 2, 1), gz)
# If x or y contain broadcastable dimensions but only one of
# them know that a matching dimensions is broadcastable, the
# above code don't always return the right broadcast pattern.
# This cause problem down the road. See gh-1461.
if xgrad.broadcastable != x.broadcastable:
xgrad = T.patternbroadcast(xgrad, x.broadcastable)
if ygrad.broadcastable != y.broadcastable:
ygrad = T.patternbroadcast(ygrad, y.broadcastable)
return xgrad, ygrad
def R_op(self, inputs, eval_points):
# R_op for batched_dot(a, b) evaluted at c for a and d for b is
# simply batched_dot(c, b) + batched_dot(a, d)
assert len(inputs) == 2
assert len(eval_points) == 2
if eval_points[0] is None and eval_points[1] is None:
return [None]
debugger_available = config.compute_test_value != 'off'
if debugger_available:
try:
iv0 = theano.gof.op.get_test_value(inputs[0])
except AttributeError:
theano.gof.op.missing_test_message(
'first input passed to BatchedDot.R_op has no test value')
debugger_available = False
try:
iv1 = theano.gof.op.get_test_value(inputs[1])
except AttributeError:
theano.gof.op.missing_test_message(
'second input passed to BatchedDot.R_op has no test value')
debugger_available = False
if eval_points[0]:
try:
ev0 = theano.gof.op.get_test_value(eval_points[0])
except AttributeError:
theano.gof.op.missing_test_message(
'first eval point passed to BatchedDot.R_op '
'has no test value')
debugger_available = False
if eval_points[1]:
try:
ev1 = theano.gof.op.get_test_value(eval_points[1])
except AttributeError:
theano.gof.op.missing_test_message(
'second eval point passed to BatchedDot.R_op '
'has no test value')
debugger_available = False
if debugger_available:
input_values = [iv0, iv1]
eval_point_values = [ev0, ev1]
for i in xrange(2):
if eval_point_values[i] is not None and \
input_values[i].shape != eval_point_values[i].shape:
raise ValueError(
'input ' + str(i) + ' and eval_point ' + str(i) +
' to BatchedDot.R_op should have the same shape, but '
'their shapes are %s and %s, respectively' % (
str(input_values[i].shape),
str(eval_point_values[i].shape)))
if eval_points[0]:
t1 = self(eval_points[0], inputs[1])
if eval_points[1]:
t2 = self(inputs[0], eval_points[1])
if eval_points[0] and eval_points[1]:
return [t1 + t2]
elif eval_points[0]:
return [t1]
else:
return [t2]
def infer_shape(self, node, shapes):
for shape_ in shapes:
if len(shape_) not in (2, 3):
raise NotImplementedError()
xshp, yshp = shapes
return [xshp[:-1] + yshp[2:]]
# from opt import register_specialize, register_canonicalize
# @register_specialize
@local_optimizer([T.sub, T.add])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论