提交 13db39ca authored 作者: bergstrj@iro.umontreal.ca's avatar bergstrj@iro.umontreal.ca

merged, changed gemm to use fortran blas instead of cblas

...@@ -120,6 +120,16 @@ class _test_Broadcast(unittest.TestCase): ...@@ -120,6 +120,16 @@ class _test_Broadcast(unittest.TestCase):
f(xv, yv) f(xv, yv)
assert (xv == yv).all() assert (xv == yv).all()
def test_weird_strides(self):
x = modes.build(Tensor('float64', [0, 0, 0, 0, 0], name = 'x'))
y = modes.build(Tensor('float64', [0, 0, 0, 0, 0], name = 'y'))
e = Broadcast(Add, (x, y)).out
f = gof.CLinker(env([x, y], [e])).make_function(inplace = False)
xv = numpy.random.rand(2, 2, 2, 2, 2)
yv = numpy.random.rand(2, 2, 2, 2, 2).transpose(4, 0, 3, 1, 2)
zv = xv + yv
assert (f(xv, yv) == zv).all()
class _test_CAReduce(unittest.TestCase): class _test_CAReduce(unittest.TestCase):
......
import unittest
import gof
from opt import *
import tensor
from tensor import Tensor
from gof import Env
from elemwise import DimShuffle
import numpy
import scalar_opt
def inputs(xbc = (0, 0), ybc = (0, 0), zbc = (0, 0)):
x = Tensor(broadcastable = xbc, dtype = 'float64', name = 'x')
y = Tensor(broadcastable = ybc, dtype = 'float64', name = 'y')
z = Tensor(broadcastable = zbc, dtype = 'float64', name = 'z')
return x, y, z
ds = gof.op.constructor(DimShuffle)
class _test_inplace_opt(unittest.TestCase):
def test_straightforward(self):
x, y, z = inputs()
e = x + y + z
g = Env([x, y], [e])
assert str(g) == "[Broadcast{Add}(Broadcast{Add}(x, y), z)]"
inplace_optimizer.optimize(g)
assert str(g) == "[Broadcast{Add}{0: 0}(Broadcast{Add}{0: 0}(x, y), z)]"
def test_multiple_uses(self):
x, y, z = inputs()
e0 = x + y
e1 = x * y
g = Env([x, y], [e0, e1])
assert str(g) == "[Broadcast{Add}(x, y), Broadcast{Mul}(x, y)]"
inplace_optimizer.optimize(g)
assert str(g) == "[Broadcast{Add}{0: 0}(x, y), Broadcast{Mul}(x, y)]" \
or str(g) == "[Broadcast{Add}(x, y), Broadcast{Mul}{0: 0}(x, y)]"
def test_user_inplace(self):
x, y, z = inputs()
e0 = x + y
e1 = tensor.mul_inplace(x, y)
g = Env([x, y], [e0, e1])
assert str(g) == "[Broadcast{Add}(x, y), Broadcast{Mul}{0: 0}(x, y)]"
inplace_optimizer.optimize(g)
assert str(g) == "[Broadcast{Add}(x, y), Broadcast{Mul}{0: 0}(x, y)]"
class _test_dimshuffle_lift(unittest.TestCase):
def test_double_transpose(self):
x, y, z = inputs()
e = ds(ds(x, (1, 0)), (1, 0))
g = Env([x], [e])
assert str(g) == "[DimShuffle{10}(DimShuffle{10}(x))]"
lift_dimshuffle.optimize(g)
assert str(g) == "[x]"
def test_merge2(self):
x, y, z = inputs()
e = ds(ds(x, (1, 'x', 0)), (2, 0, 'x', 1))
g = Env([x], [e])
self.failUnless(str(g) == "[DimShuffle{20x1}(DimShuffle{1x0}(x))]", str(g))
lift_dimshuffle.optimize(g)
self.failUnless(str(g) == "[DimShuffle{01xx}(x)]", str(g))
def test_elim3(self):
x, y, z = inputs()
e = ds(ds(ds(x, (0, 'x', 1)), (2, 0, 'x', 1)), (1, 0))
g = Env([x], [e])
self.failUnless(str(g) == "[DimShuffle{10}(DimShuffle{20x1}(DimShuffle{0x1}(x)))]", str(g))
lift_dimshuffle.optimize(g)
self.failUnless(str(g) == "[x]", str(g))
def test_lift(self):
x, y, z = inputs([0]*1, [0]*2, [0]*3)
e = x + y + z
g = Env([x, y, z], [e])
self.failUnless(str(g) == "[Broadcast{Add}(DimShuffle{x01}(Broadcast{Add}(DimShuffle{x0}(x), y)), z)]", str(g))
lift_dimshuffle.optimize(g)
self.failUnless(str(g) == "[Broadcast{Add}(Broadcast{Add}(DimShuffle{xx0}(x), DimShuffle{x01}(y)), z)]", str(g))
class _test_cliques(unittest.TestCase):
def test_straightforward(self):
x, y, z = inputs()
m = y * z
d = tensor.dot(x, m)
d.name = 'd'
e = x + y + d
g = Env([x, y, z], [e])
cliques = find_cliques(g)
assert len(cliques) == 2
(i1, o1), (i2, o2) = cliques
assert str(Env(i1, o1)) == "[Broadcast{Add}(Broadcast{Add}(x, y), d)]"
assert str(Env(i2, o2)) == "[Broadcast{Mul}(y, z)]"
# print g
# for i, o in find_cliques(g):
# print "-->", Env(i, [o])
def test_broadcasting(self):
x, y, z = inputs([0]*1, [0]*2, [0]*3)
e = x + y + z
g = Env([x, y, z], [e])
lift_dimshuffle.optimize(g)
assert len(find_cliques(g, through_broadcast = True)) == 1
assert len(find_cliques(g, through_broadcast = False)) == 2
# print g
# for i, o in find_cliques(g, True):
# print "-->", Env(i, [o])
# class _test_clique_opt(unittest.TestCase):
# def test_straightforward(self):
# x, y, z = inputs()
# e = x ** 2.0 #x * x
# g = Env([x], [e])
# gof.ConstantFinder().optimize(g)
# opt = CliqueOptimizer(through_broadcast = False,
# scalar_optimizer = scalar_opt.opt2,
# make_composite = False)
# print g
# opt.optimize(g)
# print g
# def test_inplace(self):
# x, y, z = inputs()
# #e = tensor.add_inplace(x, y + z)
# e = x + tensor.add_inplace(y, z)
# g = Env([x, y, z], [e])
# opt = CliqueOptimizer(through_broadcast = False,
# scalar_optimizer = None,
# make_composite = True)
# print g
# opt.optimize(g)
# print g
# # print g.outputs[0].owner.c_code(['x', 'y', 'z'], ['e'], dict(fail = "FAIL;", id = 0))
# print gof.OpWiseCLinker(g).make_function()(numpy.ones((5, 5)), numpy.ones((5, 5)), numpy.ones((5, 5)))
# def test_straightforward(self):
# x, y, z = inputs()
# e = x + y + z
# g = Env([x, y, z], [e])
# opt = CliqueOptimizer(through_broadcast = False,
# scalar_optimizer = None,
# make_composite = True)
# print g
# opt.optimize(g)
# print g
# # print g.outputs[0].owner.c_code(['x', 'y', 'z'], ['e'], dict(fail = "FAIL;", id = 0))
# print gof.OpWiseCLinker(g).make_function()(numpy.ones((5, 5)), numpy.ones((5, 5)), numpy.ones((5, 5)))
# def test_straightforward2(self):
# x, y, z = inputs()
# m = y * z
# d = tensor.dot(x, m)
# d.name = 'd'
# e = x + y + d
# g = Env([x, y, z], [e])
# opt = CliqueOptimizer(through_broadcast = False,
# scalar_optimizer = None,
# make_composite = True)
# print g
# opt.optimize(g)
# print g
# # print g.outputs[0].owner.c_code(['x', 'y', 'z'], ['e'], dict(fail = "FAIL;", id = 0))
# print gof.OpWiseCLinker(g).make_function()(numpy.ones((5, 5)), numpy.ones((5, 5)), numpy.ones((5, 5)))
if __name__ == '__main__':
unittest.main()
...@@ -27,6 +27,50 @@ class _test_ScalarOps(unittest.TestCase): ...@@ -27,6 +27,50 @@ class _test_ScalarOps(unittest.TestCase):
assert fn(1.0, 2.0) == 1.5 assert fn(1.0, 2.0) == 1.5
class _test_composite(unittest.TestCase):
def test_straightforward(self):
x, y, z = inputs()
e = mul(add(x, y), div(x, y))
C = composite([x, y], [e])
c = C(x, y)
# print c.c_code(['x', 'y'], ['z'], dict(id = 0))
c.perform()
assert c.outputs[0].data == 1.5
g = env([x, y], [c.out])
fn = gof.DualLinker(g).make_function()
assert fn(1.0, 2.0) == 1.5
def test_with_constants(self):
x, y, z = inputs()
e = mul(add(70.0, y), div(x, y))
C = composite([x, y], [e])
c = C(x, y)
assert "70.0" in c.c_code(['x', 'y'], ['z'], dict(id = 0))
# print c.c_code(['x', 'y'], ['z'], dict(id = 0))
c.perform()
assert c.outputs[0].data == 36.0
g = env([x, y], [c.out])
fn = gof.DualLinker(g).make_function()
assert fn(1.0, 2.0) == 36.0
def test_many_outputs(self):
x, y, z = inputs()
e0 = x + y + z
e1 = x + y * z
e2 = x / y
C = composite([x, y, z], [e0, e1, e2])
c = C(x, y, z)
# print c.c_code(['x', 'y', 'z'], ['out0', 'out1', 'out2'], dict(id = 0))
c.perform()
assert c.outputs[0].data == 6.0
assert c.outputs[1].data == 7.0
assert c.outputs[2].data == 0.5
g = env([x, y], c.outputs)
fn = gof.DualLinker(g).make_function()
assert fn(1.0, 2.0) == [6.0, 7.0, 0.5]
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
......
import unittest
from gof import Result, Op, Env, modes
import gof
from scalar import *
from scalar_opt import *
def inputs():
x = Scalar('float64', name = 'x')
y = Scalar('float64', name = 'y')
z = Scalar('float64', name = 'z')
return x, y, z
class _test_opts(unittest.TestCase):
def test_pow_to_sqr(self):
x, y, z = inputs()
e = x ** 2.0
g = Env([x], [e])
assert str(g) == "[Pow(x, 2.0)]"
gof.ConstantFinder().optimize(g)
opt2.optimize(g)
assert str(g) == "[Sqr(x)]"
if __name__ == '__main__':
unittest.main()
...@@ -990,21 +990,21 @@ class t_gemm(unittest.TestCase): ...@@ -990,21 +990,21 @@ class t_gemm(unittest.TestCase):
def cmp(self, z, a, x, y, b): def cmp(self, z, a, x, y, b):
def cmp_linker(z, a, x, y, b, l): def cmp_linker(z, a, x, y, b, l):
z,a,x,y,b = [numpy.asarray(p) for p in z,a,x,y,b] z,a,x,y,b = [numpy.asarray(p) for p in z,a,x,y,b]
cz = z.copy() z_orig = z.copy()
tz,ta,tx,ty,tb = [astensor(p) for p in z,a,x,y,b] tz,ta,tx,ty,tb = [astensor(p) for p in z,a,x,y,b]
f = Function([tz,ta,tx,ty,tb], [gemm(tz,ta,tx,ty,tb)], linker_cls=l) f = Function([tz,ta,tx,ty,tb], [gemm(tz,ta,tx,ty,tb)], linker_cls=l)
new_z = f(z,a,x,y,b) new_z = f(z,a,x,y,b)
_z = self._gemm(cz, a, x, y, b) z_after = self._gemm(z_orig, a, x, y, b)
self.failUnless(z is new_z) self.failUnless(z is new_z)
#print cz, _z, z, type(cz), type(_z), type(z) #print z_orig, z_after, z, type(z_orig), type(z_after), type(z)
#_approx_eq.debug = 1 #_approx_eq.debug = 1
self.failUnless(_approx_eq(_z, z)) self.failUnless(_approx_eq(z_after, z))
if a == 0.0 and b == 1.0: if a == 0.0 and b == 1.0:
return return
else: else:
self.failIf(numpy.all(cz == z)) self.failIf(numpy.all(z_orig == z))
cmp_linker(copy(z), a, x, y, b, gof.cc.OpWiseCLinker) cmp_linker(copy(z), a, x, y, b, gof.cc.OpWiseCLinker)
#cmp_linker(copy(z), a, x, y, b, gof.cc.CLinker) #cmp_linker(copy(z), a, x, y, b, gof.cc.CLinker)
...@@ -1101,5 +1101,49 @@ class t_gemm(unittest.TestCase): ...@@ -1101,5 +1101,49 @@ class t_gemm(unittest.TestCase):
eval_outputs([gemm(Z, 1.0, A, A, 1.0)]) eval_outputs([gemm(Z, 1.0, A, A, 1.0)])
eval_outputs([gemm(Z, 1.0, A, A.T, 1.0)]) eval_outputs([gemm(Z, 1.0, A, A.T, 1.0)])
def test_transposes(self):
# three square matrices which are not contiguous
A = self.rand(4,5)[:,:4]
B = self.rand(4,5)[:,:4]
C = self.rand(4,5)[:,:4]
def t(z,x,y,a=1.0, b=0.0,l=gof.cc.OpWiseCLinker):
z,a,x,y,b = [numpy.asarray(p) for p in z,a,x,y,b]
z_orig = z.copy()
z_after = self._gemm(z, a, x, y, b)
tz,ta,tx,ty,tb = [astensor(p) for p in z,a,x,y,b]
f = Function([tz,ta,tx,ty,tb], [gemm(tz,ta,tx,ty,tb)], linker_cls=l)
f(z, a, x, y, b)
self.failUnless(_approx_eq(z_after, z), (z_orig, z_after, z))
f(z.T, a, y.T, x.T, b)
self.failUnless(_approx_eq(z_after, z))
t(C,A,B)
t(C.T, A, B)
t(C, A.T, B)
t(C, A, B.T)
t(C.T, A.T, B)
t(C, A.T, B.T)
t(C.T, A, B.T)
t(C.T, A.T, B.T)
t(C, A[:,:2], B[:2, :])
t(C.T, A[:,:2], B[:2, :])
t(C, A[:2,:].T, B[:2, :])
t(C.T, A[:2,:].T, B[:2, :])
t(C, A[:2,:].T, B[:, :2].T)
t(C.T, A[:2,:].T, B[:, :2].T)
try:
t(C.T, A[:2,:], B[:, :2].T)
except ValueError, e:
if e[0].find('aligned') >= 0:
return
self.fail()
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -11,6 +11,7 @@ fine-grained motifs of iadd, isub, scale, and dot. ...@@ -11,6 +11,7 @@ fine-grained motifs of iadd, isub, scale, and dot.
""" """
def cblas_header_text(): def cblas_header_text():
"""C header for the cblas interface"""
return """ return """
//#include <stddef.h> //#include <stddef.h>
...@@ -589,6 +590,210 @@ def cblas_header_text(): ...@@ -589,6 +590,210 @@ def cblas_header_text():
__END_DECLS __END_DECLS
""" """
def blas_proto():
"""C header for the fortran blas interface"""
return """
extern "C"
{
void xerbla_(char*, void *);
/***********/
/* Level 1 */
/***********/
/* Single Precision */
void srot_(const int*, float *, const int*, float *, const int*, const float *, const float *);
void srotg_(float *,float *,float *,float *);
void srotm_( const int*, float *, const int*, float *, const int*, const float *);
void srotmg_(float *,float *,float *,const float *, float *);
void sswap_( const int*, float *, const int*, float *, const int*);
void scopy_( const int*, const float *, const int*, float *, const int*);
void saxpy_( const int*, const float *, const float *, const int*, float *, const int*);
void sdot_sub_(const int*, const float *, const int*, const float *, const int*, float *);
void sdsdot_sub_( const int*, const float *, const float *, const int*, const float *, const int*, float *);
void sscal_( const int*, const float *, float *, const int*);
void snrm2_sub_( const int*, const float *, const int*, float *);
void sasum_sub_( const int*, const float *, const int*, float *);
void isamax_sub_( const int*, const float * , const int*, const int*);
/* Double Precision */
void drot_(const int*, double *, const int*, double *, const int*, const double *, const double *);
void drotg_(double *,double *,double *,double *);
void drotm_( const int*, double *, const int*, double *, const int*, const double *);
void drotmg_(double *,double *,double *,const double *, double *);
void dswap_( const int*, double *, const int*, double *, const int*);
void dcopy_( const int*, const double *, const int*, double *, const int*);
void daxpy_( const int*, const double *, const double *, const int*, double *, const int*);
void dswap_( const int*, double *, const int*, double *, const int*);
void dsdot_sub_(const int*, const float *, const int*, const float *, const int*, double *);
void ddot_sub_( const int*, const double *, const int*, const double *, const int*, double *);
void dscal_( const int*, const double *, double *, const int*);
void dnrm2_sub_( const int*, const double *, const int*, double *);
void dasum_sub_( const int*, const double *, const int*, double *);
void idamax_sub_( const int*, const double * , const int*, const int*);
/* Single Complex Precision */
void cswap_( const int*, void *, const int*, void *, const int*);
void ccopy_( const int*, const void *, const int*, void *, const int*);
void caxpy_( const int*, const void *, const void *, const int*, void *, const int*);
void cswap_( const int*, void *, const int*, void *, const int*);
void cdotc_sub_( const int*, const void *, const int*, const void *, const int*, void *);
void cdotu_sub_( const int*, const void *, const int*, const void *, const int*, void *);
void cscal_( const int*, const void *, void *, const int*);
void icamax_sub_( const int*, const void *, const int*, const int*);
void csscal_( const int*, const float *, void *, const int*);
void scnrm2_sub_( const int*, const void *, const int*, float *);
void scasum_sub_( const int*, const void *, const int*, float *);
/* Double Complex Precision */
void zswap_( const int*, void *, const int*, void *, const int*);
void zcopy_( const int*, const void *, const int*, void *, const int*);
void zaxpy_( const int*, const void *, const void *, const int*, void *, const int*);
void zswap_( const int*, void *, const int*, void *, const int*);
void zdotc_sub_( const int*, const void *, const int*, const void *, const int*, void *);
void zdotu_sub_( const int*, const void *, const int*, const void *, const int*, void *);
void zdscal_( const int*, const double *, void *, const int*);
void zscal_( const int*, const void *, void *, const int*);
void dznrm2_sub_( const int*, const void *, const int*, double *);
void dzasum_sub_( const int*, const void *, const int*, double *);
void izamax_sub_( const int*, const void *, const int*, const int*);
/***********/
/* Level 2 */
/***********/
/* Single Precision */
void sgemv_(char*, const int*, const int*, const float *, const float *, const int*, const float *, const int*, const float *, float *, const int*);
void sgbmv_(char*, const int*, const int*, const int*, const int*, const float *, const float *, const int*, const float *, const int*, const float *, float *, const int*);
void ssymv_(char*, const int*, const float *, const float *, const int*, const float *, const int*, const float *, float *, const int*);
void ssbmv_(char*, const int*, const int*, const float *, const float *, const int*, const float *, const int*, const float *, float *, const int*);
void sspmv_(char*, const int*, const float *, const float *, const float *, const int*, const float *, float *, const int*);
void strmv_( char*, char*, char*, const int*, const float *, const int*, float *, const int*);
void stbmv_( char*, char*, char*, const int*, const int*, const float *, const int*, float *, const int*);
void strsv_( char*, char*, char*, const int*, const float *, const int*, float *, const int*);
void stbsv_( char*, char*, char*, const int*, const int*, const float *, const int*, float *, const int*);
void stpmv_( char*, char*, char*, const int*, const float *, float *, const int*);
void stpsv_( char*, char*, char*, const int*, const float *, float *, const int*);
void sger_( const int*, const int*, const float *, const float *, const int*, const float *, const int*, float *, const int*);
void ssyr_(char*, const int*, const float *, const float *, const int*, float *, const int*);
void sspr_(char*, const int*, const float *, const float *, const int*, float *);
void sspr2_(char*, const int*, const float *, const float *, const int*, const float *, const int*, float *);
void ssyr2_(char*, const int*, const float *, const float *, const int*, const float *, const int*, float *, const int*);
/* Double Precision */
void dgemv_(char*, const int*, const int*, const double *, const double *, const int*, const double *, const int*, const double *, double *, const int*);
void dgbmv_(char*, const int*, const int*, const int*, const int*, const double *, const double *, const int*, const double *, const int*, const double *, double *, const int*);
void dsymv_(char*, const int*, const double *, const double *, const int*, const double *, const int*, const double *, double *, const int*);
void dsbmv_(char*, const int*, const int*, const double *, const double *, const int*, const double *, const int*, const double *, double *, const int*);
void dspmv_(char*, const int*, const double *, const double *, const double *, const int*, const double *, double *, const int*);
void dtrmv_( char*, char*, char*, const int*, const double *, const int*, double *, const int*);
void dtbmv_( char*, char*, char*, const int*, const int*, const double *, const int*, double *, const int*);
void dtrsv_( char*, char*, char*, const int*, const double *, const int*, double *, const int*);
void dtbsv_( char*, char*, char*, const int*, const int*, const double *, const int*, double *, const int*);
void dtpmv_( char*, char*, char*, const int*, const double *, double *, const int*);
void dtpsv_( char*, char*, char*, const int*, const double *, double *, const int*);
void dger_( const int*, const int*, const double *, const double *, const int*, const double *, const int*, double *, const int*);
void dsyr_(char*, const int*, const double *, const double *, const int*, double *, const int*);
void dspr_(char*, const int*, const double *, const double *, const int*, double *);
void dspr2_(char*, const int*, const double *, const double *, const int*, const double *, const int*, double *);
void dsyr2_(char*, const int*, const double *, const double *, const int*, const double *, const int*, double *, const int*);
/* Single Complex Precision */
void cgemv_(char*, const int*, const int*, const void *, const void *, const int*, const void *, const int*, const void *, void *, const int*);
void cgbmv_(char*, const int*, const int*, const int*, const int*, const void *, const void *, const int*, const void *, const int*, const void *, void *, const int*);
void chemv_(char*, const int*, const void *, const void *, const int*, const void *, const int*, const void *, void *, const int*);
void chbmv_(char*, const int*, const int*, const void *, const void *, const int*, const void *, const int*, const void *, void *, const int*);
void chpmv_(char*, const int*, const void *, const void *, const void *, const int*, const void *, void *, const int*);
void ctrmv_( char*, char*, char*, const int*, const void *, const int*, void *, const int*);
void ctbmv_( char*, char*, char*, const int*, const int*, const void *, const int*, void *, const int*);
void ctpmv_( char*, char*, char*, const int*, const void *, void *, const int*);
void ctrsv_( char*, char*, char*, const int*, const void *, const int*, void *, const int*);
void ctbsv_( char*, char*, char*, const int*, const int*, const void *, const int*, void *, const int*);
void ctpsv_( char*, char*, char*, const int*, const void *, void *,const int*);
void cgerc_( const int*, const int*, const void *, const void *, const int*, const void *, const int*, void *, const int*);
void cgeru_( const int*, const int*, const void *, const void *, const int*, const void *, const int*, void *, const int*);
void cher_(char*, const int*, const float *, const void *, const int*, void *, const int*);
void cher2_(char*, const int*, const void *, const void *, const int*, const void *, const int*, void *, const int*);
void chpr_(char*, const int*, const float *, const void *, const int*, void *);
void chpr2_(char*, const int*, const float *, const void *, const int*, const void *, const int*, void *);
/* Double Complex Precision */
void zgemv_(char*, const int*, const int*, const void *, const void *, const int*, const void *, const int*, const void *, void *, const int*);
void zgbmv_(char*, const int*, const int*, const int*, const int*, const void *, const void *, const int*, const void *, const int*, const void *, void *, const int*);
void zhemv_(char*, const int*, const void *, const void *, const int*, const void *, const int*, const void *, void *, const int*);
void zhbmv_(char*, const int*, const int*, const void *, const void *, const int*, const void *, const int*, const void *, void *, const int*);
void zhpmv_(char*, const int*, const void *, const void *, const void *, const int*, const void *, void *, const int*);
void ztrmv_( char*, char*, char*, const int*, const void *, const int*, void *, const int*);
void ztbmv_( char*, char*, char*, const int*, const int*, const void *, const int*, void *, const int*);
void ztpmv_( char*, char*, char*, const int*, const void *, void *, const int*);
void ztrsv_( char*, char*, char*, const int*, const void *, const int*, void *, const int*);
void ztbsv_( char*, char*, char*, const int*, const int*, const void *, const int*, void *, const int*);
void ztpsv_( char*, char*, char*, const int*, const void *, void *,const int*);
void zgerc_( const int*, const int*, const void *, const void *, const int*, const void *, const int*, void *, const int*);
void zgeru_( const int*, const int*, const void *, const void *, const int*, const void *, const int*, void *, const int*);
void zher_(char*, const int*, const double *, const void *, const int*, void *, const int*);
void zher2_(char*, const int*, const void *, const void *, const int*, const void *, const int*, void *, const int*);
void zhpr_(char*, const int*, const double *, const void *, const int*, void *);
void zhpr2_(char*, const int*, const double *, const void *, const int*, const void *, const int*, void *);
/***********/
/* Level 3 */
/***********/
/* Single Precision */
void sgemm_(char*, char*, const int*, const int*, const int*, const float *, const float *, const int*, const float *, const int*, const float *, float *, const int*);
void ssymm_(char*, char*, const int*, const int*, const float *, const float *, const int*, const float *, const int*, const float *, float *, const int*);
void ssyrk_(char*, char*, const int*, const int*, const float *, const float *, const int*, const float *, float *, const int*);
void ssyr2k_(char*, char*, const int*, const int*, const float *, const float *, const int*, const float *, const int*, const float *, float *, const int*);
void strmm_(char*, char*, char*, char*, const int*, const int*, const float *, const float *, const int*, float *, const int*);
void strsm_(char*, char*, char*, char*, const int*, const int*, const float *, const float *, const int*, float *, const int*);
/* Double Precision */
void dgemm_(char*, char*, const int*, const int*, const int*, const double *, const double *, const int*, const double *, const int*, const double *, double *, const int*);
void dsymm_(char*, char*, const int*, const int*, const double *, const double *, const int*, const double *, const int*, const double *, double *, const int*);
void dsyrk_(char*, char*, const int*, const int*, const double *, const double *, const int*, const double *, double *, const int*);
void dsyr2k_(char*, char*, const int*, const int*, const double *, const double *, const int*, const double *, const int*, const double *, double *, const int*);
void dtrmm_(char*, char*, char*, char*, const int*, const int*, const double *, const double *, const int*, double *, const int*);
void dtrsm_(char*, char*, char*, char*, const int*, const int*, const double *, const double *, const int*, double *, const int*);
/* Single Complex Precision */
void cgemm_(char*, char*, const int*, const int*, const int*, const float *, const float *, const int*, const float *, const int*, const float *, float *, const int*);
void csymm_(char*, char*, const int*, const int*, const float *, const float *, const int*, const float *, const int*, const float *, float *, const int*);
void chemm_(char*, char*, const int*, const int*, const float *, const float *, const int*, const float *, const int*, const float *, float *, const int*);
void csyrk_(char*, char*, const int*, const int*, const float *, const float *, const int*, const float *, float *, const int*);
void cherk_(char*, char*, const int*, const int*, const float *, const float *, const int*, const float *, float *, const int*);
void csyr2k_(char*, char*, const int*, const int*, const float *, const float *, const int*, const float *, const int*, const float *, float *, const int*);
void cher2k_(char*, char*, const int*, const int*, const float *, const float *, const int*, const float *, const int*, const float *, float *, const int*);
void ctrmm_(char*, char*, char*, char*, const int*, const int*, const float *, const float *, const int*, float *, const int*);
void ctrsm_(char*, char*, char*, char*, const int*, const int*, const float *, const float *, const int*, float *, const int*);
/* Double Complex Precision */
void zgemm_(char*, char*, const int*, const int*, const int*, const double *, const double *, const int*, const double *, const int*, const double *, double *, const int*);
void zsymm_(char*, char*, const int*, const int*, const double *, const double *, const int*, const double *, const int*, const double *, double *, const int*);
void zhemm_(char*, char*, const int*, const int*, const double *, const double *, const int*, const double *, const int*, const double *, double *, const int*);
void zsyrk_(char*, char*, const int*, const int*, const double *, const double *, const int*, const double *, double *, const int*);
void zherk_(char*, char*, const int*, const int*, const double *, const double *, const int*, const double *, double *, const int*);
void zsyr2k_(char*, char*, const int*, const int*, const double *, const double *, const int*, const double *, const int*, const double *, double *, const int*);
void zher2k_(char*, char*, const int*, const int*, const double *, const double *, const int*, const double *, const int*, const double *, double *, const int*);
void ztrmm_(char*, char*, char*, char*, const int*, const int*, const double *, const double *, const int*, double *, const int*);
void ztrsm_(char*, char*, char*, char*, const int*, const int*, const double *, const double *, const int*, double *, const int*);
}
"""
def _constant(f): def _constant(f):
"""Return a function that always returns its first call value """Return a function that always returns its first call value
""" """
...@@ -603,12 +808,22 @@ def ldflags(): ...@@ -603,12 +808,22 @@ def ldflags():
"""Return a list of libraries against which an Op's object file should be """Return a list of libraries against which an Op's object file should be
linked to benefit from a BLAS implementation. linked to benefit from a BLAS implementation.
Default: ['cblas','blas'], but environment variable OMEGA_BLAS_LDFLAGS overrides this. Default: ['blas'], but environment variable THEANO_BLAS_LDFLAGS overrides this.
""" """
if os.getenv('OMEGA_BLAS_LDFLAGS'): if os.getenv('THEANO_BLAS_LDFLAGS'):
return os.getenv('OMEGA_BLAS_LDFLAGS').split() tokens = os.getenv('THEANO_BLAS_LDFLAGS').split()
for t in tokens:
try:
t0, t1, t2 = t[0:3]
assert t0 == '-'
except e:
raise ValueError('invalid token in THEANO_BLAS_LDFLAGS', t)
if t1 == 'L':
raise ValueError('library dir not allowed in THEANO_BLAS_LDFLAGS', t)
rval = [token[2:] for token in tokens]
return rval
else: else:
return ['cblas', 'blas'] return ['blas']
def gemm_code(check_ab, a_init, b_init): def gemm_code(check_ab, a_init, b_init):
mod = '%' mod = '%'
......
...@@ -397,9 +397,6 @@ class CAReduce(Op): ...@@ -397,9 +397,6 @@ class CAReduce(Op):
if dimensions_to_reduce is None: if dimensions_to_reduce is None:
dimensions_to_reduce = range(len(inputs[0].broadcastable)) dimensions_to_reduce = range(len(inputs[0].broadcastable))
self.nin = 1
self.nout = 1
self.inputs = inputs self.inputs = inputs
self.outputs = [Tensor(dtype = inputs[0].dtype, self.outputs = [Tensor(dtype = inputs[0].dtype,
broadcastable = [x for i, x in enumerate(inputs[0].broadcastable) if i not in dimensions_to_reduce])] broadcastable = [x for i, x in enumerate(inputs[0].broadcastable) if i not in dimensions_to_reduce])]
......
...@@ -409,7 +409,7 @@ class CLinker(Linker): ...@@ -409,7 +409,7 @@ class CLinker(Linker):
elif result in self.orphans: elif result in self.orphans:
self.orphans.remove(result) self.orphans.remove(result)
continue continue
except AbstractFunctionError: except (AbstractFunctionError, NotImplementedError):
pass pass
# policy = [[what to declare in the struct, what to do at construction, what to do at destruction], # policy = [[what to declare in the struct, what to do at construction, what to do at destruction],
# [what to declare in each run, what to do at the beginning of each run, what to do at the end of each run]] # [what to declare in each run, what to do at the beginning of each run, what to do at the end of each run]]
......
...@@ -35,22 +35,23 @@ class Computed : """Memory has been allocated, contents are the owner's output." ...@@ -35,22 +35,23 @@ class Computed : """Memory has been allocated, contents are the owner's output."
############################ ############################
class Result(object): class Result(object):
"""Base class for storing L{Op} inputs and outputs """
Base class for storing L{Op} inputs and outputs
Attributes: Attributes:
_role - None or (owner, index) #or BrokenLink - _role - None or (owner, index) #or BrokenLink
_data - anything - _data - anything
state - one of (Empty, Allocated, Computed) - state - one of (Empty, Allocated, Computed)
name - string - name - string
Properties: Properties:
role - (rw) - role - (rw)
owner - (ro) - owner - (ro)
index - (ro) - index - (ro)
data - (rw) : calls data_filter when setting - data - (rw) : calls data_filter when setting
Abstract Methods: Abstract Methods:
data_filter - data_filter
""" """
__slots__ = ['_role', '_data', 'state', '_name', '_hash_id'] __slots__ = ['_role', '_data', 'state', '_name', '_hash_id']
...@@ -241,6 +242,13 @@ class Result(object): ...@@ -241,6 +242,13 @@ class Result(object):
def c_libraries(self): def c_libraries(self):
""" """
Return a list of libraries to link against to manipulate this L{Result}. Return a list of libraries to link against to manipulate this L{Result}.
For example: return ['gsl', 'gslcblas', 'm', 'fftw3', 'g2c'].
The compiler will search the directories specified by the environment
variable LD_LIBRARY_PATH. No option is provided for an Op to provide an
extra library directory because this would change the linking path for
other Ops in a potentially disasterous way.
""" """
raise AbstractFunctionError() raise AbstractFunctionError()
......
from gof import opt from gof import opt, Env
from elemwise import Broadcast import gof
from elemwise import Broadcast, DimShuffle
from gof.python25 import any, all
import scalar
class InplaceOptimizer(opt.OpSpecificOptimizer): class InplaceOptimizer(opt.OpSpecificOptimizer):
...@@ -26,30 +29,236 @@ class InplaceOptimizer(opt.OpSpecificOptimizer): ...@@ -26,30 +29,236 @@ class InplaceOptimizer(opt.OpSpecificOptimizer):
inplace_optimizer = InplaceOptimizer() inplace_optimizer = InplaceOptimizer()
# class ElemwisePatternOptimizer(opt.Optimizer):
# def __init__(self, scalar_opt): class DimShuffleLifter(opt.Optimizer):
# self. """
"Lifts" DimShuffle through Broadcast operations and merges
consecutive DimShuffles. Basically, applies the following
transformations on the whole graph:
DimShuffle(Broadcast(x, y)) => Broadcast(DimShuffle(x), DimShuffle(y))
DimShuffle(DimShuffle(x)) => DimShuffle(x)
# def find_elemwise_cliques(env, cross_broadcast = False): After this transform, clusters of Broadcast operations are
void of DimShuffle operations.
"""
def apply(self, env):
seen = set()
def merge(ord1, ord2):
return [x == 'x' and 'x' or ord1[x] for x in ord2]
def lift(r):
if r in seen:
return
seen.add(r)
op = r.owner
if op is None \
or op in env.inputs \
or op in env.orphans():
return
if isinstance(op, DimShuffle):
in_op = op.inputs[0].owner
if isinstance(in_op, DimShuffle):
new_order = [x == 'x' and 'x' or in_op.new_order[x] for x in op.new_order]
if new_order == range(len(new_order)):
repl = in_op.inputs[0]
else:
repl = DimShuffle(in_op.inputs[0], new_order).out
env.replace(r, repl)
lift(repl)
return
elif isinstance(in_op, Broadcast):
repl = Broadcast(in_op.scalar_opclass,
[DimShuffle(input, op.new_order).out for input in in_op.inputs],
in_op.inplace_pattern).out
env.replace(r, repl)
r = repl
op = r.owner
for next_r in op.inputs:
lift(next_r)
for output in env.outputs:
lift(output)
lift_dimshuffle = DimShuffleLifter()
def find_cliques(env, through_broadcast = False):
# def synchronize(env1, env2, equiv, transform): def seek_from(r):
op = r.owner
if r in env.inputs \
or r in env.orphans() \
or op is None \
or not isinstance(op, Broadcast) \
or len(op.outputs) > 1:
# todo: handle multiple-output broadcast ops
# (needs to update the clique's outputs)
return None
# class Synchronize(Listener, Constraint): ret = set()
if not through_broadcast:
if any(any(bc) and not all(bc)
for bc in zip(*[input.broadcastable for input in op.inputs])):
ret.update(op.inputs)
return ret
for input in op.inputs:
res = seek_from(input)
if res is None:
ret.add(input)
else:
ret.update(res)
# def on_import(self, op1): return ret
# if op1 not in equiv:
# equiv[op1] = transform(op1) cliques = []
def find_cliques_helper(r):
if r in env.inputs or r in env.orphans():
return
clique_inputs = seek_from(r)
if clique_inputs is None:
op = r.owner
if op is not None:
for input in op.inputs:
find_cliques_helper(input)
else:
cliques.append((clique_inputs, [r]))
for input in clique_inputs:
find_cliques_helper(input)
for output in env.outputs:
find_cliques_helper(output)
# todo: merge the cliques if possible
return cliques
class CliqueOptimizer(opt.Optimizer):
def __init__(self, through_broadcast = False, scalar_optimizer = None, make_composite = False):
self.through_broadcast = through_broadcast
self.scalar_optimizer = scalar_optimizer
self.make_composite = make_composite
def apply(self, env):
if self.scalar_optimizer is None and not self.make_composite:
# there's nothing to do with the cliques...
return
cliques = find_cliques(env, self.through_broadcast)
opt = self.scalar_optimizer
def build_scalar_clique(r, env, equiv):
if r in equiv:
return equiv[r]
op = r.owner
if r in env.inputs or r in env.orphans():
s = scalar.Scalar(dtype = r.dtype)
_r = r
if isinstance(r.owner, DimShuffle) and all(x == 'x' for x in r.owner.new_order):
_r = r.owner.inputs[0]
if (getattr(r, 'constant', False) or getattr(_r, 'constant', False)) \
and _r.broadcastable == ():
s.data = _r.data
s.constant = True
equiv[r] = s
return s
s_op = op.scalar_opclass(*[build_scalar_clique(input, env, equiv) for input in op.inputs])
equiv[op] = s_op
for output, s_output in zip(op.outputs, s_op.outputs):
equiv[output] = s_output
return equiv[r]
for c_in, c_out in cliques:
equiv = dict()
g = Env(c_in, c_out)
for output in c_out:
build_scalar_clique(output, g, equiv)
s_g = Env([equiv[r] for r in g.inputs],
[equiv[r] for r in g.outputs])
if opt is not None:
equiv2 = dict()
for k, v in equiv.items():
equiv2[v] = k
def transform(op, equiv):
return Broadcast(op.__class__, [equiv[input] for input in op.inputs])
s_g.add_feature(sync_to(env, equiv2, transform))
opt.optimize(s_g)
if self.make_composite:
def follow_inplace(r):
op = r.owner
if op is None or r in g.inputs or r in g.orphans():
return None
assert isinstance(op, Broadcast)
destroyed = op.destroy_map().get(r, None)
if destroyed is None:
return None
else:
r2 = destroyed[0]
ret = follow_inplace(r2)
if ret is None:
return r2
else:
return ret
inplace_pattern = {}
for i, output in enumerate(g.outputs):
destroyed = follow_inplace(output)
if destroyed is not None and destroyed in g.inputs:
inplace_pattern[i] = g.inputs.index(destroyed)
C = scalar.composite(s_g.inputs, s_g.outputs)
ec = Broadcast(C, g.inputs, inplace_pattern = inplace_pattern)
env.replace_all(dict((o, eco) for o, eco in zip(c_out, ec.outputs)))
def sync_to(target, equiv, transform):
class Synchronize(gof.Listener, gof.Constraint):
def __init__(self, source):
self.source = source
self.target = target
self.equiv = equiv
self.transform = transform
self.inconsistencies = []
def on_import(self, op1):
if op1 not in self.equiv:
op2 = self.transform(op1, self.equiv)
self.equiv[op1] = op2
for o1, o2 in zip(op1.outputs, op2.outputs):
self.equiv[o1] = o2
def on_prune(self, op1):
if op1 in self.equiv:
op2 = self.equiv[op1]
del self.equiv[op1]
for o1, o2 in zip(op1.outputs, op2.outputs):
del self.equiv[o1]
def on_rewire(self, clients1, r1, new_r1):
if (new_r1, r1) in self.inconsistencies:
self.inconsistencies.remove((new_r1, r1))
return
if not self.source.clients(r1):
try:
target.replace(self.equiv[r1], self.equiv[new_r1])
except:
self.inconsistencies.append((r1, new_r1))
def validate(self):
if self.inconsistencies:
raise InconsistencyError("Could not synchronize when replacing the following pairs: %s" % self.inconsistencies)
return True
return Synchronize
# def on_prune(self, op1):
# if op1 in equiv:
# del equiv[op1]
......
...@@ -5,7 +5,8 @@ import math ...@@ -5,7 +5,8 @@ import math
from copy import copy from copy import copy
import inspect import inspect
from gof import Result, GuardedOp, utils import gof
from gof import Result, GuardedOp, Env, utils
def as_scalar(x, name = None): def as_scalar(x, name = None):
...@@ -20,6 +21,11 @@ def as_scalar(x, name = None): ...@@ -20,6 +21,11 @@ def as_scalar(x, name = None):
if isinstance(x, Scalar): if isinstance(x, Scalar):
return x return x
def constant(x):
res = as_scalar(x)
res.constant = True
return res
class Scalar(Result): class Scalar(Result):
...@@ -29,6 +35,8 @@ class Scalar(Result): ...@@ -29,6 +35,8 @@ class Scalar(Result):
self.dtype_specs() self.dtype_specs()
def __get_constant(self): def __get_constant(self):
if not hasattr(self, '_constant'):
return False
return self._constant return self._constant
def __set_constant(self, value): def __set_constant(self, value):
...@@ -37,7 +45,10 @@ class Scalar(Result): ...@@ -37,7 +45,10 @@ class Scalar(Result):
self._constant = value self._constant = value
constant = property(__get_constant, __set_constant) constant = property(__get_constant, __set_constant)
def desc(self):
return (self.dtype, self.data)
def filter(self, data): def filter(self, data):
py_type = self.dtype_specs()[0] py_type = self.dtype_specs()[0]
return py_type(data) return py_type(data)
...@@ -58,6 +69,11 @@ class Scalar(Result): ...@@ -58,6 +69,11 @@ class Scalar(Result):
except KeyError: except KeyError:
raise TypeError("Unsupported dtype for %s: %s" % (self.__class__.__name__, self.dtype)) raise TypeError("Unsupported dtype for %s: %s" % (self.__class__.__name__, self.dtype))
def c_literal(self):
if 'complex' in self.dtype:
raise NotImplementedError("No literal for complex values.")
return str(self.data)
def c_declare(self, name, sub): def c_declare(self, name, sub):
return """ return """
%(dtype)s %(name)s; %(dtype)s %(name)s;
...@@ -184,7 +200,7 @@ class ScalarMixedOp(GuardedOp): ...@@ -184,7 +200,7 @@ class ScalarMixedOp(GuardedOp):
inputs = [as_scalar(input) for input in inputs] inputs = [as_scalar(input) for input in inputs]
i_dtypes = [getattr(input, 'dtype', None) for input in inputs] i_dtypes = [getattr(input, 'dtype', None) for input in inputs]
o_dtypes = utils.from_return_values(self.propagate_dtypes(*i_dtypes)) o_dtypes = self.propagate_dtypes(*i_dtypes)
self.inputs = inputs self.inputs = inputs
self.outputs = [Scalar(dtype) for dtype in o_dtypes] self.outputs = [Scalar(dtype) for dtype in o_dtypes]
...@@ -217,7 +233,7 @@ class PureScalarOp(ScalarMixedOp): ...@@ -217,7 +233,7 @@ class PureScalarOp(ScalarMixedOp):
for dtype in i_dtypes: for dtype in i_dtypes:
if dtype is None: if dtype is None:
raise TypeError("Expected a Scalar.") raise TypeError("Expected a Scalar.")
return self.cast_method(*i_dtypes) return [self.cast_method(*i_dtypes)] * self.nout
class UnaryScalarOp(PureScalarOp): class UnaryScalarOp(PureScalarOp):
...@@ -383,4 +399,111 @@ modes.make_constructors(globals()) ...@@ -383,4 +399,111 @@ modes.make_constructors(globals())
def composite(inputs, outputs):
"""
Usage: composite(inputs, outputs)
Produces an Op class which represents the computations
between the provided inputs and outputs as a single
operation.
The operations between inputs and outputs (as given by
Env(inputs, outputs).ops()) must all be instances of
PureScalarOp.
Examples:
x, y = Scalar(), Scalar()
SquareDiff = composite([x, y], [(x - y)**2])
TimesTen = composite([x], [x * 10.0])
Neighbors = composite([x], [x - 1, x + 1])
"""
env = Env(inputs, outputs).clone()
gof.opt.ConstantFinder().apply(env)
inputs, outputs = env.inputs, env.outputs
for op in env.ops():
if not isinstance(op, PureScalarOp):
raise ValueError("The input env to composite must be exclusively composed of PureScalarOp instances.")
subd = dict(zip(inputs,
["%%(i%i)s"%i for i in range(len(inputs))]) +
zip(outputs,
["%%(o%i)s"%i for i in range(len(outputs))]))
for orphan in env.orphans():
if orphan.constant:
subd[orphan] = orphan.c_literal()
else:
raise ValueError("All orphans in the input env to composite must be constant.")
_c_code = "{\n"
i = 0
j = 0
for op in env.toposort():
j += 1
for output in op.outputs:
if output not in subd:
i += 1
name = "V%%(id)s_tmp%i" % i
subd[output] = name
# the c code is not robust to any other dtypes than those of the specified inputs
# a solution would be to require Composite.c_code to fill in the dtypes using
# a proper upcast
_c_code += "%s %s;\n" % (output.dtype_specs()[1], name)
_c_code += op.c_code([subd[input] for input in op.inputs],
[subd[output] for output in op.outputs],
dict(fail = "%(fail)s",
id = "%%(id)s_%i" % j))
_c_code += "\n"
_c_code += "}\n"
def compose_impl(r):
# this is not optimal at all eg in add(*1 -> mul(x, y), *1)
# it will calculate *1 twice
# it also doesn't follow env.toposort but that's (presumably)
# still correct since we only have pure scalar ops
if r in env.inputs:
idx = env.inputs.index(r)
return lambda inputs: inputs[idx]
elif r in env.orphans():
return lambda inputs: r.data
op = r.owner
producers = [compose_impl(input) for input in op.inputs]
return lambda inputs: op.impl(*[p(inputs) for p in producers])
_impls = [compose_impl(r) for r in env.outputs]
class Composite(PureScalarOp):
nin = len(inputs)
nout = len(outputs)
# todo: propagate_dtypes?
def perform(self):
inputs = [input.data for input in self.inputs]
for output, impl in zip(self.outputs, _impls):
output.data = impl(inputs)
def impl(self, *inputs):
for r, input in zip(self.inputs, inputs):
r.data = input
self.perform()
return utils.to_return_values([output.data for output in self.outputs])
def grad(self, inputs, output_grads):
raise NotImplementedError("grad is not implemented for Composite")
def c_code(self, inames, onames, sub):
d = dict(zip(["i%i"%i for i in range(len(inames))],
inames) +
zip(["o%i"%i for i in range(len(onames))],
onames),
**sub)
return _c_code % d
return Composite
from scalar import *
from gof import PatternOptimizer
c2 = constant(2.0)
opt1 = PatternOptimizer((Mul, 'x', 'x'), (Sqr, 'x'))
opt2 = PatternOptimizer((Pow, 'x', c2), (Sqr, 'x'))
...@@ -153,7 +153,7 @@ class _Op(BaseTensorOp): ...@@ -153,7 +153,7 @@ class _Op(BaseTensorOp):
return self.c_impl(self.inputs, self.outputs) % sub return self.c_impl(self.inputs, self.outputs) % sub
def c_impl(self, inputs, outputs): def c_impl(self, inputs, outputs):
raise AbstractFunctionError() raise AbstractFunctionError("No c_impl for %s" % self.__class__.__name__)
class _Unary: class _Unary:
nin = 1 nin = 1
...@@ -420,24 +420,22 @@ class Gemm(_Op): ...@@ -420,24 +420,22 @@ class Gemm(_Op):
raise NotImplementedError() raise NotImplementedError()
def c_support_code(self): def c_support_code(self):
return blas.cblas_header_text() #return blas.cblas_header_text()
mod_str = """
#ifndef MOD
#define MOD %
#endif
"""
return blas.blas_proto() + mod_str
def c_headers(self):
return ['<iostream>']
def c_libraries(self): def c_libraries(self):
return blas.ldflags() return blas.ldflags()
def c_var_names(self): #def c_var_names(self):
return [['_z', '_a', '_x', '_y', '_b'], ['_zout']] # return [['_z', '_a', '_x', '_y', '_b'], ['_zout']]
def c_validate_update(self, (_z, _a, _x, _y, _b), (_zout, ), sub): def c_validate_update(self, *args):
return """ return ""
if (%(_zout)s != %(_z)s) def c_validate_update_cleanup(self, *args):
{
if (%(_zout)s)
{
Py_DECREF(%(_zout)s);
}
%(_zout)s = %(_z)s;
Py_INCREF(%(_zout)s);
}
""" % locals()
def c_validate_update_cleanup(self, ignore, _ignore, __ignore):
return "" return ""
def c_code(self, (_z, _a, _x, _y, _b), (_zout, ), sub): def c_code(self, (_z, _a, _x, _y, _b), (_zout, ), sub):
return """ return """
...@@ -454,14 +452,22 @@ class Gemm(_Op): ...@@ -454,14 +452,22 @@ class Gemm(_Op):
npy_intp* Sy = %(_y)s->strides; npy_intp* Sy = %(_y)s->strides;
npy_intp* Sz = %(_z)s->strides; npy_intp* Sz = %(_z)s->strides;
size_t sx_0, sx_1, sy_0, sy_1, sz_0, sz_1; //strides for x, y, z in dimensions 0, 1
int sx_0, sx_1, sy_0, sy_1, sz_0, sz_1;
if (%(_x)s->nd != 2) if (%(_zout)s != %(_z)s)
{PyErr_SetString(PyExc_NotImplementedError, "rank(x) != 2"); %(fail)s;} {
if (%(_y)s->nd != 2) if (%(_zout)s)
{PyErr_SetString(PyExc_NotImplementedError, "rank(y) != 2"); %(fail)s;} {
if (%(_z)s->nd != 2) Py_DECREF(%(_zout)s);
{PyErr_SetString(PyExc_NotImplementedError, "rank(z) != 2"); %(fail)s;} }
%(_zout)s = %(_z)s;
Py_INCREF(%(_zout)s);
}
if (%(_x)s->nd != 2) {PyErr_SetString(PyExc_NotImplementedError, "rank(x) != 2"); %(fail)s;}
if (%(_y)s->nd != 2) {PyErr_SetString(PyExc_NotImplementedError, "rank(y) != 2"); %(fail)s;}
if (%(_z)s->nd != 2) {PyErr_SetString(PyExc_NotImplementedError, "rank(z) != 2"); %(fail)s;}
if ((%(_a)s->descr->type_num != PyArray_DOUBLE) if ((%(_a)s->descr->type_num != PyArray_DOUBLE)
&& (%(_a)s->descr->type_num != PyArray_FLOAT)) && (%(_a)s->descr->type_num != PyArray_FLOAT))
...@@ -473,19 +479,19 @@ class Gemm(_Op): ...@@ -473,19 +479,19 @@ class Gemm(_Op):
if ((%(_x)s->descr->type_num != PyArray_DOUBLE) if ((%(_x)s->descr->type_num != PyArray_DOUBLE)
&& (%(_x)s->descr->type_num != PyArray_FLOAT)) && (%(_x)s->descr->type_num != PyArray_FLOAT))
%(fail)s; {PyErr_SetString(PyExc_NotImplementedError, "type(x) is not double or float"); %(fail)s;}
if ((%(_y)s->descr->type_num != PyArray_DOUBLE) if ((%(_y)s->descr->type_num != PyArray_DOUBLE)
&& (%(_y)s->descr->type_num != PyArray_FLOAT)) && (%(_y)s->descr->type_num != PyArray_FLOAT))
%(fail)s; {PyErr_SetString(PyExc_NotImplementedError, "type(y) is not double or float"); %(fail)s;}
if ((%(_y)s->descr->type_num != PyArray_DOUBLE) if ((%(_z)s->descr->type_num != PyArray_DOUBLE)
&& (%(_y)s->descr->type_num != PyArray_FLOAT)) && (%(_z)s->descr->type_num != PyArray_FLOAT))
%(fail)s; {PyErr_SetString(PyExc_NotImplementedError, "type(z) is not double or float"); %(fail)s;}
if ((%(_x)s->descr->type_num != %(_y)s->descr->type_num) if ((%(_x)s->descr->type_num != %(_y)s->descr->type_num)
||(%(_x)s->descr->type_num != %(_z)s->descr->type_num)) ||(%(_x)s->descr->type_num != %(_z)s->descr->type_num))
%(fail)s; { PyErr_SetString(PyExc_NotImplementedError, "type(z), type(y), type(z) are not all the same"); %(fail)s; }
if ((Nx[0] != Nz[0]) || (Nx[1] != Ny[0]) || (Ny[1] != Nz[1])) if ((Nx[0] != Nz[0]) || (Nx[1] != Ny[0]) || (Ny[1] != Nz[1]))
{ {
...@@ -496,17 +502,15 @@ class Gemm(_Op): ...@@ -496,17 +502,15 @@ class Gemm(_Op):
|| (Sy[0] < 1) || (Sy[1] < 1) || (Sy[0] MOD type_size) || (Sy[1] MOD type_size) || (Sy[0] < 1) || (Sy[1] < 1) || (Sy[0] MOD type_size) || (Sy[1] MOD type_size)
|| (Sz[0] < 1) || (Sz[1] < 1) || (Sz[0] MOD type_size) || (Sz[1] MOD type_size)) || (Sz[0] < 1) || (Sz[1] < 1) || (Sz[0] MOD type_size) || (Sz[1] MOD type_size))
{ {
PyErr_SetString(PyExc_ValueError, "gemm cant run on these inputs"); PyErr_SetString(PyExc_ValueError, "stride is not multiple of element size"); %(fail)s;
%(fail)s;
} }
/* /*
encode the stride structure of _x,_y,_z into a single integer encode the stride structure of _x,_y,_z into a single integer
*/ */
unit |= ((Sx[1] == type_size) ? 0x0 : (Sx[0] == type_size) ? 0x1 : 0x2) << 0; unit |= ((Sx[1] == type_size) ? 0x0 : (Sx[0] == type_size) ? 0x1 : 0x2) << 8;
unit |= ((Sy[1] == type_size) ? 0x0 : (Sy[0] == type_size) ? 0x1 : 0x2) << 4; unit |= ((Sy[1] == type_size) ? 0x0 : (Sy[0] == type_size) ? 0x1 : 0x2) << 4;
unit |= ((Sz[1] == type_size) ? 0x0 : (Sz[0] == type_size) ? 0x1 : 0x2) << 8; unit |= ((Sz[1] == type_size) ? 0x0 : (Sz[0] == type_size) ? 0x1 : 0x2) << 0;
/* create appropriate strides for malformed matrices that are row or column /* create appropriate strides for malformed matrices that are row or column
* vectors * vectors
...@@ -533,18 +537,21 @@ class Gemm(_Op): ...@@ -533,18 +537,21 @@ class Gemm(_Op):
float* x = (float*)PyArray_DATA(%(_x)s); float* x = (float*)PyArray_DATA(%(_x)s);
float* y = (float*)PyArray_DATA(%(_y)s); float* y = (float*)PyArray_DATA(%(_y)s);
float* z = (float*)PyArray_DATA(%(_z)s); float* z = (float*)PyArray_DATA(%(_z)s);
char N = 'N';
char T = 'T';
int Nz0 = Nz[0], Nz1 = Nz[1], Nx1 = Nx[1];
std::cerr << (unit/256) MOD 16 << (unit / 16) MOD 16 << unit MOD 16<< '\\n';
switch(unit) switch(unit)
{ {
case 0x000: cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_0, y, sy_0, b, z, sz_0); break; case 0x000: sgemm_(&N, &N, &Nz1, &Nz0, &Nx1, &a, y, &sy_0, x, &sx_0, &b, z, &sz_0); break;
case 0x001: cblas_sgemm(CblasRowMajor, CblasTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_1, y, sy_0, b, z, sz_0); break; case 0x100: sgemm_(&N, &T, &Nz1, &Nz0, &Nx1, &a, y, &sy_0, x, &sx_1, &b, z, &sz_0); break;
case 0x010: cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, Nz[0], Nz[1], Nx[1], a, x, sx_0, y, sy_1, b, z, sz_0); break; case 0x010: sgemm_(&T, &N, &Nz1, &Nz0, &Nx1, &a, y, &sy_1, x, &sx_0, &b, z, &sz_0); break;
case 0x011: cblas_sgemm(CblasRowMajor, CblasTrans, CblasTrans, Nz[0], Nz[1], Nx[1], a, x, sx_1, y, sy_1, b, z, sz_0); break; case 0x110: sgemm_(&T, &T, &Nz1, &Nz0, &Nx1, &a, y, &sy_1, x, &sx_1, &b, z, &sz_0); break;
case 0x100: cblas_sgemm(CblasColMajor, CblasTrans, CblasTrans, Nz[0], Nz[1], Nx[1], a, x, sx_0, y, sy_0, b, z, sz_1); break; case 0x001: sgemm_(&T, &T, &Nz0, &Nz1, &Nx1, &a, x, &sx_0, y, &sy_0, &b, z, &sz_1); break;
case 0x101: cblas_sgemm(CblasColMajor, CblasNoTrans, CblasTrans, Nz[0], Nz[1], Nx[1], a, x, sx_1, y, sy_0, b, z, sz_1); break; case 0x101: sgemm_(&N, &T, &Nz0, &Nz1, &Nx1, &a, x, &sx_1, y, &sy_0, &b, z, &sz_1); break;
case 0x110: cblas_sgemm(CblasColMajor, CblasTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_0, y, sy_1, b, z, sz_1); break; case 0x011: sgemm_(&T, &N, &Nz0, &Nz1, &Nx1, &a, x, &sx_0, y, &sy_1, &b, z, &sz_1); break;
case 0x111: cblas_sgemm(CblasColMajor, CblasNoTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_1, y, sy_1, b, z, sz_1); break; case 0x111: sgemm_(&N, &N, &Nz0, &Nz1, &Nx1, &a, x, &sx_1, y, &sy_1, &b, z, &sz_1); break;
default: %(fail)s; default: PyErr_SetString(PyExc_ValueError, "some matrix has no unit stride"); %(fail)s;
}; };
#undef REAL #undef REAL
} }
...@@ -562,17 +569,21 @@ class Gemm(_Op): ...@@ -562,17 +569,21 @@ class Gemm(_Op):
double* x = (double*)PyArray_DATA(%(_x)s); double* x = (double*)PyArray_DATA(%(_x)s);
double* y = (double*)PyArray_DATA(%(_y)s); double* y = (double*)PyArray_DATA(%(_y)s);
double* z = (double*)PyArray_DATA(%(_z)s); double* z = (double*)PyArray_DATA(%(_z)s);
char N = 'N';
char T = 'T';
int Nz0 = Nz[0], Nz1 = Nz[1], Nx1 = Nx[1];
//std::cerr << (unit/256) MOD 16 << (unit / 16) MOD 16 << unit MOD 16<< '\\n';
switch(unit) switch(unit)
{ {
case 0x000: cblas_dgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_0, y, sy_0, b, z, sz_0); break; case 0x000: dgemm_(&N, &N, &Nz1, &Nz0, &Nx1, &a, y, &sy_0, x, &sx_0, &b, z, &sz_0); break;
case 0x001: cblas_dgemm(CblasRowMajor, CblasTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_1, y, sy_0, b, z, sz_0); break; case 0x100: dgemm_(&N, &T, &Nz1, &Nz0, &Nx1, &a, y, &sy_0, x, &sx_1, &b, z, &sz_0); break;
case 0x010: cblas_dgemm(CblasRowMajor, CblasNoTrans, CblasTrans, Nz[0], Nz[1], Nx[1], a, x, sx_0, y, sy_1, b, z, sz_0); break; case 0x010: dgemm_(&T, &N, &Nz1, &Nz0, &Nx1, &a, y, &sy_1, x, &sx_0, &b, z, &sz_0); break;
case 0x011: cblas_dgemm(CblasRowMajor, CblasTrans, CblasTrans, Nz[0], Nz[1], Nx[1], a, x, sx_1, y, sy_1, b, z, sz_0); break; case 0x110: dgemm_(&T, &T, &Nz1, &Nz0, &Nx1, &a, y, &sy_1, x, &sx_1, &b, z, &sz_0); break;
case 0x100: cblas_dgemm(CblasColMajor, CblasTrans, CblasTrans, Nz[0], Nz[1], Nx[1], a, x, sx_0, y, sy_0, b, z, sz_1); break; case 0x001: dgemm_(&T, &T, &Nz0, &Nz1, &Nx1, &a, x, &sx_0, y, &sy_0, &b, z, &sz_1); break;
case 0x101: cblas_dgemm(CblasColMajor, CblasNoTrans, CblasTrans, Nz[0], Nz[1], Nx[1], a, x, sx_1, y, sy_0, b, z, sz_1); break; case 0x101: dgemm_(&N, &T, &Nz0, &Nz1, &Nx1, &a, x, &sx_1, y, &sy_0, &b, z, &sz_1); break;
case 0x110: cblas_dgemm(CblasColMajor, CblasTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_0, y, sy_1, b, z, sz_1); break; case 0x011: dgemm_(&T, &N, &Nz0, &Nz1, &Nx1, &a, x, &sx_0, y, &sy_1, &b, z, &sz_1); break;
case 0x111: cblas_dgemm(CblasColMajor, CblasNoTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_1, y, sy_1, b, z, sz_1); break; case 0x111: dgemm_(&N, &N, &Nz0, &Nz1, &Nx1, &a, x, &sx_1, y, &sy_1, &b, z, &sz_1); break;
default: %(fail)s; default: PyErr_SetString(PyExc_ValueError, "some matrix has no unit stride"); %(fail)s;
}; };
#undef REAL #undef REAL
} }
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论