gemm restricted to rank 2, implemented in C. added __hash__ to Op, Result for…

gemm restricted to rank 2, implemented in C. added __hash__ to Op, Result for consistent set orderings
上级 201629ac
...@@ -11,3 +11,4 @@ pull.sh ...@@ -11,3 +11,4 @@ pull.sh
*.pyc *.pyc
*.so *.so
*.sw? *.sw?
core.*
...@@ -627,11 +627,12 @@ class t_gemm(unittest.TestCase): ...@@ -627,11 +627,12 @@ class t_gemm(unittest.TestCase):
return numpy.random.rand(*args) return numpy.random.rand(*args)
def cmp(self, z, a, x, y, b): def cmp(self, z, a, x, y, b):
def cmp_linker(z, a, x, y, b, l):
z,a,x,y,b = [numpy.asarray(p) for p in z,a,x,y,b] z,a,x,y,b = [numpy.asarray(p) for p in z,a,x,y,b]
cz = z.copy() cz = z.copy()
tz,ta,tx,ty,tb = [astensor(p) for p in z,a,x,y,b] tz,ta,tx,ty,tb = [astensor(p) for p in z,a,x,y,b]
f = Function([tz,ta,tx,ty,tb], [gemm(tz,ta,tx,ty,tb)]) f = Function([tz,ta,tx,ty,tb], [gemm(tz,ta,tx,ty,tb)], linker_cls=l)
new_z = f(z,a,x,y,b) new_z = f(z,a,x,y,b)
_z = self._gemm(cz, a, x, y, b) _z = self._gemm(cz, a, x, y, b)
...@@ -644,19 +645,35 @@ class t_gemm(unittest.TestCase): ...@@ -644,19 +645,35 @@ class t_gemm(unittest.TestCase):
else: else:
self.failIf(numpy.all(cz == z)) self.failIf(numpy.all(cz == z))
cmp_linker(copy(z), a, x, y, b, gof.cc.OpWiseCLinker)
#cmp_linker(copy(z), a, x, y, b, gof.cc.CLinker)
cmp_linker(copy(z), a, x, y, b, gof.link.PerformLinker)
def test0a(self):
try:
g = gemm([1.], 1., [1.], [1.], 1.)
except ValueError, e:
if e[0] is Gemm.E_rank:
return
self.fail()
def test0(self):
try:
self.cmp(1., 0., 1.0, 1.0, 1.0)
except ValueError, e:
if e[0] is Gemm.E_rank:
return
self.fail()
def test0(self): self.cmp(1., 0., 1.0, 1.0, 1.0)
def test1(self): self.cmp(2., 0., 1.0, 1.0, 0.0)
def test2(self): def test2(self):
try: try:
self.cmp(2., 1.0, [3,2,1.], [[1],[2],[3.]], 1.0) self.cmp(2., 1.0, [3,2,1.], [[1],[2],[3.]], 1.0)
except ValueError, e: except ValueError, e:
self.failUnless(e[0] == Gemm.E_bcast) self.failUnless(e[0] == Gemm.E_rank)
return return
self.fail() self.fail()
def test3(self): self.cmp([2.], 1.,[3,2,1.], [[1],[2],[3.]], 1.0) def test4(self):
def test4(self): self.cmp(self.rand(3,4), 1.0, self.cmp(self.rand(3,4), 1.0, self.rand(3,5), self.rand(5,4), 0.0)
self.rand(3,5), self.rand(5,4), 0.0)
def test5(self): self.cmp(self.rand(3,4), 1.0, def test5(self): self.cmp(self.rand(3,4), 1.0,
self.rand(3,5), self.rand(5,4), 1.0) self.rand(3,5), self.rand(5,4), 1.0)
def test6(self): self.cmp(self.rand(3,4), 1.0, def test6(self): self.cmp(self.rand(3,4), 1.0,
......
...@@ -26,6 +26,8 @@ def cblas_header_text(): ...@@ -26,6 +26,8 @@ def cblas_header_text():
__BEGIN_DECLS __BEGIN_DECLS
#define MOD %
/* /*
* Enumerated and derived types * Enumerated and derived types
*/ */
......
...@@ -4,6 +4,7 @@ Contains the Op class, which is the base interface for all operations ...@@ -4,6 +4,7 @@ Contains the Op class, which is the base interface for all operations
compatible with gof's graph manipulation routines. compatible with gof's graph manipulation routines.
""" """
import utils
from utils import ClsInit, all_bases, all_bases_collect, AbstractFunctionError from utils import ClsInit, all_bases, all_bases_collect, AbstractFunctionError
import graph import graph
...@@ -36,7 +37,7 @@ class Op(object): ...@@ -36,7 +37,7 @@ class Op(object):
not required) that it creates them. not required) that it creates them.
""" """
__slots__ = ['_inputs', '_outputs'] __slots__ = ['_inputs', '_outputs', '_hash_id']
_default_output_idx = 0 _default_output_idx = 0
...@@ -52,7 +53,29 @@ class Op(object): ...@@ -52,7 +53,29 @@ class Op(object):
def __init__(self, **kwargs): def __init__(self, **kwargs):
pass self._hash_id = utils.hashgen()
#
# Python stdlib compatibility
#
def __cmp__(self, other):
return cmp(id(self), id(other))
def __eq__(self, other):
return self is other #assuming this is faster, equiv to id(self) == id(other)
def __ne__(self, other):
return self is not other #assuming this is faster, equiv to id(self) != id(other)
def __hash__(self):
if not hasattr(self, '_hash_id'):
self._hash_id = utils.hashgen()
return self._hash_id
#
#
#
def get_input(self, i): def get_input(self, i):
return self._inputs[i] return self._inputs[i]
......
...@@ -5,6 +5,7 @@ value that is the input or the output of an Op. ...@@ -5,6 +5,7 @@ value that is the input or the output of an Op.
""" """
import copy import copy
import utils
from utils import AbstractFunctionError from utils import AbstractFunctionError
...@@ -52,14 +53,30 @@ class ResultBase(object): ...@@ -52,14 +53,30 @@ class ResultBase(object):
data_filter data_filter
""" """
__slots__ = ['_role', '_data', 'state', '_name'] __slots__ = ['_role', '_data', 'state', '_name', '_hash_id']
def __init__(self, role=None, name=None): def __init__(self, role=None, name=None):
self._role = role self._role = role
self._data = [None] self._data = [None]
self.state = Empty self.state = Empty
self.name = name self.name = name
self._hash_id = utils.hashgen()
#
# Python stdlib compatibility
#
def __cmp__(self, other):
return cmp(id(self), id(other))
def __eq__(self, other):
return self is other #assuming this is faster, equiv to id(self) == id(other)
def __ne__(self, other):
return self is not other #assuming this is faster, equiv to id(self) != id(other)
def __hash__(self):
return self._hash_id
# #
# role # role
......
...@@ -4,6 +4,10 @@ ...@@ -4,6 +4,10 @@
import re import re
def hashgen():
hashgen.next += 1
return hashgen.next
hashgen.next = 0
class OmegaError(Exception): pass class OmegaError(Exception): pass
......
...@@ -10,6 +10,7 @@ import gof.op ...@@ -10,6 +10,7 @@ import gof.op
from base_tensor import BaseTensor, BaseTensorOp from base_tensor import BaseTensor, BaseTensorOp
from elemwise import Elemwise from elemwise import Elemwise
import blas # for gemm, dot
class Tensor(BaseTensor): class Tensor(BaseTensor):
...@@ -712,12 +713,17 @@ dot = gof.op.constructor(Dot) ...@@ -712,12 +713,17 @@ dot = gof.op.constructor(Dot)
class Gemm(_Op): class Gemm(_Op):
nin=5 nin=5
nout=1 nout=1
E_bcast = 'incompatible broadcastable flags' E_rank = 'gemm only works for rank 2'
E_scalar = 'gemm requires scalar argument'
def destroy_map(self): def destroy_map(self):
return {self.out:[self.inputs[0]]} return {self.out:[self.inputs[0]]}
def propagate_broadcastable(self, bz, ba, bx, by, bb): def propagate_broadcastable(self, bz, ba, bx, by, bb):
if len(bz) != len(Dot.broadcastable_rule(bx,by)): if len(bz) != 2: raise ValueError(Gemm.E_rank, len(bz))
raise ValueError(Gemm.E_bcast, bz, bx, by) if len(bx) != 2: raise ValueError(Gemm.E_rank, len(bx))
if len(by) != 2: raise ValueError(Gemm.E_rank, len(by))
if len(ba): raise ValueError(Gemm.E_scalar, ba)
if len(bb): raise ValueError(Gemm.E_scalar, bb)
return [bz] return [bz]
def impl(self, z, a, x, y, b): def impl(self, z, a, x, y, b):
assert a.shape == () assert a.shape == ()
...@@ -746,26 +752,168 @@ class Gemm(_Op): ...@@ -746,26 +752,168 @@ class Gemm(_Op):
return z return z
def grad(self, (z, a, x, y, b), gz): def grad(self, (z, a, x, y, b), gz):
raise NotImplementedError() raise NotImplementedError()
if 0:
def c_support_code(self): def c_support_code(self):
return blas.cblas_header_text() return blas.cblas_header_text()
def c_libs(self): def c_libraries(self):
return blas.ldflags() return blas.ldflags()
def c_impl((_zin, _a, _x, _y, _b), (_z,)): def c_var_names(self):
check_ab = """ return [['_z', '_a', '_x', '_y', '_b'], ['_zout']]
def c_validate_update(self):
return """
if (%(_zout)s)
{
Py_DECREF(%(_zout)s);
}
if (%(_zout)s != %(_z)s)
{ {
if ((_a->descr->type_num != PyArray_DOUBLE) %(_zout)s = %(_z)s;
&& (_a->descr->type_num != PyArray_FLOAT)) Py_INCREF(%(_zout)s);
goto _dot_execute_fallback; }
"""
def c_validate_update_cleanup(self):
return ""
def c_code(self):
return """
int unit = 0;
int type_num = %(_x)s->descr->type_num;
int type_size = %(_x)s->descr->elsize; // in bytes
npy_intp* Nx = %(_x)s->dimensions;
npy_intp* Ny = %(_y)s->dimensions;
npy_intp* Nz = %(_z)s->dimensions;
npy_intp* Sx = %(_x)s->strides;
npy_intp* Sy = %(_y)s->strides;
npy_intp* Sz = %(_z)s->strides;
size_t sx_0, sx_1, sy_0, sy_1, sz_0, sz_1;
if (%(_x)s->nd != 2)
{PyErr_SetString(PyExc_NotImplementedError, "rank(x) != 2"); %(fail)s;}
if (%(_y)s->nd != 2)
{PyErr_SetString(PyExc_NotImplementedError, "rank(y) != 2"); %(fail)s;}
if (%(_z)s->nd != 2)
{PyErr_SetString(PyExc_NotImplementedError, "rank(z) != 2"); %(fail)s;}
if ((%(_a)s->descr->type_num != PyArray_DOUBLE)
&& (%(_a)s->descr->type_num != PyArray_FLOAT))
{PyErr_SetString(PyExc_NotImplementedError, "type(a) is not double or float"); %(fail)s;}
if ((%(_b)s->descr->type_num != PyArray_DOUBLE)
&& (%(_b)s->descr->type_num != PyArray_FLOAT))
{PyErr_SetString(PyExc_NotImplementedError, "type(b) is not double or float"); %(fail)s;}
if ((_b->descr->type_num != PyArray_DOUBLE) if ((%(_x)s->descr->type_num != PyArray_DOUBLE)
&& (_b->descr->type_num != PyArray_FLOAT)) && (%(_x)s->descr->type_num != PyArray_FLOAT))
goto _dot_execute_fallback; %(fail)s;
if ((%(_y)s->descr->type_num != PyArray_DOUBLE)
&& (%(_y)s->descr->type_num != PyArray_FLOAT))
%(fail)s;
if ((%(_y)s->descr->type_num != PyArray_DOUBLE)
&& (%(_y)s->descr->type_num != PyArray_FLOAT))
%(fail)s;
if ((%(_x)s->descr->type_num != %(_y)s->descr->type_num)
||(%(_x)s->descr->type_num != %(_z)s->descr->type_num))
%(fail)s;
if ((Nx[0] != Nz[0]) || (Nx[1] != Ny[0]) || (Ny[1] != Nz[1]))
{
PyErr_SetString(PyExc_ValueError, "Input dimensions do not agree");
%(fail)s;
}
if ((Sx[0] < 1) || (Sx[1] < 1) || (Sx[0] MOD type_size) || (Sx[1] MOD type_size)
|| (Sy[0] < 1) || (Sy[1] < 1) || (Sy[0] MOD type_size) || (Sy[1] MOD type_size)
|| (Sz[0] < 1) || (Sz[1] < 1) || (Sz[0] MOD type_size) || (Sz[1] MOD type_size))
{
PyErr_SetString(PyExc_ValueError, "gemm cant run on these inputs");
%(fail)s;
}
/*
encode the stride structure of _x,_y,_z into a single integer
*/
unit |= ((Sx[1] == type_size) ? 0x0 : (Sx[0] == type_size) ? 0x1 : 0x2) << 0;
unit |= ((Sy[1] == type_size) ? 0x0 : (Sy[0] == type_size) ? 0x1 : 0x2) << 4;
unit |= ((Sz[1] == type_size) ? 0x0 : (Sz[0] == type_size) ? 0x1 : 0x2) << 8;
/* create appropriate strides for malformed matrices that are row or column
* vectors
*/
sx_0 = (Nx[0] > 1) ? Sx[0]/type_size : Nx[1];
sx_1 = (Nx[1] > 1) ? Sx[1]/type_size : Nx[0];
sy_0 = (Ny[0] > 1) ? Sy[0]/type_size : Ny[1];
sy_1 = (Ny[1] > 1) ? Sy[1]/type_size : Ny[0];
sz_0 = (Nz[0] > 1) ? Sz[0]/type_size : Nz[1];
sz_1 = (Nz[1] > 1) ? Sz[1]/type_size : Nz[0];
switch (type_num)
{
case PyArray_FLOAT:
{
#define REAL float
float a = (%(_a)s->descr->type_num == PyArray_FLOAT)
? (REAL)(((float*)%(_a)s->data)[0])
: (REAL)(((double*)%(_a)s->data)[0]);
float b = (%(_b)s->descr->type_num == PyArray_FLOAT) ?
(REAL)(((float*)%(_b)s->data)[0])
: (REAL)(((double*)%(_b)s->data)[0]);
float* x = (float*)PyArray_DATA(%(_x)s);
float* y = (float*)PyArray_DATA(%(_y)s);
float* z = (float*)PyArray_DATA(%(_z)s);
switch(unit)
{
case 0x000: cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_0, y, sy_0, b, z, sz_0); break;
case 0x001: cblas_sgemm(CblasRowMajor, CblasTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_1, y, sy_0, b, z, sz_0); break;
case 0x010: cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, Nz[0], Nz[1], Nx[1], a, x, sx_0, y, sy_1, b, z, sz_0); break;
case 0x011: cblas_sgemm(CblasRowMajor, CblasTrans, CblasTrans, Nz[0], Nz[1], Nx[1], a, x, sx_1, y, sy_1, b, z, sz_0); break;
case 0x100: cblas_sgemm(CblasColMajor, CblasTrans, CblasTrans, Nz[0], Nz[1], Nx[1], a, x, sx_0, y, sy_0, b, z, sz_1); break;
case 0x101: cblas_sgemm(CblasColMajor, CblasNoTrans, CblasTrans, Nz[0], Nz[1], Nx[1], a, x, sx_1, y, sy_0, b, z, sz_1); break;
case 0x110: cblas_sgemm(CblasColMajor, CblasTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_0, y, sy_1, b, z, sz_1); break;
case 0x111: cblas_sgemm(CblasColMajor, CblasNoTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_1, y, sy_1, b, z, sz_1); break;
default: %(fail)s;
};
#undef REAL
} }
break;
case PyArray_DOUBLE:
{
#define REAL double
double a = (%(_a)s->descr->type_num == PyArray_FLOAT)
? (REAL)(((float*)%(_a)s->data)[0])
: (REAL)(((double*)%(_a)s->data)[0]);
double b = (%(_b)s->descr->type_num == PyArray_FLOAT) ?
(REAL)(((float*)%(_b)s->data)[0])
: (REAL)(((double*)%(_b)s->data)[0]);
double* x = (double*)PyArray_DATA(%(_x)s);
double* y = (double*)PyArray_DATA(%(_y)s);
double* z = (double*)PyArray_DATA(%(_z)s);
switch(unit)
{
case 0x000: cblas_dgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_0, y, sy_0, b, z, sz_0); break;
case 0x001: cblas_dgemm(CblasRowMajor, CblasTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_1, y, sy_0, b, z, sz_0); break;
case 0x010: cblas_dgemm(CblasRowMajor, CblasNoTrans, CblasTrans, Nz[0], Nz[1], Nx[1], a, x, sx_0, y, sy_1, b, z, sz_0); break;
case 0x011: cblas_dgemm(CblasRowMajor, CblasTrans, CblasTrans, Nz[0], Nz[1], Nx[1], a, x, sx_1, y, sy_1, b, z, sz_0); break;
case 0x100: cblas_dgemm(CblasColMajor, CblasTrans, CblasTrans, Nz[0], Nz[1], Nx[1], a, x, sx_0, y, sy_0, b, z, sz_1); break;
case 0x101: cblas_dgemm(CblasColMajor, CblasNoTrans, CblasTrans, Nz[0], Nz[1], Nx[1], a, x, sx_1, y, sy_0, b, z, sz_1); break;
case 0x110: cblas_dgemm(CblasColMajor, CblasTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_0, y, sy_1, b, z, sz_1); break;
case 0x111: cblas_dgemm(CblasColMajor, CblasNoTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_1, y, sy_1, b, z, sz_1); break;
default: %(fail)s;
};
#undef REAL
}
break;
}
""" """
return blas.gemm_code( check_ab,
'(_a->descr->type_num == PyArray_FLOAT) ? (REAL)(((float*)_a->data)[0]) : (REAL)(((double*)_a->data)[0])',
'(_b->descr->type_num == PyArray_FLOAT) ? (REAL)(((float*)_b->data)[0]) : (REAL)(((double*)_b->data)[0])')
gemm = gof.op.constructor(Gemm) gemm = gof.op.constructor(Gemm)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论