提交 74ca96f1 authored 作者: Frederic's avatar Frederic

make blas test use floatX.

上级 c0d9f466
......@@ -171,14 +171,14 @@ class t_gemm(TestCase):
self.cmp(self.rand(0, 0), -1.0, self.rand(0, 0), self.rand(0, 0), -1.0)
def test_factorised_scalar(self):
a = T.dmatrix()
b = T.dmatrix()
c = T.dmatrix()
s = theano.shared(numpy.zeros((5, 5)))
a = T.matrix()
b = T.matrix()
c = T.matrix()
s = theano.shared(numpy.zeros((5, 5)).astype(config.floatX))
lr1 = T.constant(0.01).astype('float64')
lr2 = T.constant(2).astype('float64')
l2_reg = T.constant(0.0001).astype('float64')
lr1 = T.constant(0.01).astype(config.floatX)
lr2 = T.constant(2).astype(config.floatX)
l2_reg = T.constant(0.0001).astype(config.floatX)
#test constant merge with gemm
f = theano.function([a, b], updates={s: lr1 * T.dot(a, b) +
......@@ -418,7 +418,7 @@ class t_as_scalar(TestCase):
def test3(self):
"""Test that it fails on nonscalar variables"""
a = T.dmatrix()
a = T.matrix()
self.assertTrue(None == _as_scalar(a))
self.assertTrue(None == _as_scalar(T.DimShuffle([False, False],
[0, 'x', 1])(a)))
......@@ -427,7 +427,7 @@ class t_as_scalar(TestCase):
class T_real_matrix(TestCase):
def test0(self):
self.assertTrue(_is_real_matrix(T.DimShuffle([False, False],
[1, 0])(T.dmatrix())))
[1, 0])(T.matrix())))
self.assertTrue(not _is_real_matrix(T.DimShuffle([False],
['x', 0])
(T.dvector())))
......@@ -441,7 +441,7 @@ def fail(msg):
"""This test suite ensures that Gemm is inserted where it belongs, and that the resulting
functions compute the same things as the originals."""
def XYZab():
return T.dmatrix(), T.dmatrix(), T.dmatrix(), T.dscalar(), T.dscalar()
return T.matrix(), T.matrix(), T.matrix(), T.scalar(), T.scalar()
class Failure(Exception):
......@@ -476,11 +476,16 @@ def just_gemm(i, o, ishapes=[(4, 3), (3, 5), (4, 5), (), ()], max_graphlen=0):
assert False, 'graphlen=%i>%i' % (graphlen, max_graphlen)
rng = numpy.random.RandomState(unittest_tools.fetch_seed(234))
r0 = f(*[rng.randn(*sh) for sh in ishapes])
r0 = f(*[numpy.asarray(rng.randn(*sh), config.floatX)
for sh in ishapes])
rng = numpy.random.RandomState(unittest_tools.fetch_seed(234))
r1 = g(*[rng.randn(*sh) for sh in ishapes])
r1 = g(*[numpy.asarray(rng.randn(*sh), config.floatX)
for sh in ishapes])
max_abs_err = numpy.max(numpy.abs(r0[0] - r1[0]))
if max_abs_err > 1.0e-8:
eps = 1.0e-8
if config.floatX == 'float32':
eps = 1.0e-6
if max_abs_err > eps:
raise Failure('GEMM is computing the wrong output. max_rel_err =',
max_abs_err)
except Failure:
......@@ -519,8 +524,8 @@ def test_gemm_opt0():
def test_gemm_opt_double_gemm():
"""This is the pattern that shows up in the autoencoder"""
X,Y,Z,a,b = T.dmatrix(), T.dmatrix(), T.dmatrix(), T.dscalar(), T.dscalar()
R, S, c = T.dmatrix(), T.dmatrix(), T.dscalar()
X,Y,Z,a,b = T.matrix(), T.matrix(), T.matrix(), T.scalar(), T.scalar()
R, S, c = T.matrix(), T.matrix(), T.scalar()
just_gemm([X,Y,Z,a,b, R, S, c], [Z *c + a * T.dot(X,Y) + b * T.dot(R,S).T],
ishapes=[(4,3), (3,5), (4,5), (), (), (5,9), (9,4), ()])
......@@ -528,7 +533,7 @@ def test_gemm_opt_double_gemm():
ishapes=[(4,3), (3,5), (4,5), (), (), (5,9), (9,4), ()]
i = [X,Y,Z,a,b, R, S, c]
o = [(a * T.dot(X,Y)
+ gemm_inplace(Z, b, S.T, R.T, T.constant(1.0).astype('float64')))]
+ gemm_inplace(Z, b, S.T, R.T, T.constant(1.0).astype(config.floatX)))]
try:
f = inplace_func([Param(ii, mutable=True) for ii in i],o,
mode='FAST_RUN', on_unused_input='ignore')
......@@ -541,11 +546,14 @@ def test_gemm_opt_double_gemm():
# if node.op == gemm_inplace: raise Failure('gemm_inplace in graph')
rng = numpy.random.RandomState(unittest_tools.fetch_seed(234))
r0 = f(*[rng.randn(*sh) for sh in ishapes])
r0 = f(*[numpy.asarray(rng.randn(*sh), config.floatX) for sh in ishapes])
rng = numpy.random.RandomState(unittest_tools.fetch_seed(234))
r1 = g(*[rng.randn(*sh) for sh in ishapes])
r1 = g(*[numpy.asarray(rng.randn(*sh), config.floatX) for sh in ishapes])
max_abs_err = numpy.max(numpy.abs(r0[0] - r1[0]))
if max_abs_err > 1.0e-8:
eps = 1.0e-8
if config.floatX == 'float32':
eps = 1.0e-6
if max_abs_err > eps:
raise Failure('GEMM is computing the wrong output. max_rel_err =', max_abs_err)
except Failure:
for node in f.maker.env.toposort():
......@@ -554,8 +562,8 @@ def test_gemm_opt_double_gemm():
def test_gemm_canonicalize():
X,Y,Z,a,b = T.dmatrix('X'), T.dmatrix('Y'), T.dmatrix('Z'), T.dscalar('a'), T.dscalar('b')
R,S,U,c,d = T.dmatrix('R'), T.dmatrix('S'), T.dmatrix('U'), T.dscalar('c'), T.dscalar('d')
X,Y,Z,a,b = T.matrix('X'), T.matrix('Y'), T.matrix('Z'), T.scalar('a'), T.scalar('b')
R,S,U,c,d = T.matrix('R'), T.matrix('S'), T.matrix('U'), T.scalar('c'), T.scalar('d')
u = T.row('u')
v = T.vector('v')
w = T.col('w')
......@@ -605,8 +613,8 @@ def test_gemm_canonicalize():
assert can[3][0].owner.inputs == [c,b]
def test_gemm_factor():
X,Y,Z,a,b = T.dmatrix('X'), T.dmatrix('Y'), T.dmatrix('Z'), T.dscalar('a'), T.dscalar('b')
R,S,U,c,d = T.dmatrix('R'), T.dmatrix('S'), T.dmatrix('U'), T.dscalar('c'), T.dscalar('d')
X,Y,Z,a,b = T.matrix('X'), T.matrix('Y'), T.matrix('Z'), T.scalar('a'), T.scalar('b')
R,S,U,c,d = T.matrix('R'), T.matrix('S'), T.matrix('U'), T.scalar('c'), T.scalar('d')
assert [(1.0, X), (1.0, Y)] == _factor_canonicalized([(1.0, X), (1.0, Y)])
assert [(2.0, X)] == _factor_canonicalized([(1.0, X),(1.0, X)])
......@@ -644,8 +652,8 @@ def test_upcasting_scalar_nogemm():
#theano.printing.debugprint(f, print_type=True)
def test_gemm_nested():
X,Y,Z,a,b = T.dmatrix('X'), T.dmatrix('Y'), T.dmatrix('Z'), T.dscalar('a'), T.dscalar('b')
R,S,U,c,d = T.dmatrix('R'), T.dmatrix('S'), T.dmatrix('U'), T.dscalar('c'), T.dscalar('d')
X,Y,Z,a,b = T.matrix('X'), T.matrix('Y'), T.matrix('Z'), T.scalar('a'), T.scalar('b')
R,S,U,c,d = T.matrix('R'), T.matrix('S'), T.matrix('U'), T.scalar('c'), T.scalar('d')
just_gemm([X,Y,Z,R,S,U,a,b,c,d],
[a * Z - b * (c*T.dot(X,Y) + d*Z)],
......@@ -663,7 +671,7 @@ def test_gemm_nested():
max_graphlen=3)
def test_gemm_opt_wishlist():
X,Y,Z,a,b = T.dmatrix(), T.dmatrix(), T.dmatrix(), T.dscalar(), T.dscalar()
X,Y,Z,a,b = T.matrix(), T.matrix(), T.matrix(), T.scalar(), T.scalar()
#with >2 additions of the same T.dot(X,Y term
just_gemm([X,Y,Z,a,b], [(b * b) * Z * a + (a * a) * T.dot(X,Y) + b * T.dot(X,Y)])
......@@ -699,8 +707,8 @@ def test_gemm_with_vector():
my_just_gemm([Z - a*b*a*T.dot(X,Y) + v])
def test_gemm_opt_vector_stuff():
X,Y,Z,a,b = T.dmatrix(), T.dmatrix(), T.dmatrix(), T.dscalar(), T.dscalar()
u,v = T.dvector(), T.dvector()
X,Y,Z,a,b = T.matrix(), T.matrix(), T.matrix(), T.scalar(), T.scalar()
u,v = T.vector(), T.vector()
f = inplace_func([a, u, v], a + T.dot(u,v), mode='FAST_RUN')
if gemm_inplace in [n.op for n in f.maker.env.nodes]:
......@@ -712,8 +720,8 @@ def test_gemm_opt_vector_stuff():
def test_inplace0():
#should fail to insert gemm_inplace because gemm_inplace would create cycles
X,Y,Z,a,b = T.dmatrix('X'), T.dmatrix('Y'), T.dmatrix('Z'), T.dscalar('a'), T.dscalar('b')
R, S, c = T.dmatrix('R'), T.dmatrix('S'), T.dscalar('c')
X,Y,Z,a,b = T.matrix('X'), T.matrix('Y'), T.matrix('Z'), T.scalar('a'), T.scalar('b')
R, S, c = T.matrix('R'), T.matrix('S'), T.scalar('c')
f = inplace_func([Z, b, R, S],
[Z * (Z + b * T.dot(R,S).T)], mode='FAST_RUN')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论