gemm restricted to rank 2, implemented in C. added __hash__ to Op, Result for…

gemm restricted to rank 2, implemented in C. added __hash__ to Op, Result for consistent set orderings
上级 201629ac
......@@ -11,3 +11,4 @@ pull.sh
*.pyc
*.so
*.sw?
core.*
......@@ -627,36 +627,53 @@ class t_gemm(unittest.TestCase):
return numpy.random.rand(*args)
def cmp(self, z, a, x, y, b):
z,a,x,y,b = [numpy.asarray(p) for p in z,a,x,y,b]
cz = z.copy()
tz,ta,tx,ty,tb = [astensor(p) for p in z,a,x,y,b]
f = Function([tz,ta,tx,ty,tb], [gemm(tz,ta,tx,ty,tb)])
new_z = f(z,a,x,y,b)
_z = self._gemm(cz, a, x, y, b)
self.failUnless(z is new_z)
#print cz, _z, z, type(cz), type(_z), type(z)
#_approx_eq.debug = 1
self.failUnless(_approx_eq(_z, z))
if a == 0.0 and b == 1.0:
return
else:
self.failIf(numpy.all(cz == z))
def cmp_linker(z, a, x, y, b, l):
z,a,x,y,b = [numpy.asarray(p) for p in z,a,x,y,b]
cz = z.copy()
tz,ta,tx,ty,tb = [astensor(p) for p in z,a,x,y,b]
f = Function([tz,ta,tx,ty,tb], [gemm(tz,ta,tx,ty,tb)], linker_cls=l)
new_z = f(z,a,x,y,b)
_z = self._gemm(cz, a, x, y, b)
self.failUnless(z is new_z)
#print cz, _z, z, type(cz), type(_z), type(z)
#_approx_eq.debug = 1
self.failUnless(_approx_eq(_z, z))
if a == 0.0 and b == 1.0:
return
else:
self.failIf(numpy.all(cz == z))
cmp_linker(copy(z), a, x, y, b, gof.cc.OpWiseCLinker)
#cmp_linker(copy(z), a, x, y, b, gof.cc.CLinker)
cmp_linker(copy(z), a, x, y, b, gof.link.PerformLinker)
def test0a(self):
try:
g = gemm([1.], 1., [1.], [1.], 1.)
except ValueError, e:
if e[0] is Gemm.E_rank:
return
self.fail()
def test0(self):
try:
self.cmp(1., 0., 1.0, 1.0, 1.0)
except ValueError, e:
if e[0] is Gemm.E_rank:
return
self.fail()
def test0(self): self.cmp(1., 0., 1.0, 1.0, 1.0)
def test1(self): self.cmp(2., 0., 1.0, 1.0, 0.0)
def test2(self):
try:
self.cmp(2., 1.0, [3,2,1.], [[1],[2],[3.]], 1.0)
except ValueError, e:
self.failUnless(e[0] == Gemm.E_bcast)
self.failUnless(e[0] == Gemm.E_rank)
return
self.fail()
def test3(self): self.cmp([2.], 1.,[3,2,1.], [[1],[2],[3.]], 1.0)
def test4(self): self.cmp(self.rand(3,4), 1.0,
self.rand(3,5), self.rand(5,4), 0.0)
def test4(self):
self.cmp(self.rand(3,4), 1.0, self.rand(3,5), self.rand(5,4), 0.0)
def test5(self): self.cmp(self.rand(3,4), 1.0,
self.rand(3,5), self.rand(5,4), 1.0)
def test6(self): self.cmp(self.rand(3,4), 1.0,
......
......@@ -26,6 +26,8 @@ def cblas_header_text():
__BEGIN_DECLS
#define MOD %
/*
* Enumerated and derived types
*/
......
......@@ -4,6 +4,7 @@ Contains the Op class, which is the base interface for all operations
compatible with gof's graph manipulation routines.
"""
import utils
from utils import ClsInit, all_bases, all_bases_collect, AbstractFunctionError
import graph
......@@ -36,7 +37,7 @@ class Op(object):
not required) that it creates them.
"""
__slots__ = ['_inputs', '_outputs']
__slots__ = ['_inputs', '_outputs', '_hash_id']
_default_output_idx = 0
......@@ -52,7 +53,29 @@ class Op(object):
def __init__(self, **kwargs):
pass
self._hash_id = utils.hashgen()
#
# Python stdlib compatibility
#
def __cmp__(self, other):
return cmp(id(self), id(other))
def __eq__(self, other):
return self is other #assuming this is faster, equiv to id(self) == id(other)
def __ne__(self, other):
return self is not other #assuming this is faster, equiv to id(self) != id(other)
def __hash__(self):
if not hasattr(self, '_hash_id'):
self._hash_id = utils.hashgen()
return self._hash_id
#
#
#
def get_input(self, i):
return self._inputs[i]
......
......@@ -5,6 +5,7 @@ value that is the input or the output of an Op.
"""
import copy
import utils
from utils import AbstractFunctionError
......@@ -52,14 +53,30 @@ class ResultBase(object):
data_filter
"""
__slots__ = ['_role', '_data', 'state', '_name']
__slots__ = ['_role', '_data', 'state', '_name', '_hash_id']
def __init__(self, role=None, name=None):
self._role = role
self._data = [None]
self.state = Empty
self.name = name
self._hash_id = utils.hashgen()
#
# Python stdlib compatibility
#
def __cmp__(self, other):
return cmp(id(self), id(other))
def __eq__(self, other):
return self is other #assuming this is faster, equiv to id(self) == id(other)
def __ne__(self, other):
return self is not other #assuming this is faster, equiv to id(self) != id(other)
def __hash__(self):
return self._hash_id
#
# role
......
......@@ -4,6 +4,10 @@
import re
def hashgen():
hashgen.next += 1
return hashgen.next
hashgen.next = 0
class OmegaError(Exception): pass
......
......@@ -10,6 +10,7 @@ import gof.op
from base_tensor import BaseTensor, BaseTensorOp
from elemwise import Elemwise
import blas # for gemm, dot
class Tensor(BaseTensor):
......@@ -712,12 +713,17 @@ dot = gof.op.constructor(Dot)
class Gemm(_Op):
nin=5
nout=1
E_bcast = 'incompatible broadcastable flags'
E_rank = 'gemm only works for rank 2'
E_scalar = 'gemm requires scalar argument'
def destroy_map(self):
return {self.out:[self.inputs[0]]}
def propagate_broadcastable(self, bz, ba, bx, by, bb):
if len(bz) != len(Dot.broadcastable_rule(bx,by)):
raise ValueError(Gemm.E_bcast, bz, bx, by)
if len(bz) != 2: raise ValueError(Gemm.E_rank, len(bz))
if len(bx) != 2: raise ValueError(Gemm.E_rank, len(bx))
if len(by) != 2: raise ValueError(Gemm.E_rank, len(by))
if len(ba): raise ValueError(Gemm.E_scalar, ba)
if len(bb): raise ValueError(Gemm.E_scalar, bb)
return [bz]
def impl(self, z, a, x, y, b):
assert a.shape == ()
......@@ -746,26 +752,168 @@ class Gemm(_Op):
return z
def grad(self, (z, a, x, y, b), gz):
raise NotImplementedError()
if 0:
def c_support_code(self):
return blas.cblas_header_text()
def c_libs(self):
return blas.ldflags()
def c_impl((_zin, _a, _x, _y, _b), (_z,)):
check_ab = """
{
if ((_a->descr->type_num != PyArray_DOUBLE)
&& (_a->descr->type_num != PyArray_FLOAT))
goto _dot_execute_fallback;
if ((_b->descr->type_num != PyArray_DOUBLE)
&& (_b->descr->type_num != PyArray_FLOAT))
goto _dot_execute_fallback;
def c_support_code(self):
return blas.cblas_header_text()
def c_libraries(self):
return blas.ldflags()
def c_var_names(self):
return [['_z', '_a', '_x', '_y', '_b'], ['_zout']]
def c_validate_update(self):
return """
if (%(_zout)s)
{
Py_DECREF(%(_zout)s);
}
if (%(_zout)s != %(_z)s)
{
%(_zout)s = %(_z)s;
Py_INCREF(%(_zout)s);
}
"""
def c_validate_update_cleanup(self):
return ""
def c_code(self):
return """
int unit = 0;
int type_num = %(_x)s->descr->type_num;
int type_size = %(_x)s->descr->elsize; // in bytes
npy_intp* Nx = %(_x)s->dimensions;
npy_intp* Ny = %(_y)s->dimensions;
npy_intp* Nz = %(_z)s->dimensions;
npy_intp* Sx = %(_x)s->strides;
npy_intp* Sy = %(_y)s->strides;
npy_intp* Sz = %(_z)s->strides;
size_t sx_0, sx_1, sy_0, sy_1, sz_0, sz_1;
if (%(_x)s->nd != 2)
{PyErr_SetString(PyExc_NotImplementedError, "rank(x) != 2"); %(fail)s;}
if (%(_y)s->nd != 2)
{PyErr_SetString(PyExc_NotImplementedError, "rank(y) != 2"); %(fail)s;}
if (%(_z)s->nd != 2)
{PyErr_SetString(PyExc_NotImplementedError, "rank(z) != 2"); %(fail)s;}
if ((%(_a)s->descr->type_num != PyArray_DOUBLE)
&& (%(_a)s->descr->type_num != PyArray_FLOAT))
{PyErr_SetString(PyExc_NotImplementedError, "type(a) is not double or float"); %(fail)s;}
if ((%(_b)s->descr->type_num != PyArray_DOUBLE)
&& (%(_b)s->descr->type_num != PyArray_FLOAT))
{PyErr_SetString(PyExc_NotImplementedError, "type(b) is not double or float"); %(fail)s;}
if ((%(_x)s->descr->type_num != PyArray_DOUBLE)
&& (%(_x)s->descr->type_num != PyArray_FLOAT))
%(fail)s;
if ((%(_y)s->descr->type_num != PyArray_DOUBLE)
&& (%(_y)s->descr->type_num != PyArray_FLOAT))
%(fail)s;
if ((%(_y)s->descr->type_num != PyArray_DOUBLE)
&& (%(_y)s->descr->type_num != PyArray_FLOAT))
%(fail)s;
if ((%(_x)s->descr->type_num != %(_y)s->descr->type_num)
||(%(_x)s->descr->type_num != %(_z)s->descr->type_num))
%(fail)s;
if ((Nx[0] != Nz[0]) || (Nx[1] != Ny[0]) || (Ny[1] != Nz[1]))
{
PyErr_SetString(PyExc_ValueError, "Input dimensions do not agree");
%(fail)s;
}
if ((Sx[0] < 1) || (Sx[1] < 1) || (Sx[0] MOD type_size) || (Sx[1] MOD type_size)
|| (Sy[0] < 1) || (Sy[1] < 1) || (Sy[0] MOD type_size) || (Sy[1] MOD type_size)
|| (Sz[0] < 1) || (Sz[1] < 1) || (Sz[0] MOD type_size) || (Sz[1] MOD type_size))
{
PyErr_SetString(PyExc_ValueError, "gemm cant run on these inputs");
%(fail)s;
}
/*
encode the stride structure of _x,_y,_z into a single integer
*/
unit |= ((Sx[1] == type_size) ? 0x0 : (Sx[0] == type_size) ? 0x1 : 0x2) << 0;
unit |= ((Sy[1] == type_size) ? 0x0 : (Sy[0] == type_size) ? 0x1 : 0x2) << 4;
unit |= ((Sz[1] == type_size) ? 0x0 : (Sz[0] == type_size) ? 0x1 : 0x2) << 8;
/* create appropriate strides for malformed matrices that are row or column
* vectors
*/
sx_0 = (Nx[0] > 1) ? Sx[0]/type_size : Nx[1];
sx_1 = (Nx[1] > 1) ? Sx[1]/type_size : Nx[0];
sy_0 = (Ny[0] > 1) ? Sy[0]/type_size : Ny[1];
sy_1 = (Ny[1] > 1) ? Sy[1]/type_size : Ny[0];
sz_0 = (Nz[0] > 1) ? Sz[0]/type_size : Nz[1];
sz_1 = (Nz[1] > 1) ? Sz[1]/type_size : Nz[0];
switch (type_num)
{
case PyArray_FLOAT:
{
#define REAL float
float a = (%(_a)s->descr->type_num == PyArray_FLOAT)
? (REAL)(((float*)%(_a)s->data)[0])
: (REAL)(((double*)%(_a)s->data)[0]);
float b = (%(_b)s->descr->type_num == PyArray_FLOAT) ?
(REAL)(((float*)%(_b)s->data)[0])
: (REAL)(((double*)%(_b)s->data)[0]);
float* x = (float*)PyArray_DATA(%(_x)s);
float* y = (float*)PyArray_DATA(%(_y)s);
float* z = (float*)PyArray_DATA(%(_z)s);
switch(unit)
{
case 0x000: cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_0, y, sy_0, b, z, sz_0); break;
case 0x001: cblas_sgemm(CblasRowMajor, CblasTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_1, y, sy_0, b, z, sz_0); break;
case 0x010: cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, Nz[0], Nz[1], Nx[1], a, x, sx_0, y, sy_1, b, z, sz_0); break;
case 0x011: cblas_sgemm(CblasRowMajor, CblasTrans, CblasTrans, Nz[0], Nz[1], Nx[1], a, x, sx_1, y, sy_1, b, z, sz_0); break;
case 0x100: cblas_sgemm(CblasColMajor, CblasTrans, CblasTrans, Nz[0], Nz[1], Nx[1], a, x, sx_0, y, sy_0, b, z, sz_1); break;
case 0x101: cblas_sgemm(CblasColMajor, CblasNoTrans, CblasTrans, Nz[0], Nz[1], Nx[1], a, x, sx_1, y, sy_0, b, z, sz_1); break;
case 0x110: cblas_sgemm(CblasColMajor, CblasTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_0, y, sy_1, b, z, sz_1); break;
case 0x111: cblas_sgemm(CblasColMajor, CblasNoTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_1, y, sy_1, b, z, sz_1); break;
default: %(fail)s;
};
#undef REAL
}
"""
return blas.gemm_code( check_ab,
'(_a->descr->type_num == PyArray_FLOAT) ? (REAL)(((float*)_a->data)[0]) : (REAL)(((double*)_a->data)[0])',
'(_b->descr->type_num == PyArray_FLOAT) ? (REAL)(((float*)_b->data)[0]) : (REAL)(((double*)_b->data)[0])')
break;
case PyArray_DOUBLE:
{
#define REAL double
double a = (%(_a)s->descr->type_num == PyArray_FLOAT)
? (REAL)(((float*)%(_a)s->data)[0])
: (REAL)(((double*)%(_a)s->data)[0]);
double b = (%(_b)s->descr->type_num == PyArray_FLOAT) ?
(REAL)(((float*)%(_b)s->data)[0])
: (REAL)(((double*)%(_b)s->data)[0]);
double* x = (double*)PyArray_DATA(%(_x)s);
double* y = (double*)PyArray_DATA(%(_y)s);
double* z = (double*)PyArray_DATA(%(_z)s);
switch(unit)
{
case 0x000: cblas_dgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_0, y, sy_0, b, z, sz_0); break;
case 0x001: cblas_dgemm(CblasRowMajor, CblasTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_1, y, sy_0, b, z, sz_0); break;
case 0x010: cblas_dgemm(CblasRowMajor, CblasNoTrans, CblasTrans, Nz[0], Nz[1], Nx[1], a, x, sx_0, y, sy_1, b, z, sz_0); break;
case 0x011: cblas_dgemm(CblasRowMajor, CblasTrans, CblasTrans, Nz[0], Nz[1], Nx[1], a, x, sx_1, y, sy_1, b, z, sz_0); break;
case 0x100: cblas_dgemm(CblasColMajor, CblasTrans, CblasTrans, Nz[0], Nz[1], Nx[1], a, x, sx_0, y, sy_0, b, z, sz_1); break;
case 0x101: cblas_dgemm(CblasColMajor, CblasNoTrans, CblasTrans, Nz[0], Nz[1], Nx[1], a, x, sx_1, y, sy_0, b, z, sz_1); break;
case 0x110: cblas_dgemm(CblasColMajor, CblasTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_0, y, sy_1, b, z, sz_1); break;
case 0x111: cblas_dgemm(CblasColMajor, CblasNoTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_1, y, sy_1, b, z, sz_1); break;
default: %(fail)s;
};
#undef REAL
}
break;
}
"""
gemm = gof.op.constructor(Gemm)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论