提交 76f2757a authored 作者: Frederic's avatar Frederic

pep8

上级 0c0059e9
#from nose.plugins.skip import SkipTest #from nose.plugins.skip import SkipTest
#import traceback #import traceback
import itertools, sys import itertools
import sys
import theano.tensor as T import theano.tensor as T
from theano import tensor from theano import tensor
from theano.gof.python25 import product as itertools_product from theano.gof.python25 import product as itertools_product
...@@ -40,21 +41,27 @@ mode_blas_opt = theano.compile.get_default_mode().including( ...@@ -40,21 +41,27 @@ mode_blas_opt = theano.compile.get_default_mode().including(
'BlasOpt', 'specialize', 'InplaceBlasOpt') 'BlasOpt', 'specialize', 'InplaceBlasOpt')
mode_blas_opt = mode_blas_opt.excluding('c_blas') mode_blas_opt = mode_blas_opt.excluding('c_blas')
def test_dot_eq(): def test_dot_eq():
assert T.Dot() == T.Dot() assert T.Dot() == T.Dot()
class t_gemm(TestCase): class t_gemm(TestCase):
"""This test suite is supposed to establish that gemm works as it is supposed to.""" """This test suite is supposed to establish that gemm works as it
is supposed to.
"""
def setUp(self): def setUp(self):
unittest_tools.seed_rng() unittest_tools.seed_rng()
_approx_eq.debug = 0 _approx_eq.debug = 0
Gemm.debug = False Gemm.debug = False
@staticmethod @staticmethod
def _gemm(z,a,x,y,b): def _gemm(z, a, x, y, b):
assert a.shape == () assert a.shape == ()
assert b.shape == () assert b.shape == ()
return b * z + a * numpy.dot(x,y) return b * z + a * numpy.dot(x, y)
@staticmethod @staticmethod
def rand(*args): def rand(*args):
return numpy.random.rand(*args) return numpy.random.rand(*args)
...@@ -66,13 +73,17 @@ class t_gemm(TestCase): ...@@ -66,13 +73,17 @@ class t_gemm(TestCase):
x = numpy.asarray(x_, dtype=dtype) x = numpy.asarray(x_, dtype=dtype)
y = numpy.asarray(y_, dtype=dtype) y = numpy.asarray(y_, dtype=dtype)
b = numpy.asarray(b_, dtype=dtype) b = numpy.asarray(b_, dtype=dtype)
def cmp_linker(z, a, x, y, b, l): def cmp_linker(z, a, x, y, b, l):
z,a,x,y,b = [numpy.asarray(p) for p in z,a,x,y,b] z, a, x, y, b = [numpy.asarray(p) for p in z, a, x, y, b]
z_orig = z.copy() z_orig = z.copy()
tz,ta,tx,ty,tb = [as_tensor_variable(p).type() for p in z,a,x,y,b] tz, ta, tx, ty, tb = [as_tensor_variable(p).type()
for p in z, a, x, y, b]
f = inplace_func([tz,ta,tx,ty,tb], gemm_inplace(tz,ta,tx,ty,tb), mode=compile.Mode(optimizer = None, linker = l)) f = inplace_func([tz, ta, tx, ty, tb],
new_z = f(z,a,x,y,b) gemm_inplace(tz, ta, tx, ty, tb),
mode=compile.Mode(optimizer=None, linker=l))
new_z = f(z, a, x, y, b)
z_after = self._gemm(z_orig, a, x, y, b) z_after = self._gemm(z_orig, a, x, y, b)
#print z_orig, z_after, z, type(z_orig), type(z_after), type(z) #print z_orig, z_after, z, type(z_orig), type(z_after), type(z)
...@@ -81,14 +92,15 @@ class t_gemm(TestCase): ...@@ -81,14 +92,15 @@ class t_gemm(TestCase):
if a == 0.0 and b == 1.0: if a == 0.0 and b == 1.0:
return return
elif z_orig.size == 0: elif z_orig.size == 0:
self.assertTrue(z.size==0) self.assertTrue(z.size == 0)
else: else:
self.assertFalse(numpy.all(z_orig == z)) self.assertFalse(numpy.all(z_orig == z))
cmp_linker(copy(z), a, x, y, b, 'c|py') cmp_linker(copy(z), a, x, y, b, 'c|py')
cmp_linker(copy(z), a, x, y, b, 'py') cmp_linker(copy(z), a, x, y, b, 'py')
if config.blas.ldflags and not dtype.startswith("complex"): if config.blas.ldflags and not dtype.startswith("complex"):
# If blas.ldflags is equal to '', the C code will not be generated # If blas.ldflags is equal to '', the C code will not
# be generated
cmp_linker(copy(z), a, x, y, b, 'c') cmp_linker(copy(z), a, x, y, b, 'c')
def test0a(self): def test0a(self):
...@@ -110,99 +122,123 @@ class t_gemm(TestCase): ...@@ -110,99 +122,123 @@ class t_gemm(TestCase):
def test2(self): def test2(self):
try: try:
self.cmp(2., 1.0, [3,2,1.], [[1],[2],[3.]], 1.0) self.cmp(2., 1.0, [3, 2, 1.], [[1], [2], [3.]], 1.0)
except TypeError, e: except TypeError, e:
self.assertTrue(e[0] == Gemm.E_rank) self.assertTrue(e[0] == Gemm.E_rank)
return return
self.fail() self.fail()
def test4(self): def test4(self):
self.cmp(self.rand(3,4), 1.0, self.rand(3,5), self.rand(5,4), 0.0) self.cmp(self.rand(3, 4), 1.0, self.rand(3, 5), self.rand(5, 4), 0.0)
def test5(self): self.cmp(self.rand(3,4), 1.0,
self.rand(3,5), self.rand(5,4), 1.0) def test5(self):
def test6(self): self.cmp(self.rand(3,4), 1.0, self.cmp(self.rand(3, 4), 1.0,
self.rand(3,5), self.rand(5,4), -1.0) self.rand(3, 5), self.rand(5, 4), 1.0)
def test7(self): self.cmp(self.rand(3,4), 0.0,
self.rand(3,5), self.rand(5,4), 0.0) def test6(self):
def test8(self): self.cmp(self.rand(3,4), 0.0, self.cmp(self.rand(3, 4), 1.0,
self.rand(3,5), self.rand(5,4), 0.6) self.rand(3, 5), self.rand(5, 4), -1.0)
def test9(self): self.cmp(self.rand(3,4), 0.0,
self.rand(3,5), self.rand(5,4), -1.0) def test7(self):
self.cmp(self.rand(3, 4), 0.0,
self.rand(3, 5), self.rand(5, 4), 0.0)
def test8(self):
self.cmp(self.rand(3, 4), 0.0,
self.rand(3, 5), self.rand(5, 4), 0.6)
def test9(self):
self.cmp(self.rand(3, 4), 0.0,
self.rand(3, 5), self.rand(5, 4), -1.0)
def test10(self): def test10(self):
_approx_eq.debug = 1 _approx_eq.debug = 1
self.cmp(self.rand(3,4), -1.0, self.rand(3,5), self.rand(5,4), 0.0) self.cmp(self.rand(3, 4), -1.0, self.rand(3, 5), self.rand(5, 4), 0.0)
def test11(self): self.cmp(self.rand(3,4), -1.0,
self.rand(3,5), self.rand(5,4), 1.0) def test11(self):
def test12(self): self.cmp(self.rand(3,4), -1.0, self.cmp(self.rand(3, 4), -1.0,
self.rand(3,5), self.rand(5,4), -1.0) self.rand(3, 5), self.rand(5, 4), 1.0)
def test12(self):
self.cmp(self.rand(3, 4), -1.0,
self.rand(3, 5), self.rand(5, 4), -1.0)
def test_shape_0(self): def test_shape_0(self):
self.cmp(self.rand(0,4), -1.0, self.rand(0,5), self.rand(5,4), -1.0) self.cmp(self.rand(0, 4), -1.0, self.rand(0, 5), self.rand(5, 4), -1.0)
self.cmp(self.rand(3,0), -1.0, self.rand(3,5), self.rand(5,0), -1.0) self.cmp(self.rand(3, 0), -1.0, self.rand(3, 5), self.rand(5, 0), -1.0)
self.cmp(self.rand(3,4), -1.0, self.rand(3,0), self.rand(0,4), -1.0) self.cmp(self.rand(3, 4), -1.0, self.rand(3, 0), self.rand(0, 4), -1.0)
self.cmp(self.rand(0,0), -1.0, self.rand(0,5), self.rand(5,0), -1.0) self.cmp(self.rand(0, 0), -1.0, self.rand(0, 5), self.rand(5, 0), -1.0)
self.cmp(self.rand(0,0), -1.0, self.rand(0,0), self.rand(0,0), -1.0) self.cmp(self.rand(0, 0), -1.0, self.rand(0, 0), self.rand(0, 0), -1.0)
def test_factorised_scalar(self): def test_factorised_scalar(self):
a=T.dmatrix() a = T.dmatrix()
b=T.dmatrix() b = T.dmatrix()
c=T.dmatrix() c = T.dmatrix()
s=theano.shared(numpy.zeros((5,5))) s = theano.shared(numpy.zeros((5, 5)))
lr1=T.constant(0.01).astype('float64') lr1 = T.constant(0.01).astype('float64')
lr2=T.constant(2).astype('float64') lr2 = T.constant(2).astype('float64')
l2_reg=T.constant(0.0001).astype('float64') l2_reg = T.constant(0.0001).astype('float64')
#test constant merge with gemm #test constant merge with gemm
f = theano.function([a,b],updates={s:lr1*T.dot(a,b)+l2_reg*lr2*s},mode=mode_not_fast_compile).maker.env.toposort() f = theano.function([a, b], updates={s: lr1 * T.dot(a, b) +
l2_reg * lr2 * s},
mode=mode_not_fast_compile).maker.env.toposort()
#[Gemm{inplace}(<TensorType(float64, matrix)>, 0.01, <TensorType(float64, matrix)>, <TensorType(float64, matrix)>, 2e-06)] #[Gemm{inplace}(<TensorType(float64, matrix)>, 0.01, <TensorType(float64, matrix)>, <TensorType(float64, matrix)>, 2e-06)]
assert len(f)==1 assert len(f) == 1
assert f[0].op==gemm_inplace assert f[0].op == gemm_inplace
#test factored scalar with merge #test factored scalar with merge
f = theano.function([a,b],updates={s:lr1*(T.dot(a,b)-l2_reg*s)},mode=mode_not_fast_compile).maker.env.toposort() f = theano.function([a, b], updates={s: lr1 * (T.dot(a, b) -
l2_reg * s)},
mode=mode_not_fast_compile).maker.env.toposort()
#[Gemm{inplace}(<TensorType(float64, matrix)>, 0.01, <TensorType(float64, matrix)>, <TensorType(float64, matrix)>, -2e-06)] #[Gemm{inplace}(<TensorType(float64, matrix)>, 0.01, <TensorType(float64, matrix)>, <TensorType(float64, matrix)>, -2e-06)]
assert len(f)==1 assert len(f) == 1
assert f[0].op==gemm_inplace assert f[0].op == gemm_inplace
#test factored scalar with merge and neg #test factored scalar with merge and neg
f = theano.function([a,b],updates={s:s-lr1*(s*.0002+T.dot(a,b))},mode=mode_not_fast_compile).maker.env.toposort() f = theano.function([a,b],updates={s:s-lr1*(s*.0002+T.dot(a,b))},
mode=mode_not_fast_compile).maker.env.toposort()
#[Gemm{inplace}(<TensorType(float64, matrix)>, -0.01, <TensorType(float64, matrix)>, <TensorType(float64, matrix)>, 0.999998)] #[Gemm{inplace}(<TensorType(float64, matrix)>, -0.01, <TensorType(float64, matrix)>, <TensorType(float64, matrix)>, 0.999998)]
assert len(f)==1 assert len(f) == 1
assert f[0].op==gemm_inplace assert f[0].op == gemm_inplace
def test_destroy_map0(self): def test_destroy_map0(self):
"""test that only first input can be overwritten""" """test that only first input can be overwritten"""
Z = as_tensor_variable(self.rand(2,2)) Z = as_tensor_variable(self.rand(2, 2))
try: try:
gemm_inplace(Z, 1.0, Z, Z, 1.0) gemm_inplace(Z, 1.0, Z, Z, 1.0)
except InconsistencyError, e: except InconsistencyError, e:
if e[0] == Gemm.E_z_uniq: if e[0] == Gemm.E_z_uniq:
return return
self.fail() self.fail()
def test_destroy_map1(self): def test_destroy_map1(self):
"""test that only first input can be overwritten""" """test that only first input can be overwritten"""
Z = as_tensor_variable(self.rand(2,2)) Z = as_tensor_variable(self.rand(2, 2))
A = as_tensor_variable(self.rand(2,2)) A = as_tensor_variable(self.rand(2, 2))
try: try:
gemm_inplace(Z, 1.0, A, inplace.transpose_inplace(Z), 1.0) gemm_inplace(Z, 1.0, A, inplace.transpose_inplace(Z), 1.0)
except InconsistencyError, e: except InconsistencyError, e:
if e[0] == Gemm.E_z_uniq: if e[0] == Gemm.E_z_uniq:
return return
self.fail() self.fail()
def test_destroy_map2(self): def test_destroy_map2(self):
"""test that only first input can be overwritten""" """test that only first input can be overwritten"""
Z = as_tensor_variable(self.rand(2,2)) Z = as_tensor_variable(self.rand(2, 2))
A = as_tensor_variable(self.rand(2,2)) A = as_tensor_variable(self.rand(2, 2))
try: try:
gemm_inplace(Z, 1.0, inplace.transpose_inplace(Z), A, 1.0) gemm_inplace(Z, 1.0, inplace.transpose_inplace(Z), A, 1.0)
except InconsistencyError, e: except InconsistencyError, e:
if e[0] == Gemm.E_z_uniq: if e[0] == Gemm.E_z_uniq:
return return
self.fail() self.fail()
def test_destroy_map3(self): def test_destroy_map3(self):
"""test that only first input can be overwritten""" """test that only first input can be overwritten"""
Z = as_tensor_variable(self.rand(2,2)) Z = as_tensor_variable(self.rand(2, 2))
A = as_tensor_variable(self.rand(2,2)) A = as_tensor_variable(self.rand(2, 2))
try: try:
gemm_inplace(Z, 1.0, Z, A, 1.0) gemm_inplace(Z, 1.0, Z, A, 1.0)
except InconsistencyError, e: except InconsistencyError, e:
...@@ -212,8 +248,8 @@ class t_gemm(TestCase): ...@@ -212,8 +248,8 @@ class t_gemm(TestCase):
def test_destroy_map4(self): def test_destroy_map4(self):
"""test that dot args can be aliased""" """test that dot args can be aliased"""
Z = shared(self.rand(2,2)) Z = shared(self.rand(2, 2))
A = shared(self.rand(2,2)) A = shared(self.rand(2, 2))
one = T.constant(1.0).astype(Z.dtype) one = T.constant(1.0).astype(Z.dtype)
f = inplace_func([], gemm_inplace(Z, one, A, A, one)) f = inplace_func([], gemm_inplace(Z, one, A, A, one))
f() f()
...@@ -222,26 +258,32 @@ class t_gemm(TestCase): ...@@ -222,26 +258,32 @@ class t_gemm(TestCase):
def test_transposes(self): def test_transposes(self):
# three square matrices which are not contiguous # three square matrices which are not contiguous
A = self.rand(4,5)[:,:4] A = self.rand(4, 5)[:, :4]
B = self.rand(4,5)[:,:4] B = self.rand(4, 5)[:, :4]
C = self.rand(4,5)[:,:4] C = self.rand(4, 5)[:, :4]
def t(z,x,y,a=1.0, b=0.0,l='c|py',dt='float64'): def t(z, x, y, a=1.0, b=0.0, l='c|py', dt='float64'):
z,a,x,y,b = [theano._asarray(p,dtype=dt) for p in z,a,x,y,b] z, a, x, y, b = [theano._asarray(p, dtype=dt)
for p in z, a, x, y, b]
z_orig = z.copy() z_orig = z.copy()
z_after = self._gemm(z, a, x, y, b) z_after = self._gemm(z, a, x, y, b)
tz,ta,tx,ty,tb = [shared(p) for p in z,a,x,y,b] tz, ta, tx, ty, tb = [shared(p) for p in z, a, x, y, b]
#f = inplace_func([tz,ta,tx,ty,tb], gemm_inplace(tz,ta,tx,ty,tb), mode = compile.Mode(optimizer = None, linker=l)) #f = inplace_func([tz,ta,tx,ty,tb], gemm_inplace(tz,ta,tx,ty,tb),
# mode = compile.Mode(optimizer = None, linker=l))
#f(z, a, x, y, b) #f(z, a, x, y, b)
f = inplace_func([], gemm_inplace(tz,ta,tx,ty,tb), mode = compile.Mode(optimizer = None, linker=l)) f = inplace_func([], gemm_inplace(tz, ta, tx, ty, tb),
mode=compile.Mode(optimizer=None, linker=l))
f() f()
self.assertTrue(_approx_eq(z_after, tz.get_value(borrow=True)), (z_orig, z_after, z, z_after - z)) self.assertTrue(_approx_eq(z_after, tz.get_value(borrow=True)),
(z_orig, z_after, z, z_after - z))
f() f()
self.assertTrue(_approx_eq(z_after, tz.get_value(borrow=True)), (z_orig, z_after, z, z_after - z)) self.assertTrue(_approx_eq(z_after, tz.get_value(borrow=True)),
(z_orig, z_after, z, z_after - z))
f() f()
self.assertTrue(_approx_eq(z_after, tz.get_value(borrow=True)), (z_orig, z_after, z, z_after - z)) self.assertTrue(_approx_eq(z_after, tz.get_value(borrow=True)),
(z_orig, z_after, z, z_after - z))
#tz.value *= 0 # clear z's value #tz.value *= 0 # clear z's value
y_T = ty.get_value(borrow=True).T y_T = ty.get_value(borrow=True).T
...@@ -252,7 +294,7 @@ class t_gemm(TestCase): ...@@ -252,7 +294,7 @@ class t_gemm(TestCase):
# test that the transposed version of multiplication gives same answer # test that the transposed version of multiplication gives same answer
self.assertTrue(_approx_eq(z_after, tz.get_value(borrow=True).T)) self.assertTrue(_approx_eq(z_after, tz.get_value(borrow=True).T))
t(C,A,B) t(C, A, B)
t(C.T, A, B) t(C.T, A, B)
t(C, A.T, B, dt='float32') t(C, A.T, B, dt='float32')
t(C, A, B.T) t(C, A, B.T)
...@@ -261,15 +303,15 @@ class t_gemm(TestCase): ...@@ -261,15 +303,15 @@ class t_gemm(TestCase):
t(C.T, A, B.T) t(C.T, A, B.T)
t(C.T, A.T, B.T, dt='float32') t(C.T, A.T, B.T, dt='float32')
t(C, A[:,:2], B[:2, :]) t(C, A[:, :2], B[:2, :])
t(C.T, A[:,:2], B[:2, :], dt='float32') t(C.T, A[:, :2], B[:2, :], dt='float32')
t(C, A[:2,:].T, B[:2, :]) t(C, A[:2, :].T, B[:2, :])
t(C.T, A[:2,:].T, B[:2, :], dt='float32') t(C.T, A[:2, :].T, B[:2, :], dt='float32')
t(C, A[:2,:].T, B[:, :2].T) t(C, A[:2, :].T, B[:, :2].T)
t(C.T, A[:2,:].T, B[:, :2].T) t(C.T, A[:2, :].T, B[:, :2].T)
try: try:
t(C.T, A[:2,:], B[:, :2].T) t(C.T, A[:2, :], B[:, :2].T)
except ValueError, e: except ValueError, e:
if e[0].find('aligned') >= 0: if e[0].find('aligned') >= 0:
return return
...@@ -278,12 +320,13 @@ class t_gemm(TestCase): ...@@ -278,12 +320,13 @@ class t_gemm(TestCase):
def test_non_contiguous(self): def test_non_contiguous(self):
# Like test_transposes but with matrices without any # Like test_transposes but with matrices without any
# continuous dimension # continuous dimension
A = self.rand(4,4,3) A = self.rand(4, 4, 3)
B = self.rand(4,4,3) B = self.rand(4, 4, 3)
C = self.rand(4,4,3) C = self.rand(4, 4, 3)
def t(z, x, y, a=1.0, b=0.0, l='c|py', dt='float64'): def t(z, x, y, a=1.0, b=0.0, l='c|py', dt='float64'):
z, a, x, y, b = [theano._asarray(p, dtype=dt) for p in z, a, x, y, b] z, a, x, y, b = [theano._asarray(p, dtype=dt)
for p in z, a, x, y, b]
z_orig = z.copy() z_orig = z.copy()
z_after = numpy.zeros_like(z_orig) z_after = numpy.zeros_like(z_orig)
for i in xrange(3): for i in xrange(3):
...@@ -300,10 +343,10 @@ class t_gemm(TestCase): ...@@ -300,10 +343,10 @@ class t_gemm(TestCase):
# will create cycles, so we update by hand. # will create cycles, so we update by hand.
z_i = f_i() z_i = f_i()
z = tz.get_value(borrow=True, return_internal_type=True) z = tz.get_value(borrow=True, return_internal_type=True)
z[:,:,i] = z_i z[:, :, i] = z_i
self.assertTrue( self.assertTrue(
_approx_eq(z_after[:,:,i], _approx_eq(z_after[:, :, i],
tz.get_value(borrow=True)[:,:,i]), tz.get_value(borrow=True)[:,:,i]),
(z_orig[:,:,i], z_after[:,:,i], (z_orig[:,:,i], z_after[:,:,i],
z[:,:,i], z_after[:,:,i] - z[:,:,i])) z[:,:,i], z_after[:,:,i] - z[:,:,i]))
...@@ -329,15 +372,17 @@ class t_gemm(TestCase): ...@@ -329,15 +372,17 @@ class t_gemm(TestCase):
t(C.transpose((1,0,2)), A, B.transpose((1,0,2))) t(C.transpose((1,0,2)), A, B.transpose((1,0,2)))
t(C.transpose((1,0,2)), A.transpose((1,0,2)), B.transpose((1,0,2)), dt='float32') t(C.transpose((1,0,2)), A.transpose((1,0,2)), B.transpose((1,0,2)), dt='float32')
def test_res_is_a(): def test_res_is_a():
X,Y,Z,a,b = XYZab() X, Y, Z, a, b = XYZab()
assert not res_is_a(a, T.sqrt) assert not res_is_a(a, T.sqrt)
assert not res_is_a(a+a, T.sqrt) assert not res_is_a(a + a, T.sqrt)
assert res_is_a(T.sqrt(a+a), T.sqrt) assert res_is_a(T.sqrt(a + a), T.sqrt)
#leave the maxclients stuff untested because it requires being in an env. #leave the maxclients stuff untested because it requires being in an env.
class t_as_scalar(TestCase): class t_as_scalar(TestCase):
def test0(self): def test0(self):
"""Test that it works on scalar constants""" """Test that it works on scalar constants"""
...@@ -346,7 +391,7 @@ class t_as_scalar(TestCase): ...@@ -346,7 +391,7 @@ class t_as_scalar(TestCase):
b2 = b.dimshuffle() b2 = b.dimshuffle()
assert b2.ndim == 0 assert b2.ndim == 0
d_a = T.DimShuffle([], [])(a) d_a = T.DimShuffle([], [])(a)
d_b = T.DimShuffle([True, True, True], [0,2,1])(b) d_b = T.DimShuffle([True, True, True], [0, 2, 1])(b)
d_a2 = T.DimShuffle([], ['x', 'x', 'x'])(a) d_a2 = T.DimShuffle([], ['x', 'x', 'x'])(a)
self.assertTrue(_as_scalar(a) == a) self.assertTrue(_as_scalar(a) == a)
...@@ -359,7 +404,7 @@ class t_as_scalar(TestCase): ...@@ -359,7 +404,7 @@ class t_as_scalar(TestCase):
"""Test that it fails on nonscalar constants""" """Test that it fails on nonscalar constants"""
a = T.constant(numpy.ones(5)) a = T.constant(numpy.ones(5))
self.assertTrue(None == _as_scalar(a)) self.assertTrue(None == _as_scalar(a))
self.assertTrue(None == _as_scalar(T.DimShuffle([False], [0,'x'])(a))) self.assertTrue(None == _as_scalar(T.DimShuffle([False], [0, 'x'])(a)))
def test2(self): def test2(self):
"""Test that it works on scalar variables""" """Test that it works on scalar variables"""
...@@ -375,26 +420,35 @@ class t_as_scalar(TestCase): ...@@ -375,26 +420,35 @@ class t_as_scalar(TestCase):
"""Test that it fails on nonscalar variables""" """Test that it fails on nonscalar variables"""
a = T.dmatrix() a = T.dmatrix()
self.assertTrue(None == _as_scalar(a)) self.assertTrue(None == _as_scalar(a))
self.assertTrue(None == _as_scalar(T.DimShuffle([False, False], [0,'x', 1])(a))) self.assertTrue(None == _as_scalar(T.DimShuffle([False, False],
[0, 'x', 1])(a)))
class T_real_matrix(TestCase): class T_real_matrix(TestCase):
def test0(self): def test0(self):
self.assertTrue(_is_real_matrix(T.DimShuffle([False,False], [1, 0])(T.dmatrix()))) self.assertTrue(_is_real_matrix(T.DimShuffle([False, False],
self.assertTrue(not _is_real_matrix(T.DimShuffle([False], ['x', 0])(T.dvector()))) [1, 0])(T.dmatrix())))
self.assertTrue(not _is_real_matrix(T.DimShuffle([False],
['x', 0])
(T.dvector())))
def fail(msg): def fail(msg):
print 'FAIL', msg print 'FAIL', msg
assert False assert False
"""This test suite ensures that Gemm is inserted where it belongs, and that the resulting """This test suite ensures that Gemm is inserted where it belongs, and that the resulting
functions compute the same things as the originals.""" functions compute the same things as the originals."""
def XYZab(): def XYZab():
return T.dmatrix(), T.dmatrix(), T.dmatrix(), T.dscalar(), T.dscalar() return T.dmatrix(), T.dmatrix(), T.dmatrix(), T.dscalar(), T.dscalar()
class Failure(Exception): class Failure(Exception):
pass pass
def just_gemm(i, o, ishapes = [(4,3), (3,5), (4,5), (), ()], max_graphlen=0):
def just_gemm(i, o, ishapes=[(4, 3), (3, 5), (4, 5), (), ()], max_graphlen=0):
try: try:
f = inplace_func( f = inplace_func(
[Param(ii, mutable=True, allow_downcast=True) for ii in i], [Param(ii, mutable=True, allow_downcast=True) for ii in i],
...@@ -418,8 +472,8 @@ def just_gemm(i, o, ishapes = [(4,3), (3,5), (4,5), (), ()], max_graphlen=0): ...@@ -418,8 +472,8 @@ def just_gemm(i, o, ishapes = [(4,3), (3,5), (4,5), (), ()], max_graphlen=0):
graphlen = len(f.maker.env.toposort()) graphlen = len(f.maker.env.toposort())
if max_graphlen and (graphlen <= max_graphlen): if max_graphlen and (graphlen <= max_graphlen):
theano.printing.debugprint(f) # theano.printing.debugprint(f)
assert False, 'graphlen=%i>%i'%(graphlen, max_graphlen) assert False, 'graphlen=%i>%i' % (graphlen, max_graphlen)
rng = numpy.random.RandomState(unittest_tools.fetch_seed(234)) rng = numpy.random.RandomState(unittest_tools.fetch_seed(234))
r0 = f(*[rng.randn(*sh) for sh in ishapes]) r0 = f(*[rng.randn(*sh) for sh in ishapes])
...@@ -427,7 +481,8 @@ def just_gemm(i, o, ishapes = [(4,3), (3,5), (4,5), (), ()], max_graphlen=0): ...@@ -427,7 +481,8 @@ def just_gemm(i, o, ishapes = [(4,3), (3,5), (4,5), (), ()], max_graphlen=0):
r1 = g(*[rng.randn(*sh) for sh in ishapes]) r1 = g(*[rng.randn(*sh) for sh in ishapes])
max_abs_err = numpy.max(numpy.abs(r0[0] - r1[0])) max_abs_err = numpy.max(numpy.abs(r0[0] - r1[0]))
if max_abs_err > 1.0e-8: if max_abs_err > 1.0e-8:
raise Failure('GEMM is computing the wrong output. max_rel_err =', max_abs_err) raise Failure('GEMM is computing the wrong output. max_rel_err =',
max_abs_err)
except Failure: except Failure:
for node in f.maker.env.toposort(): for node in f.maker.env.toposort():
print 'GRAPH', node print 'GRAPH', node
...@@ -539,7 +594,7 @@ def test_gemm_canonicalize(): ...@@ -539,7 +594,7 @@ def test_gemm_canonicalize():
can = [] can = []
_gemm_canonicalize((-d) * X - (a*X + Y - b*Z*c), 1.0, can, 0) _gemm_canonicalize((-d) * X - (a*X + Y - b*Z*c), 1.0, can, 0)
print can #print can
assert can[0][0].owner.op == T.neg assert can[0][0].owner.op == T.neg
assert can[0][0].owner.inputs[0] == d assert can[0][0].owner.inputs[0] == d
assert can[0][1] == X assert can[0][1] == X
...@@ -596,12 +651,12 @@ def test_gemm_nested(): ...@@ -596,12 +651,12 @@ def test_gemm_nested():
[a * Z - b * (c*T.dot(X,Y) + d*Z)], [a * Z - b * (c*T.dot(X,Y) + d*Z)],
ishapes=[(2,3),(3,4),(2,4),(2,3),(3,4),(2,4),(),(),(),()], ishapes=[(2,3),(3,4),(2,4),(2,3),(3,4),(2,4),(),(),(),()],
max_graphlen=1) max_graphlen=1)
print "---------------------" #print "---------------------"
just_gemm([X,Y,Z,R,S,U,a,b,c,d], just_gemm([X,Y,Z,R,S,U,a,b,c,d],
[a * Z - b * (c*T.dot(X,Y) + d*Z + c*Z)], [a * Z - b * (c*T.dot(X,Y) + d*Z + c*Z)],
ishapes=[(2,3),(3,4),(2,4),(2,3),(3,4),(2,4),(),(),(),()], ishapes=[(2,3),(3,4),(2,4),(2,3),(3,4),(2,4),(),(),(),()],
max_graphlen=1) max_graphlen=1)
print "---------------------" #print "---------------------"
just_gemm([X,Y,Z,R,S,U,a,b,c,d], just_gemm([X,Y,Z,R,S,U,a,b,c,d],
[a * Z - b * (c*T.dot(X,Y) + d*Z + c*U)], [a * Z - b * (c*T.dot(X,Y) + d*Z + c*U)],
ishapes=[(2,3),(3,4),(2,4),(2,3),(3,4),(2,4),(),(),(),()], ishapes=[(2,3),(3,4),(2,4),(2,3),(3,4),(2,4),(),(),(),()],
...@@ -680,7 +735,7 @@ def test_inplace1(): ...@@ -680,7 +735,7 @@ def test_inplace1():
# with > 2 terms in the overall addition # with > 2 terms in the overall addition
f = inplace_func([X, Y, Z], f = inplace_func([X, Y, Z],
[Z + Z + T.dot(X,Y)], mode='FAST_RUN') [Z + Z + T.dot(X,Y)], mode='FAST_RUN')
theano.printing.debugprint(f) #theano.printing.debugprint(f)
# it doesn't work inplace because we didn't mark Z as mutable input # it doesn't work inplace because we didn't mark Z as mutable input
assert [n.op for n in f.maker.env.nodes] == [gemm_no_inplace] assert [n.op for n in f.maker.env.nodes] == [gemm_no_inplace]
...@@ -1527,8 +1582,8 @@ class TestBlasStrides(TestCase): ...@@ -1527,8 +1582,8 @@ class TestBlasStrides(TestCase):
f_nn = theano.function([], [], updates={a: tensor.dot(b, c)}, f_nn = theano.function([], [], updates={a: tensor.dot(b, c)},
mode=self.mode) mode=self.mode)
print 'class name:', self.__class__.__name__ #print 'class name:', self.__class__.__name__
theano.printing.debugprint(f_nn) #theano.printing.debugprint(f_nn)
f_nt = theano.function([], [], updates={a: tensor.dot(b, c_t.T)}, f_nt = theano.function([], [], updates={a: tensor.dot(b, c_t.T)},
mode=self.mode) mode=self.mode)
f_tn = theano.function([], [], updates={a: tensor.dot(b_t.T, c)}, f_tn = theano.function([], [], updates={a: tensor.dot(b_t.T, c)},
...@@ -1800,7 +1855,8 @@ class TestBlasStrides(TestCase): ...@@ -1800,7 +1855,8 @@ class TestBlasStrides(TestCase):
c.set_value(c_dev.copy()[::c_step], borrow=True) c.set_value(c_dev.copy()[::c_step], borrow=True)
a_n = (av[::a_step] a_n = (av[::a_step]
+ l * numpy.dot(bv[::b_step1, ::b_step2], cv[::c_step])) + l * numpy.dot(bv[::b_step1, ::b_step2],
cv[::c_step]))
f_n() f_n()
assert numpy.allclose(a.get_value(), a_n), (a.get_value(), a_n) assert numpy.allclose(a.get_value(), a_n), (a.get_value(), a_n)
...@@ -1818,7 +1874,6 @@ class TestBlasStrides(TestCase): ...@@ -1818,7 +1874,6 @@ class TestBlasStrides(TestCase):
self.cmp_gemv(1, (1, 0), 0) self.cmp_gemv(1, (1, 0), 0)
self.cmp_gemv(0, (0, 0), 0) self.cmp_gemv(0, (0, 0), 0)
def cmp_ger(self, a_shp, b_shp, c_shp): def cmp_ger(self, a_shp, b_shp, c_shp):
av = self.rand(*a_shp) av = self.rand(*a_shp)
bv = self.rand(b_shp) bv = self.rand(b_shp)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论