提交 19de7f5d authored 作者: bergstrj@iro.umontreal.ca's avatar bergstrj@iro.umontreal.ca

merged

import unittest
from core import build_mode, pop_mode
from ops import *
class _testCase_add_build_mode(unittest.TestCase):
def setUp(self):
build_mode()
numpy.random.seed(44)
def tearDown(self):
pop_mode()
class _testCase_dot(unittest.TestCase):
def setUp(self):
build_eval_mode()
numpy.random.seed(44)
def tearDown(self):
pop_mode()
@staticmethod
def rand(*args):
return numpy.random.rand(*args)
def cmp_dot(self,x,y):
if 0:
def spec(x):
x = numpy.asarray(x)
return type(x), x.dtype, x.shape
zspec = dot.specs(spec(x), spec(y))
nz = numpy.dot(x,y)
self.failUnless(zspec == spec(nz))
self.failUnless(_approx_eq(dot(x,y), numpy.dot(x,y)))
def cmp_dot_comp(self, x,y):
x = numpy.asarray(x)
y = numpy.asarray(y)
z = dot(x,y)
p = compile.single(z)
if len(x.shape):
x[:] = numpy.random.rand(*x.shape)
else:
x.fill(numpy.random.rand(*x.shape))
if len(y.shape):
y[:] = numpy.random.rand(*y.shape)
else:
y.fill(numpy.random.rand(*y.shape))
p() # recalculate z
self.failUnless(_approx_eq(z, numpy.dot(x,y)))
def test_dot_0d_0d(self): self.cmp_dot(1.1, 2.2)
def test_dot_0d_1d(self): self.cmp_dot(1.1, self.rand(5))
def test_dot_0d_2d(self): self.cmp_dot(3.0, self.rand(6,7))
def test_dot_0d_3d(self): self.cmp_dot(3.0, self.rand(8,6,7))
def test_dot_1d_0d(self): self.cmp_dot(self.rand(5), 1.1 )
def test_dot_1d_1d(self): self.cmp_dot(self.rand(5), self.rand(5))
def test_dot_1d_2d(self): self.cmp_dot(self.rand(6), self.rand(6,7))
def test_dot_1d_3d(self): self.cmp_dot(self.rand(6), self.rand(8,6,7))
def test_dot_2d_0d(self): self.cmp_dot(self.rand(5,6), 1.0)
def test_dot_2d_1d(self): self.cmp_dot(self.rand(5,6), self.rand(6))
def test_dot_2d_2d(self): self.cmp_dot(self.rand(5,6), self.rand(6,7))
def test_dot_2d_3d(self): self.cmp_dot(self.rand(5,6), self.rand(8,6,7))
def test_dot_3d_0d(self): self.cmp_dot(self.rand(4,5,6), 1.0)
def test_dot_3d_1d(self): self.cmp_dot(self.rand(4,5,6), self.rand(6))
def test_dot_3d_2d(self): self.cmp_dot(self.rand(4,5,6), self.rand(6,7))
def test_dot_3d_3d(self): self.cmp_dot(self.rand(4,5,6), self.rand(8,6,7))
def test_dot_0d_0d_(self): self.cmp_dot_comp(1.1, 2.2)
def test_dot_0d_1d_(self): self.cmp_dot_comp(1.1, self.rand(5))
def test_dot_0d_2d_(self): self.cmp_dot_comp(3.0, self.rand(6,7))
def test_dot_0d_3d_(self): self.cmp_dot_comp(3.0, self.rand(8,6,7))
def test_dot_1d_0d_(self): self.cmp_dot_comp(self.rand(5), 1.1 )
def test_dot_1d_1d_(self): self.cmp_dot_comp(self.rand(5), self.rand(5))
def test_dot_1d_2d_(self): self.cmp_dot_comp(self.rand(6), self.rand(6,7))
def test_dot_1d_3d_(self): self.cmp_dot_comp(self.rand(6), self.rand(8,6,7))
def test_dot_2d_0d_(self): self.cmp_dot_comp(self.rand(5,6), 1.0)
def test_dot_2d_1d_(self): self.cmp_dot_comp(self.rand(5,6), self.rand(6))
def test_dot_2d_2d_(self): self.cmp_dot_comp(self.rand(5,6), self.rand(6,7))
def test_dot_2d_3d_(self): self.cmp_dot_comp(self.rand(5,6), self.rand(8,6,7))
def test_dot_3d_0d_(self): self.cmp_dot_comp(self.rand(4,5,6), 1.0)
def test_dot_3d_1d_(self): self.cmp_dot_comp(self.rand(4,5,6), self.rand(6))
def test_dot_3d_2d_(self): self.cmp_dot_comp(self.rand(4,5,6), self.rand(6,7))
def test_dot_3d_3d_(self): self.cmp_dot_comp(self.rand(4,5,6), self.rand(8,6,7))
def test_dot_fail_1_1(self):
x = numpy.random.rand(5)
y = numpy.random.rand(6)
try:
z = dot(x,y)
except ValueError, e:
self.failUnless(str(e) == 'objects are not aligned', e)
return
self.fail()
def test_dot_fail_1_2(self):
x = numpy.random.rand(5)
y = numpy.random.rand(6,4)
try:
z = dot(x,y)
except ValueError, e:
self.failUnless(str(e) == 'objects are not aligned', e)
return
self.fail()
def test_dot_fail_1_3(self):
x = numpy.random.rand(5)
y = numpy.random.rand(6,4,7)
try:
z = dot(x,y)
except ValueError, e:
self.failUnless(str(e) == 'objects are not aligned', e)
return
self.fail()
def test_dot_fail_2_1(self):
x = numpy.random.rand(5,4)
y = numpy.random.rand(6)
try:
z = dot(x,y)
except ValueError, e:
self.failUnless(str(e) == 'objects are not aligned', e)
return
self.fail()
def test_dot_fail_2_2(self):
x = numpy.random.rand(5,4)
y = numpy.random.rand(6,7)
try:
z = dot(x,y)
except ValueError, e:
self.failUnless(str(e) == 'objects are not aligned', e)
return
self.fail()
def test_dot_fail_2_3(self):
x = numpy.random.rand(5,4)
y = numpy.random.rand(6,7,8)
try:
z = dot(x,y)
except ValueError, e:
self.failUnless(str(e) == 'objects are not aligned', e)
return
self.fail()
def test_dot_fail_3_1(self):
x = numpy.random.rand(5,4,3)
y = numpy.random.rand(6)
try:
z = dot(x,y)
except ValueError, e:
self.failUnless(str(e) == 'objects are not aligned', e)
return
self.fail()
def test_dot_fail_3_2(self):
x = numpy.random.rand(5,4,3)
y = numpy.random.rand(6,7)
try:
z = dot(x,y)
except ValueError, e:
self.failUnless(str(e) == 'objects are not aligned', e)
return
self.fail()
def test_dot_fail_3_3(self):
x = numpy.random.rand(5,4,3)
y = numpy.random.rand(6,7,8)
try:
z = dot(x,y)
except ValueError, e:
self.failUnless(str(e) == 'objects are not aligned', e)
return
self.fail()
class gemm(omega_op):
def destroy_map(self):
return {self.out:[self.inputs[0]]}
def impl(z, a, x, y, b):
if b == 0.0:
if a == 1.0:
z[:] = numpy.dot(x,y)
elif a == -1.0:
z[:] = -numpy.dot(x,y)
else:
z[:] = a * numpy.dot(x,y)
elif b == 1.0:
if a == 1.0:
z += numpy.dot(x,y)
elif a == -1.0:
z -= numpy.dot(x,y)
else:
z += a * numpy.dot(x,y)
else:
z *= b
z += a * numpy.dot(x,y)
return z[:]
def grad(z, a, x, y, b, gz):
raise NotImplemented
def refresh(self, alloc = False):
z,a,x,y,b = self.inputs
self.out.shape = z.shape
self.out.dtype = z.dtype
if alloc:
self.out.data = z.data
def c_support_code(self):
return blas.cblas_header_text()
def c_libs(self):
return blas.ldflags()
def c_impl((_zin, _a, _x, _y, _b), (_z,)):
check_ab = """
{
if ((_a->descr->type_num != PyArray_DOUBLE)
&& (_a->descr->type_num != PyArray_FLOAT))
goto _dot_execute_fallback;
if ((_b->descr->type_num != PyArray_DOUBLE)
&& (_b->descr->type_num != PyArray_FLOAT))
goto _dot_execute_fallback;
}
"""
return blas.gemm_code( check_ab,
'(_a->descr->type_num == PyArray_FLOAT) ? (REAL)(((float*)_a->data)[0]) : (REAL)(((double*)_a->data)[0])',
'(_b->descr->type_num == PyArray_FLOAT) ? (REAL)(((float*)_b->data)[0]) : (REAL)(((double*)_b->data)[0])')
if __name__ == '__main__':
unittest.main()
...@@ -53,6 +53,20 @@ def verify_grad(testcase, op_cls, pt, n_tests=1, rng=numpy.random, eps=0.0000001 ...@@ -53,6 +53,20 @@ def verify_grad(testcase, op_cls, pt, n_tests=1, rng=numpy.random, eps=0.0000001
verify_grad.E_grad = 'gradient error exceeded tolerance' verify_grad.E_grad = 'gradient error exceeded tolerance'
#useful mostly for unit tests
def _approx_eq(a,b,eps=1.0e-9):
a = numpy.asarray(a)
b = numpy.asarray(b)
if a.shape != b.shape:
if _approx_eq.debug:
print a.shape, b.shape
return False
if numpy.max(numpy.abs(a-b)) >= eps:
if _approx_eq.debug:
print a, b
return False
return True
_approx_eq.debug = 0
def check_eq(self, node_in, node_out, arg_in, arg_out): def check_eq(self, node_in, node_out, arg_in, arg_out):
fn = Function([node_in], [node_out]) fn = Function([node_in], [node_out])
...@@ -97,16 +111,16 @@ class T_argmax(unittest.TestCase): ...@@ -97,16 +111,16 @@ class T_argmax(unittest.TestCase):
n = astensor(numpy.random.rand(2,3)) n = astensor(numpy.random.rand(2,3))
try: try:
eval_outputs(argmax(n,axis=3)) eval_outputs(argmax(n,axis=3))
self.fail()
except ValueError, e: except ValueError, e:
return return
self.fail()
def test2_invalid_neg(self): def test2_invalid_neg(self):
n = astensor(numpy.random.rand(2,3)) n = astensor(numpy.random.rand(2,3))
try: try:
eval_outputs(argmax(n,axis=-3)) eval_outputs(argmax(n,axis=-3))
self.fail()
except ValueError, e: except ValueError, e:
return return
self.fail()
def test2_valid_neg(self): def test2_valid_neg(self):
n = astensor(numpy.random.rand(2,3)) n = astensor(numpy.random.rand(2,3))
v,i = eval_outputs(argmax(n,axis=-1)) v,i = eval_outputs(argmax(n,axis=-1))
...@@ -178,19 +192,20 @@ class T_subtensor(unittest.TestCase): ...@@ -178,19 +192,20 @@ class T_subtensor(unittest.TestCase):
n = astensor(numpy.ones(())) n = astensor(numpy.ones(()))
try: try:
t = n[0] t = n[0]
self.fail()
except ValueError, e: except ValueError, e:
self.failUnless(e[0] is Subtensor.e_invalid) self.failUnless(e[0] is Subtensor.e_invalid)
return
self.fail()
def test1_err_bounds(self): def test1_err_bounds(self):
n = astensor(numpy.ones(3)) n = astensor(numpy.ones(3))
t = n[7] t = n[7]
self.failUnless(t.owner.__class__ is Subtensor) self.failUnless(t.owner.__class__ is Subtensor)
try: try:
tval = eval_outputs([t]) tval = eval_outputs([t])
self.fail()
except Exception, e: except Exception, e:
if e[0] != 'index out of bounds': if e[0] != 'index out of bounds':
raise raise
self.fail()
def test1_ok_range_finite(self): def test1_ok_range_finite(self):
n = astensor(numpy.ones(3)*5) n = astensor(numpy.ones(3)*5)
t = n[0:2] t = n[0:2]
...@@ -209,9 +224,9 @@ class T_subtensor(unittest.TestCase): ...@@ -209,9 +224,9 @@ class T_subtensor(unittest.TestCase):
n = astensor(numpy.ones(1)) n = astensor(numpy.ones(1))
try: try:
t = n[0,0] t = n[0,0]
self.fail()
except ValueError, e: except ValueError, e:
self.failUnless(e[0] is Subtensor.e_invalid) self.failUnless(e[0] is Subtensor.e_invalid)
self.fail()
def test1_ok_elem(self): def test1_ok_elem(self):
n = astensor(numpy.ones(1)*5) n = astensor(numpy.ones(1)*5)
t = n[0] t = n[0]
...@@ -244,9 +259,9 @@ class T_subtensor(unittest.TestCase): ...@@ -244,9 +259,9 @@ class T_subtensor(unittest.TestCase):
self.failUnless(t.owner.__class__ is Subtensor) self.failUnless(t.owner.__class__ is Subtensor)
try: try:
tval = eval_outputs([t]) tval = eval_outputs([t])
self.fail()
except IndexError, e: except IndexError, e:
return return
self.fail()
def test2_err_bounds1(self): def test2_err_bounds1(self):
n = astensor(numpy.ones((2,3))*5) n = astensor(numpy.ones((2,3))*5)
t = n[4:5,2] t = n[4:5,2]
...@@ -356,9 +371,10 @@ class T_abs(unittest.TestCase): ...@@ -356,9 +371,10 @@ class T_abs(unittest.TestCase):
def test_badgrad(self): def test_badgrad(self):
try: try:
verify_grad(self, T_abs.AbsBadGrad, [numpy.ones(())]) verify_grad(self, T_abs.AbsBadGrad, [numpy.ones(())])
self.fail()
except Exception, e: except Exception, e:
self.failUnless(str(e) == verify_grad.E_grad, str(e)) self.failUnless(str(e) == verify_grad.E_grad, str(e))
return
self.fail()
class T_fill(unittest.TestCase): class T_fill(unittest.TestCase):
def test0(self): def test0(self):
...@@ -425,9 +441,10 @@ class T_mul(unittest.TestCase): ...@@ -425,9 +441,10 @@ class T_mul(unittest.TestCase):
try: try:
check_eq2(self, [a,b], MulElemwise(a,b).out, check_eq2(self, [a,b], MulElemwise(a,b).out,
[numpy.ones(3), numpy.ones(4)], 1.0) [numpy.ones(3), numpy.ones(4)], 1.0)
self.fail()
except ValueError, e: except ValueError, e:
self.failUnless(e[0] is tensor._assert_same_shapes.E_shape) self.failUnless(e[0] is tensor._assert_same_shapes.E_shape)
return
self.fail()
class T_div(unittest.TestCase): class T_div(unittest.TestCase):
def setUp(self): def setUp(self):
...@@ -501,5 +518,124 @@ class _testCase_matinv(unittest.TestCase): ...@@ -501,5 +518,124 @@ class _testCase_matinv(unittest.TestCase):
"""Matrix reciprocal by gradient descent""" """Matrix reciprocal by gradient descent"""
self.assertEqual(('6.10141615619', '0.00703816291711'), self.mat_reciprocal(3)) self.assertEqual(('6.10141615619', '0.00703816291711'), self.mat_reciprocal(3))
class t_dot(unittest.TestCase):
def setUp(self):
numpy.random.seed(44)
@staticmethod
def rand(*args):
return numpy.random.rand(*args)
def cmp_dot(self,x,y):
#x, y are matrices or numbers
def spec(x):
x = numpy.asarray(x)
return type(x), x.dtype, x.shape
nz = numpy.dot(x,y)
tz = eval_outputs([dot(tinit(x), tinit(y))])
self.failUnless(tz.dtype == nz.dtype)
self.failUnless(tz.shape == nz.shape)
self.failUnless(_approx_eq(nz, tz))
def test_dot_0d_0d(self): self.cmp_dot(1.1, 2.2)
def test_dot_0d_1d(self): self.cmp_dot(1.1, self.rand(5))
def test_dot_0d_2d(self): self.cmp_dot(3.0, self.rand(6,7))
def test_dot_0d_3d(self): self.cmp_dot(3.0, self.rand(8,6,7))
def test_dot_1d_0d(self): self.cmp_dot(self.rand(5), 1.1 )
def test_dot_1d_1d(self): self.cmp_dot(self.rand(5), self.rand(5))
def test_dot_1d_2d(self): self.cmp_dot(self.rand(6), self.rand(6,7))
def test_dot_1d_3d(self): self.cmp_dot(self.rand(6), self.rand(8,6,7))
def test_dot_2d_0d(self): self.cmp_dot(self.rand(5,6), 1.0)
def test_dot_2d_1d(self): self.cmp_dot(self.rand(5,6), self.rand(6))
def test_dot_2d_2d(self): self.cmp_dot(self.rand(5,6), self.rand(6,7))
def test_dot_2d_3d(self): self.cmp_dot(self.rand(5,6), self.rand(8,6,7))
def test_dot_3d_0d(self): self.cmp_dot(self.rand(4,5,6), 1.0)
def test_dot_3d_1d(self): self.cmp_dot(self.rand(4,5,6), self.rand(6))
def test_dot_3d_2d(self): self.cmp_dot(self.rand(4,5,6), self.rand(6,7))
def test_dot_3d_3d(self): self.cmp_dot(self.rand(4,5,6), self.rand(8,6,7))
def not_aligned(self, x, y):
z = dot(x,y)
try:
tz = eval_outputs([z])
except ValueError, e:
self.failUnless(e[0] == 'objects are not aligned', e)
return
self.fail()
def test_align_1_1(self): self.not_aligned(self.rand(5), self.rand(6))
def test_align_1_2(self): self.not_aligned(self.rand(5), self.rand(6,4))
def test_align_1_3(self): self.not_aligned(self.rand(5), self.rand(6,4,7))
def test_align_2_1(self): self.not_aligned(self.rand(5,4), self.rand(6))
def test_align_2_1(self): self.not_aligned(self.rand(5,4), self.rand(6,7))
def test_align_2_3(self): self.not_aligned(self.rand(5,4), self.rand(6,7,8))
def test_align_3_1(self): self.not_aligned(self.rand(5,4,3), self.rand(6))
def test_align_3_2(self): self.not_aligned(self.rand(5,4,3), self.rand(6,7))
def test_align_3_3(self): self.not_aligned(self.rand(5,4,3), self.rand(6,7,8))
class t_gemm(unittest.TestCase):
def setUp(self):
numpy.random.seed(44)
_approx_eq.debug = 0
@staticmethod
def _gemm(z,a,x,y,b):
assert a.shape == ()
assert b.shape == ()
return b * z + a * numpy.dot(x,y)
@staticmethod
def rand(*args):
return numpy.random.rand(*args)
def cmp(self, z, a, x, y, b):
z,a,x,y,b = [numpy.asarray(p) for p in z,a,x,y,b]
cz = z.copy()
tz,ta,tx,ty,tb = [tinit(p) for p in z,a,x,y,b]
f = Function([tz,ta,tx,ty,tb], [gemm(tz,ta,tx,ty,tb)])
new_z = f(z,a,x,y,b)
_z = self._gemm(cz, a, x, y, b)
self.failUnless(z is new_z)
#print cz, _z, z, type(cz), type(_z), type(z)
#_approx_eq.debug = 1
self.failUnless(_approx_eq(_z, z))
if a == 0.0 and b == 1.0:
return
else:
self.failIf(numpy.all(cz == z))
def test0(self): self.cmp(1., 0., 1.0, 1.0, 1.0)
def test1(self): self.cmp(2., 0., 1.0, 1.0, 0.0)
def test2(self):
try:
self.cmp(2., 1.0, [3,2,1.], [[1],[2],[3.]], 1.0)
except ValueError, e:
self.failUnless(e[0] == Gemm.E_bcast)
return
self.fail()
def test3(self): self.cmp([2.], 1.,[3,2,1.], [[1],[2],[3.]], 1.0)
def test4(self): self.cmp(self.rand(3,4), 1.0,
self.rand(3,5), self.rand(5,4), 0.0)
def test5(self): self.cmp(self.rand(3,4), 1.0,
self.rand(3,5), self.rand(5,4), 1.0)
def test6(self): self.cmp(self.rand(3,4), 1.0,
self.rand(3,5), self.rand(5,4), -1.0)
def test7(self): self.cmp(self.rand(3,4), 0.0,
self.rand(3,5), self.rand(5,4), 0.0)
def test8(self): self.cmp(self.rand(3,4), 0.0,
self.rand(3,5), self.rand(5,4), 0.6)
def test9(self): self.cmp(self.rand(3,4), 0.0,
self.rand(3,5), self.rand(5,4), -1.0)
def test10(self):
_approx_eq.debug = 1
self.cmp(self.rand(3,4), -1.0, self.rand(3,5), self.rand(5,4), 0.0)
def test11(self): self.cmp(self.rand(3,4), -1.0,
self.rand(3,5), self.rand(5,4), 1.0)
def test12(self): self.cmp(self.rand(3,4), -1.0,
self.rand(3,5), self.rand(5,4), -1.0)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
差异被折叠。
...@@ -4,10 +4,9 @@ def test_root_dir(): ...@@ -4,10 +4,9 @@ def test_root_dir():
suite = None suite = None
filenames = os.listdir('.') filenames = os.listdir('.')
for filename in filenames: for filename in filenames:
if filename[-3:] == '.py': if filename[-3:] == '.py' and filename[0:5] == '_test':
modname = filename[:-3]
if modname in ['__init__', 'autotest']: continue
#print >>sys.stderr, 'Loading', modname #print >>sys.stderr, 'Loading', modname
modname = filename[0:-3]
tests = unittest.TestLoader().loadTestsFromModule(__import__(modname)) tests = unittest.TestLoader().loadTestsFromModule(__import__(modname))
if tests.countTestCases() > 0: if tests.countTestCases() > 0:
print >>sys.stderr, 'Testing', modname print >>sys.stderr, 'Testing', modname
......
import os, sys import os, sys
from gof import PatternOptimizer as pattern_opt, OpSubOptimizer as op_sub
import scipy.weave as weave import scipy.weave as weave
""" """
......
from core import *
import gof
from gof import PatternOptimizer as pattern_opt, OpSubOptimizer as op_sub
""" """
This variable is used in compile.prog as the optimizer for all programs built This variable is used in compile.prog as the optimizer for all programs built
using either compile.single, compile.to_func, and compile.prog. using either compile.single, compile.to_func, and compile.prog.
"""
def optimizer(lst): if 0:
begin = gof.SeqOptimizer([]) def optimizer(lst):
end = gof.SeqOptimizer([gof.DummyRemover]) begin = gof.SeqOptimizer([])
seq_opt = gof.SeqOptimizer(begin + lst + end) end = gof.SeqOptimizer([gof.DummyRemover])
return gof.PythonOpt(gof.MergeOptMerge(seq_opt)) seq_opt = gof.SeqOptimizer(begin + lst + end)
return gof.PythonOpt(gof.MergeOptMerge(seq_opt))
if 0: if 0:
optimizer_begin = gof.SeqOptimizer([opt for name, opt in [ optimizer_begin = gof.SeqOptimizer([opt for name, opt in [
...@@ -34,3 +29,4 @@ if 0: ...@@ -34,3 +29,4 @@ if 0:
(iadd_elemwise, 'y', 'x'))]]]) (iadd_elemwise, 'y', 'x'))]]])
# ['remove_copies', gof.OpRemover(array_copy)], # ['remove_copies', gof.OpRemover(array_copy)],
# [None, gof.DummyRemover] # has to be at the end # [None, gof.DummyRemover] # has to be at the end
"""
...@@ -3,13 +3,11 @@ import unittest ...@@ -3,13 +3,11 @@ import unittest
import numpy import numpy
from scipy import sparse from scipy import sparse
import gof.lib import gof
import core
import grad
# Wrapper type # Wrapper type
class SparseR(core.ResultBase): class SparseR(gof.ResultBase):
""" """
Attribute: Attribute:
format - a subclass of sparse.spmatrix indicating self.data.__class__ format - a subclass of sparse.spmatrix indicating self.data.__class__
......
...@@ -690,6 +690,101 @@ pow = _scalar_switch(pow_elemwise, pow_scalar_r, pow_scalar_l) ...@@ -690,6 +690,101 @@ pow = _scalar_switch(pow_elemwise, pow_scalar_r, pow_scalar_l)
pow_inplace = _scalar_switch(pow_elemwise_inplace, pow_scalar_r_inplace) pow_inplace = _scalar_switch(pow_elemwise_inplace, pow_scalar_r_inplace)
#########################
# Linalg : Dot
#########################
class Dot(_Op):
nin=2
nout=1
@staticmethod
def broadcastable_rule(bx,by):
if len(bx) == 0: # x is a scalar
rval = by
else:
if len(by) >= 2: #y is a matrix or tensor
rval = bx[:-1] + by[:-2] + by[-1:]
elif len(by)==1: #y is vector
rval = bx[:-1]
else: #y is a scalar
rval = bx
return rval
def propagate_broadcastable(self, bx, by):
return [self.broadcastable_rule(bx,by)]
def impl(self, x, y):
return numpy.dot(x, y)
def grad(self, (x, y), gz):
return dot(gz, y.T), dot(x.T, gz)
if 0:
def c_support_code(self):
return blas.cblas_header_text()
def c_libs(self):
return blas.ldflags()
def c_impl(self, (_x, _y), (_z, )):
return blas.gemm_code('', '1.0', '0.0')
dot = _constructor(Dot)
class Gemm(_Op):
nin=5
nout=1
E_bcast = 'incompatible broadcastable flags'
def destroy_map(self):
return {self.out:[self.inputs[0]]}
def propagate_broadcastable(self, bz, ba, bx, by, bb):
if len(bz) != len(Dot.broadcastable_rule(bx,by)):
raise ValueError(Gemm.E_bcast, bz, bx, by)
return [bz]
def impl(self, z, a, x, y, b):
assert a.shape == ()
assert b.shape == ()
if z.shape == ():
z.itemset(z*a + b*numpy.dot(x,y))
return z
else:
if b == 0.0:
if a == 1.0:
z[:] = numpy.dot(x,y)
elif a == -1.0:
z[:] = -numpy.dot(x,y)
else:
z[:] = a * numpy.dot(x,y)
elif b == 1.0:
if a == 1.0:
z += numpy.dot(x,y)
elif a == -1.0:
z -= numpy.dot(x,y)
else:
z += a * numpy.dot(x,y)
else:
z *= b
z += a * numpy.dot(x,y)
return z
def grad(self, (z, a, x, y, b), gz):
raise NotImplementedError()
if 0:
def c_support_code(self):
return blas.cblas_header_text()
def c_libs(self):
return blas.ldflags()
def c_impl((_zin, _a, _x, _y, _b), (_z,)):
check_ab = """
{
if ((_a->descr->type_num != PyArray_DOUBLE)
&& (_a->descr->type_num != PyArray_FLOAT))
goto _dot_execute_fallback;
if ((_b->descr->type_num != PyArray_DOUBLE)
&& (_b->descr->type_num != PyArray_FLOAT))
goto _dot_execute_fallback;
}
"""
return blas.gemm_code( check_ab,
'(_a->descr->type_num == PyArray_FLOAT) ? (REAL)(((float*)_a->data)[0]) : (REAL)(((double*)_a->data)[0])',
'(_b->descr->type_num == PyArray_FLOAT) ? (REAL)(((float*)_b->data)[0]) : (REAL)(((double*)_b->data)[0])')
gemm = _constructor(Gemm)
if 0: if 0:
########################## ##########################
# Comparisons # Comparisons
......
from gof import Op, utils, Destroyer, Viewer
import gof.op
from tensor import *
###########################
#### Binary Operations ####
###########################
#########
## Dot ##
#########
class Dot(TensorOp):
@staticmethod
def _output_shape(xshape, yshape):
# This describes the logic to calculate numpy.dot(x, y).shape
# given x.shape and y.shape
if len(xshape) == 0: # x is a scalar
shape = yshape
else:
if len(yshape) >= 2: #y is a matrix or tensor
assert xshape[-1] == yshape[-2]
shape = tuple(xshape[:-1]+ yshape[:-2]+yshape[-1:])
elif len(yshape)==1: #y is vector
assert xshape[-1] == yshape[-1]
shape = tuple(xshape[:-1])
else: #y is a scalar
shape = xshape
return shape
def impl(self, x, y):
return numpy.dot(x, y)
def grad(self, (x, y), gz):
return dot(gz, transpose(y)), dot(transpose(x), gz)
def propagate_broadcastable(self, x, y):
assert len(x) == 2 and len(x) == len(y)
return [(x[0], y[1])]
def c_support_code(self):
return blas.cblas_header_text()
def c_libs(self):
return blas.ldflags()
def c_impl(self, (_x, _y), (_z, )):
return blas.gemm_code('', '1.0', '0.0')
############
## Others ##
############
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论