提交 5cb8526d authored 作者: James Bergstra's avatar James Bergstra

added _dot22 and some gemm optimizations

上级 7691290e
......@@ -10,10 +10,246 @@ import theano.sandbox
import theano.sandbox.wraplinker
from theano.compile import module, Mode
from theano.sandbox.wraplinker import ProfileMode
from theano import gof, Op, Apply
from theano.tensor import blas, opt
# numpy: aa_numpy.py
# c : aa.cc
class _Dot22(Op):
"""Compute matrix-matrix product.
This is a specialization of the more general Dot()
"""
def make_node(self, x, y):
assert x.type in T.float_matrix_types
assert y.type == x.type
bz = [x.type.broadcastable[0], y.type.broadcastable[1]]
outputs = [T.tensor(x.type.dtype, bz)]
return Apply(self, [x,y], outputs)
def perform(self, node, (x, y), (z, )):
try:
z[0] = numpy.asarray(numpy.dot(x, y))
except ValueError, e:
# The error raised by numpy has no shape information, we mean to add that
e.args = e.args + (x.shape, y.shape)
raise
def __str__(self):
return "_dot22"
def c_support_code(self):
#return blas.cblas_header_text()
mod_str = """
#ifndef MOD
#define MOD %
#endif
"""
return blas.blas_proto() + mod_str
def c_headers(self):
return ['<iostream>']
def c_libraries(self):
return blas.ldflags()
def c_code(self, node, name, (_x, _y), (_z, ), sub):
return """
int unit = 0;
int type_num = %(_x)s->descr->type_num;
int type_size = %(_x)s->descr->elsize; // in bytes
npy_intp* Nx = %(_x)s->dimensions;
npy_intp* Ny = %(_y)s->dimensions;
npy_intp* Nz = 0; //%(_z)s->dimensions;
npy_intp* Sx = %(_x)s->strides;
npy_intp* Sy = %(_y)s->strides;
npy_intp* Sz = 0;//%(_z)s->strides;
//strides for x, y, z in dimensions 0, 1
int sx_0, sx_1, sy_0, sy_1, sz_0, sz_1;
if ((NULL == %(_z)s)
|| (%(_z)s->dimensions[0] != %(_x)s->dimensions[0])
|| (%(_z)s->dimensions[1] != %(_y)s->dimensions[1]))
{
if (NULL != %(_z)s) Py_XDECREF(%(_z)s);
npy_intp dims[2];
dims[0] = %(_x)s->dimensions[0];
dims[1] = %(_y)s->dimensions[1];
%(_z)s = (PyArrayObject*)PyArray_SimpleNew(2, dims, type_num_%(_x)s);
if(!%(_z)s) {
PyErr_SetString(PyExc_MemoryError, "failed to alloc dot22 output");
%(fail)s
}
}
Nz = %(_z)s->dimensions;
Sz = %(_z)s->strides;
if (%(_x)s->nd != 2) {PyErr_SetString(PyExc_NotImplementedError, "rank(x) != 2"); %(fail)s;}
if (%(_y)s->nd != 2) {PyErr_SetString(PyExc_NotImplementedError, "rank(y) != 2"); %(fail)s;}
if (%(_z)s->nd != 2) {PyErr_SetString(PyExc_NotImplementedError, "rank(z) != 2"); %(fail)s;}
if ((%(_x)s->descr->type_num != PyArray_DOUBLE)
&& (%(_x)s->descr->type_num != PyArray_FLOAT))
{PyErr_SetString(PyExc_NotImplementedError, "type(x) is not double or float"); %(fail)s;}
if ((%(_y)s->descr->type_num != PyArray_DOUBLE)
&& (%(_y)s->descr->type_num != PyArray_FLOAT))
{PyErr_SetString(PyExc_NotImplementedError, "type(y) is not double or float"); %(fail)s;}
if ((%(_z)s->descr->type_num != PyArray_DOUBLE)
&& (%(_z)s->descr->type_num != PyArray_FLOAT))
{PyErr_SetString(PyExc_NotImplementedError, "type(z) is not double or float"); %(fail)s;}
if ((%(_x)s->descr->type_num != %(_y)s->descr->type_num)
||(%(_x)s->descr->type_num != %(_z)s->descr->type_num))
{ PyErr_SetString(PyExc_NotImplementedError, "type(z), type(y), type(z) are not all the same"); %(fail)s; }
if ((Nx[0] != Nz[0]) || (Nx[1] != Ny[0]) || (Ny[1] != Nz[1]))
{
PyErr_SetString(PyExc_ValueError, "Input dimensions do not agree");
%(fail)s;
}
if ((Sx[0] < 1) || (Sx[1] < 1) || (Sx[0] MOD type_size) || (Sx[1] MOD type_size)
|| (Sy[0] < 1) || (Sy[1] < 1) || (Sy[0] MOD type_size) || (Sy[1] MOD type_size)
|| (Sz[0] < 1) || (Sz[1] < 1) || (Sz[0] MOD type_size) || (Sz[1] MOD type_size))
{
PyErr_SetString(PyExc_ValueError, "stride is not multiple of element size"); %(fail)s;
}
/*
encode the stride structure of _x,_y,_z into a single integer
*/
unit |= ((Sx[1] == type_size) ? 0x0 : (Sx[0] == type_size) ? 0x1 : 0x2) << 8;
unit |= ((Sy[1] == type_size) ? 0x0 : (Sy[0] == type_size) ? 0x1 : 0x2) << 4;
unit |= ((Sz[1] == type_size) ? 0x0 : (Sz[0] == type_size) ? 0x1 : 0x2) << 0;
/* create appropriate strides for malformed matrices that are row or column
* vectors
*/
sx_0 = (Nx[0] > 1) ? Sx[0]/type_size : Nx[1];
sx_1 = (Nx[1] > 1) ? Sx[1]/type_size : Nx[0];
sy_0 = (Ny[0] > 1) ? Sy[0]/type_size : Ny[1];
sy_1 = (Ny[1] > 1) ? Sy[1]/type_size : Ny[0];
sz_0 = (Nz[0] > 1) ? Sz[0]/type_size : Nz[1];
sz_1 = (Nz[1] > 1) ? Sz[1]/type_size : Nz[0];
switch (type_num)
{
case PyArray_FLOAT:
{
float a = 1.0;
float b = 0.0;
float* x = (float*)PyArray_DATA(%(_x)s);
float* y = (float*)PyArray_DATA(%(_y)s);
float* z = (float*)PyArray_DATA(%(_z)s);
char N = 'N';
char T = 'T';
int Nz0 = Nz[0], Nz1 = Nz[1], Nx1 = Nx[1];
//std::cerr << (unit/256) MOD 16 << (unit / 16) MOD 16 << unit MOD 16<< '\\n';
switch(unit)
{
case 0x000: sgemm_(&N, &N, &Nz1, &Nz0, &Nx1, &a, y, &sy_0, x, &sx_0, &b, z, &sz_0); break;
case 0x100: sgemm_(&N, &T, &Nz1, &Nz0, &Nx1, &a, y, &sy_0, x, &sx_1, &b, z, &sz_0); break;
case 0x010: sgemm_(&T, &N, &Nz1, &Nz0, &Nx1, &a, y, &sy_1, x, &sx_0, &b, z, &sz_0); break;
case 0x110: sgemm_(&T, &T, &Nz1, &Nz0, &Nx1, &a, y, &sy_1, x, &sx_1, &b, z, &sz_0); break;
case 0x001: sgemm_(&T, &T, &Nz0, &Nz1, &Nx1, &a, x, &sx_0, y, &sy_0, &b, z, &sz_1); break;
case 0x101: sgemm_(&N, &T, &Nz0, &Nz1, &Nx1, &a, x, &sx_1, y, &sy_0, &b, z, &sz_1); break;
case 0x011: sgemm_(&T, &N, &Nz0, &Nz1, &Nx1, &a, x, &sx_0, y, &sy_1, &b, z, &sz_1); break;
case 0x111: sgemm_(&N, &N, &Nz0, &Nz1, &Nx1, &a, x, &sx_1, y, &sy_1, &b, z, &sz_1); break;
default: PyErr_SetString(PyExc_ValueError, "some matrix has no unit stride"); %(fail)s;
};
#undef REAL
}
break;
case PyArray_DOUBLE:
{
double a = 1.0;
double b = 0.0;
double* x = (double*)PyArray_DATA(%(_x)s);
double* y = (double*)PyArray_DATA(%(_y)s);
double* z = (double*)PyArray_DATA(%(_z)s);
char N = 'N';
char T = 'T';
int Nz0 = Nz[0], Nz1 = Nz[1], Nx1 = Nx[1];
//std::cerr << (unit/256) MOD 16 << (unit / 16) MOD 16 << unit MOD 16<< '\\n';
switch(unit)
{
case 0x000: dgemm_(&N, &N, &Nz1, &Nz0, &Nx1, &a, y, &sy_0, x, &sx_0, &b, z, &sz_0); break;
case 0x100: dgemm_(&N, &T, &Nz1, &Nz0, &Nx1, &a, y, &sy_0, x, &sx_1, &b, z, &sz_0); break;
case 0x010: dgemm_(&T, &N, &Nz1, &Nz0, &Nx1, &a, y, &sy_1, x, &sx_0, &b, z, &sz_0); break;
case 0x110: dgemm_(&T, &T, &Nz1, &Nz0, &Nx1, &a, y, &sy_1, x, &sx_1, &b, z, &sz_0); break;
case 0x001: dgemm_(&T, &T, &Nz0, &Nz1, &Nx1, &a, x, &sx_0, y, &sy_0, &b, z, &sz_1); break;
case 0x101: dgemm_(&N, &T, &Nz0, &Nz1, &Nx1, &a, x, &sx_1, y, &sy_0, &b, z, &sz_1); break;
case 0x011: dgemm_(&T, &N, &Nz0, &Nz1, &Nx1, &a, x, &sx_0, y, &sy_1, &b, z, &sz_1); break;
case 0x111: dgemm_(&N, &N, &Nz0, &Nz1, &Nx1, &a, x, &sx_1, y, &sy_1, &b, z, &sz_1); break;
default: PyErr_SetString(PyExc_ValueError, "some matrix has no unit stride"); %(fail)s;
};
#undef REAL
}
break;
}
""" % dict(locals(), **sub)
_dot22 = _Dot22()
@gof.local_optimizer([T.dot])
def local_dot_to_dot22(node):
if node.op == T.dot:
return [_dot22(*node.inputs)]
else:
return False
T.opt.register_specialize(local_dot_to_dot22)
@gof.local_optimizer([T.sub])
def local_sub_to_gemm(node):
if node.op == T.sub:
subleft, subright = node.inputs
if subright.owner and (subright.owner.op == _dot22):
dotleft, dotright = subright.owner.inputs
return [T.gemm(subleft, -1.0, dotleft, dotright, 1.0)]
if subright.owner and (subright.owner.op == T.mul):
mulleft, mulright = subright.owner.inputs
#TODO: we actually want to get any scalar here, not necessrily a constant
mulleft_const = opt.local_mul_canonizer.get_constant(mulleft)
if mulleft_const is not None:
mulleft_const = mulleft_const.flatten()[0]
#subleft - (mulleft_const * ?)
if mulright.owner and (mulright.owner.op == T.add):
#subleft - (mulleft_const * (? + ?))
addleft, addright = mulright.owner.inputs
if addright.owner and addright.owner.op == T.DimShuffle([False,False], [1,0]):
#subleft - (mulleft_const * (? + ?.T))
raise NotImplementedError()
if addright.owner and addright.owner.op == T.DimShuffle([False,False], [1,0], inplace=True):
#subleft - (mulleft_const * (? + ?.T))
transposed = addright.owner.inputs[0]
if transposed.owner and transposed.owner.op == _dot22:
x, y = transposed.owner.inputs
#subleft - (mulleft_const * (addleft + dot(x, y).T))
if addleft.owner and addleft.owner.op == _dot22:
u, v = addleft.owner.inputs
#subleft - (mulleft_const * (dot(u,v) + dot(x, y).T))
return [T.gemm(
T.gemm(subleft, -mulleft_const, y.T, x.T, 1.0),
-mulleft_const, u, v, 1.0)]
if mulright.owner and (mulright.owner.op == _dot22):
dotleft, dotright = mulright.owner.inputs
#TODO: we actually want to get any scalar here, not necessrily a constant
mulleft_const = opt.local_mul_canonizer.get_constant(mulleft)
if mulleft_const:
return [T.gemm(subleft, -mulleft_const, dotleft, dotright, 1.0)]
if mulleft.owner and (mulleft.owner.op == _dot22):
dotleft, dotright = mulleft.owner.inputs
#TODO: we actually want to get any scalar here, not necessrily a constant
mulright_const = opt.local_mul_canonizer.get_constant(mulright)
if mulright_const:
return [T.gemm(subleft, -mulright_const, dotleft, dotright, 1.0)]
return False
T.opt.register_specialize(local_sub_to_gemm)
if 0:
class Opt(object):
merge = theano.gof.MergeOptimizer()
......
......@@ -1089,45 +1089,11 @@ pprint.assign(pow, printing.OperatorPrinter('**', 1, 'right'))
# View Operations
##########################
if 0:
class _TransposeInplace(Op):
view_map = {0: [0]}
def make_node(self, input):
return Apply(self, [input], [tensor(dtype = input.type.dtype,
broadcastable = reversed(input.type.broadcastable))])
def perform(self, node, (x, ), (z, )):
z[0] = x.T
def grad(self, (x,), (gz,)):
return transpose(gz),
def c_code(self, node, name, (x, ), (z, ), sub):
return """
PyArrayObject* transposed = (PyArrayObject*)PyArray_Transpose(%(x)s, NULL);
if (%(z)s) {
Py_XDECREF(%(z)s);
}
%(z)s = transposed;
""" % locals()
def __str__(self):
return "TransposeView"
_transpose_inplace = _TransposeInplace()
def _old_transpose(x, **kwargs):
"""WRITEME"""
return _transpose_inplace(tensor_copy(x), **kwargs)
def transpose(x, **kwargs):
dims = range(x.ndim-1, -1, -1)
return DimShuffle(x.broadcastable, dims, inplace=True)(tensor_copy(x))
class Subtensor(Op):
"""Return a subtensor view
......@@ -1786,6 +1752,7 @@ class Dot(Op):
# The error raised by numpy has no shape information, we mean to add that
e.args = e.args + (x.shape, y.shape)
raise
def grad(self, (x, y), (gz,)):
if gz.type.ndim == 0:
return gz * y, gz * x
......
......@@ -625,6 +625,37 @@ def local_pow_specialize(node):
return False
register_specialize(local_pow_specialize)
if 0: #TODO: replace this with a c version of any InplaceDimShuffle
class _TransposeInplace(T.Op):
view_map = {0: [0]}
def make_node(self, input):
return T.Apply(self, [input],
[T.tensor(dtype = input.type.dtype,
broadcastable = reversed(input.type.broadcastable))])
def perform(self, node, (x, ), (z, )):
z[0] = x.T
def c_code(self, node, name, (x, ), (z, ), sub):
return """
PyArrayObject* transposed = (PyArrayObject*)PyArray_Transpose(%(x)s, NULL);
if (%(z)s) {
Py_XDECREF(%(z)s);
}
%(z)s = transposed;
""" % locals()
def __str__(self):
return "_TransposeInplace"
_transpose_inplace = _TransposeInplace()
@gof.local_optimizer([T.DimShuffle([False,False],[1,0],inplace=True)])
def local_dimshuffle_transposeinplace(node):
if node.op == T.DimShuffle([False,False],[1,0],inplace=True):
return [_transpose_inplace(node.inputs[0])]
return False
register_specialize(local_dimshuffle_transposeinplace)
register_canonicalize(local_mul_canonizer, name = 'local_mul_canonizer')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论