提交 74ca96f1 authored 作者: Frederic's avatar Frederic

make blas test use floatX.

上级 c0d9f466
...@@ -171,14 +171,14 @@ class t_gemm(TestCase): ...@@ -171,14 +171,14 @@ class t_gemm(TestCase):
self.cmp(self.rand(0, 0), -1.0, self.rand(0, 0), self.rand(0, 0), -1.0) self.cmp(self.rand(0, 0), -1.0, self.rand(0, 0), self.rand(0, 0), -1.0)
def test_factorised_scalar(self): def test_factorised_scalar(self):
a = T.dmatrix() a = T.matrix()
b = T.dmatrix() b = T.matrix()
c = T.dmatrix() c = T.matrix()
s = theano.shared(numpy.zeros((5, 5))) s = theano.shared(numpy.zeros((5, 5)).astype(config.floatX))
lr1 = T.constant(0.01).astype('float64') lr1 = T.constant(0.01).astype(config.floatX)
lr2 = T.constant(2).astype('float64') lr2 = T.constant(2).astype(config.floatX)
l2_reg = T.constant(0.0001).astype('float64') l2_reg = T.constant(0.0001).astype(config.floatX)
#test constant merge with gemm #test constant merge with gemm
f = theano.function([a, b], updates={s: lr1 * T.dot(a, b) + f = theano.function([a, b], updates={s: lr1 * T.dot(a, b) +
...@@ -418,7 +418,7 @@ class t_as_scalar(TestCase): ...@@ -418,7 +418,7 @@ class t_as_scalar(TestCase):
def test3(self): def test3(self):
"""Test that it fails on nonscalar variables""" """Test that it fails on nonscalar variables"""
a = T.dmatrix() a = T.matrix()
self.assertTrue(None == _as_scalar(a)) self.assertTrue(None == _as_scalar(a))
self.assertTrue(None == _as_scalar(T.DimShuffle([False, False], self.assertTrue(None == _as_scalar(T.DimShuffle([False, False],
[0, 'x', 1])(a))) [0, 'x', 1])(a)))
...@@ -427,7 +427,7 @@ class t_as_scalar(TestCase): ...@@ -427,7 +427,7 @@ class t_as_scalar(TestCase):
class T_real_matrix(TestCase): class T_real_matrix(TestCase):
def test0(self): def test0(self):
self.assertTrue(_is_real_matrix(T.DimShuffle([False, False], self.assertTrue(_is_real_matrix(T.DimShuffle([False, False],
[1, 0])(T.dmatrix()))) [1, 0])(T.matrix())))
self.assertTrue(not _is_real_matrix(T.DimShuffle([False], self.assertTrue(not _is_real_matrix(T.DimShuffle([False],
['x', 0]) ['x', 0])
(T.dvector()))) (T.dvector())))
...@@ -441,7 +441,7 @@ def fail(msg): ...@@ -441,7 +441,7 @@ def fail(msg):
"""This test suite ensures that Gemm is inserted where it belongs, and that the resulting """This test suite ensures that Gemm is inserted where it belongs, and that the resulting
functions compute the same things as the originals.""" functions compute the same things as the originals."""
def XYZab(): def XYZab():
return T.dmatrix(), T.dmatrix(), T.dmatrix(), T.dscalar(), T.dscalar() return T.matrix(), T.matrix(), T.matrix(), T.scalar(), T.scalar()
class Failure(Exception): class Failure(Exception):
...@@ -476,11 +476,16 @@ def just_gemm(i, o, ishapes=[(4, 3), (3, 5), (4, 5), (), ()], max_graphlen=0): ...@@ -476,11 +476,16 @@ def just_gemm(i, o, ishapes=[(4, 3), (3, 5), (4, 5), (), ()], max_graphlen=0):
assert False, 'graphlen=%i>%i' % (graphlen, max_graphlen) assert False, 'graphlen=%i>%i' % (graphlen, max_graphlen)
rng = numpy.random.RandomState(unittest_tools.fetch_seed(234)) rng = numpy.random.RandomState(unittest_tools.fetch_seed(234))
r0 = f(*[rng.randn(*sh) for sh in ishapes]) r0 = f(*[numpy.asarray(rng.randn(*sh), config.floatX)
for sh in ishapes])
rng = numpy.random.RandomState(unittest_tools.fetch_seed(234)) rng = numpy.random.RandomState(unittest_tools.fetch_seed(234))
r1 = g(*[rng.randn(*sh) for sh in ishapes]) r1 = g(*[numpy.asarray(rng.randn(*sh), config.floatX)
for sh in ishapes])
max_abs_err = numpy.max(numpy.abs(r0[0] - r1[0])) max_abs_err = numpy.max(numpy.abs(r0[0] - r1[0]))
if max_abs_err > 1.0e-8: eps = 1.0e-8
if config.floatX == 'float32':
eps = 1.0e-6
if max_abs_err > eps:
raise Failure('GEMM is computing the wrong output. max_rel_err =', raise Failure('GEMM is computing the wrong output. max_rel_err =',
max_abs_err) max_abs_err)
except Failure: except Failure:
...@@ -519,8 +524,8 @@ def test_gemm_opt0(): ...@@ -519,8 +524,8 @@ def test_gemm_opt0():
def test_gemm_opt_double_gemm(): def test_gemm_opt_double_gemm():
"""This is the pattern that shows up in the autoencoder""" """This is the pattern that shows up in the autoencoder"""
X,Y,Z,a,b = T.dmatrix(), T.dmatrix(), T.dmatrix(), T.dscalar(), T.dscalar() X,Y,Z,a,b = T.matrix(), T.matrix(), T.matrix(), T.scalar(), T.scalar()
R, S, c = T.dmatrix(), T.dmatrix(), T.dscalar() R, S, c = T.matrix(), T.matrix(), T.scalar()
just_gemm([X,Y,Z,a,b, R, S, c], [Z *c + a * T.dot(X,Y) + b * T.dot(R,S).T], just_gemm([X,Y,Z,a,b, R, S, c], [Z *c + a * T.dot(X,Y) + b * T.dot(R,S).T],
ishapes=[(4,3), (3,5), (4,5), (), (), (5,9), (9,4), ()]) ishapes=[(4,3), (3,5), (4,5), (), (), (5,9), (9,4), ()])
...@@ -528,7 +533,7 @@ def test_gemm_opt_double_gemm(): ...@@ -528,7 +533,7 @@ def test_gemm_opt_double_gemm():
ishapes=[(4,3), (3,5), (4,5), (), (), (5,9), (9,4), ()] ishapes=[(4,3), (3,5), (4,5), (), (), (5,9), (9,4), ()]
i = [X,Y,Z,a,b, R, S, c] i = [X,Y,Z,a,b, R, S, c]
o = [(a * T.dot(X,Y) o = [(a * T.dot(X,Y)
+ gemm_inplace(Z, b, S.T, R.T, T.constant(1.0).astype('float64')))] + gemm_inplace(Z, b, S.T, R.T, T.constant(1.0).astype(config.floatX)))]
try: try:
f = inplace_func([Param(ii, mutable=True) for ii in i],o, f = inplace_func([Param(ii, mutable=True) for ii in i],o,
mode='FAST_RUN', on_unused_input='ignore') mode='FAST_RUN', on_unused_input='ignore')
...@@ -541,11 +546,14 @@ def test_gemm_opt_double_gemm(): ...@@ -541,11 +546,14 @@ def test_gemm_opt_double_gemm():
# if node.op == gemm_inplace: raise Failure('gemm_inplace in graph') # if node.op == gemm_inplace: raise Failure('gemm_inplace in graph')
rng = numpy.random.RandomState(unittest_tools.fetch_seed(234)) rng = numpy.random.RandomState(unittest_tools.fetch_seed(234))
r0 = f(*[rng.randn(*sh) for sh in ishapes]) r0 = f(*[numpy.asarray(rng.randn(*sh), config.floatX) for sh in ishapes])
rng = numpy.random.RandomState(unittest_tools.fetch_seed(234)) rng = numpy.random.RandomState(unittest_tools.fetch_seed(234))
r1 = g(*[rng.randn(*sh) for sh in ishapes]) r1 = g(*[numpy.asarray(rng.randn(*sh), config.floatX) for sh in ishapes])
max_abs_err = numpy.max(numpy.abs(r0[0] - r1[0])) max_abs_err = numpy.max(numpy.abs(r0[0] - r1[0]))
if max_abs_err > 1.0e-8: eps = 1.0e-8
if config.floatX == 'float32':
eps = 1.0e-6
if max_abs_err > eps:
raise Failure('GEMM is computing the wrong output. max_rel_err =', max_abs_err) raise Failure('GEMM is computing the wrong output. max_rel_err =', max_abs_err)
except Failure: except Failure:
for node in f.maker.env.toposort(): for node in f.maker.env.toposort():
...@@ -554,8 +562,8 @@ def test_gemm_opt_double_gemm(): ...@@ -554,8 +562,8 @@ def test_gemm_opt_double_gemm():
def test_gemm_canonicalize(): def test_gemm_canonicalize():
X,Y,Z,a,b = T.dmatrix('X'), T.dmatrix('Y'), T.dmatrix('Z'), T.dscalar('a'), T.dscalar('b') X,Y,Z,a,b = T.matrix('X'), T.matrix('Y'), T.matrix('Z'), T.scalar('a'), T.scalar('b')
R,S,U,c,d = T.dmatrix('R'), T.dmatrix('S'), T.dmatrix('U'), T.dscalar('c'), T.dscalar('d') R,S,U,c,d = T.matrix('R'), T.matrix('S'), T.matrix('U'), T.scalar('c'), T.scalar('d')
u = T.row('u') u = T.row('u')
v = T.vector('v') v = T.vector('v')
w = T.col('w') w = T.col('w')
...@@ -605,8 +613,8 @@ def test_gemm_canonicalize(): ...@@ -605,8 +613,8 @@ def test_gemm_canonicalize():
assert can[3][0].owner.inputs == [c,b] assert can[3][0].owner.inputs == [c,b]
def test_gemm_factor(): def test_gemm_factor():
X,Y,Z,a,b = T.dmatrix('X'), T.dmatrix('Y'), T.dmatrix('Z'), T.dscalar('a'), T.dscalar('b') X,Y,Z,a,b = T.matrix('X'), T.matrix('Y'), T.matrix('Z'), T.scalar('a'), T.scalar('b')
R,S,U,c,d = T.dmatrix('R'), T.dmatrix('S'), T.dmatrix('U'), T.dscalar('c'), T.dscalar('d') R,S,U,c,d = T.matrix('R'), T.matrix('S'), T.matrix('U'), T.scalar('c'), T.scalar('d')
assert [(1.0, X), (1.0, Y)] == _factor_canonicalized([(1.0, X), (1.0, Y)]) assert [(1.0, X), (1.0, Y)] == _factor_canonicalized([(1.0, X), (1.0, Y)])
assert [(2.0, X)] == _factor_canonicalized([(1.0, X),(1.0, X)]) assert [(2.0, X)] == _factor_canonicalized([(1.0, X),(1.0, X)])
...@@ -644,8 +652,8 @@ def test_upcasting_scalar_nogemm(): ...@@ -644,8 +652,8 @@ def test_upcasting_scalar_nogemm():
#theano.printing.debugprint(f, print_type=True) #theano.printing.debugprint(f, print_type=True)
def test_gemm_nested(): def test_gemm_nested():
X,Y,Z,a,b = T.dmatrix('X'), T.dmatrix('Y'), T.dmatrix('Z'), T.dscalar('a'), T.dscalar('b') X,Y,Z,a,b = T.matrix('X'), T.matrix('Y'), T.matrix('Z'), T.scalar('a'), T.scalar('b')
R,S,U,c,d = T.dmatrix('R'), T.dmatrix('S'), T.dmatrix('U'), T.dscalar('c'), T.dscalar('d') R,S,U,c,d = T.matrix('R'), T.matrix('S'), T.matrix('U'), T.scalar('c'), T.scalar('d')
just_gemm([X,Y,Z,R,S,U,a,b,c,d], just_gemm([X,Y,Z,R,S,U,a,b,c,d],
[a * Z - b * (c*T.dot(X,Y) + d*Z)], [a * Z - b * (c*T.dot(X,Y) + d*Z)],
...@@ -663,7 +671,7 @@ def test_gemm_nested(): ...@@ -663,7 +671,7 @@ def test_gemm_nested():
max_graphlen=3) max_graphlen=3)
def test_gemm_opt_wishlist(): def test_gemm_opt_wishlist():
X,Y,Z,a,b = T.dmatrix(), T.dmatrix(), T.dmatrix(), T.dscalar(), T.dscalar() X,Y,Z,a,b = T.matrix(), T.matrix(), T.matrix(), T.scalar(), T.scalar()
#with >2 additions of the same T.dot(X,Y term #with >2 additions of the same T.dot(X,Y term
just_gemm([X,Y,Z,a,b], [(b * b) * Z * a + (a * a) * T.dot(X,Y) + b * T.dot(X,Y)]) just_gemm([X,Y,Z,a,b], [(b * b) * Z * a + (a * a) * T.dot(X,Y) + b * T.dot(X,Y)])
...@@ -699,8 +707,8 @@ def test_gemm_with_vector(): ...@@ -699,8 +707,8 @@ def test_gemm_with_vector():
my_just_gemm([Z - a*b*a*T.dot(X,Y) + v]) my_just_gemm([Z - a*b*a*T.dot(X,Y) + v])
def test_gemm_opt_vector_stuff(): def test_gemm_opt_vector_stuff():
X,Y,Z,a,b = T.dmatrix(), T.dmatrix(), T.dmatrix(), T.dscalar(), T.dscalar() X,Y,Z,a,b = T.matrix(), T.matrix(), T.matrix(), T.scalar(), T.scalar()
u,v = T.dvector(), T.dvector() u,v = T.vector(), T.vector()
f = inplace_func([a, u, v], a + T.dot(u,v), mode='FAST_RUN') f = inplace_func([a, u, v], a + T.dot(u,v), mode='FAST_RUN')
if gemm_inplace in [n.op for n in f.maker.env.nodes]: if gemm_inplace in [n.op for n in f.maker.env.nodes]:
...@@ -712,8 +720,8 @@ def test_gemm_opt_vector_stuff(): ...@@ -712,8 +720,8 @@ def test_gemm_opt_vector_stuff():
def test_inplace0(): def test_inplace0():
#should fail to insert gemm_inplace because gemm_inplace would create cycles #should fail to insert gemm_inplace because gemm_inplace would create cycles
X,Y,Z,a,b = T.dmatrix('X'), T.dmatrix('Y'), T.dmatrix('Z'), T.dscalar('a'), T.dscalar('b') X,Y,Z,a,b = T.matrix('X'), T.matrix('Y'), T.matrix('Z'), T.scalar('a'), T.scalar('b')
R, S, c = T.dmatrix('R'), T.dmatrix('S'), T.dscalar('c') R, S, c = T.matrix('R'), T.matrix('S'), T.scalar('c')
f = inplace_func([Z, b, R, S], f = inplace_func([Z, b, R, S],
[Z * (Z + b * T.dot(R,S).T)], mode='FAST_RUN') [Z * (Z + b * T.dot(R,S).T)], mode='FAST_RUN')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论