提交 aebaa276 authored 作者: Frederic Bastien's avatar Frederic Bastien

fix test in fast_compile

上级 ec573e63
......@@ -15,6 +15,12 @@ from theano import Param, shared
from test_basic import (_approx_eq, as_tensor_variable, inplace_func,
compile, constant, inplace, eval_outputs)
if config.mode == 'FAST_COMPILE':
mode_not_fast_compile = 'FAST_RUN'
else: mode_not_fast_compile = config.mode
mode_blas_opt = theano.compile.get_default_mode().including('BlasOpt', 'specialize')
class t_gemm(TestCase):
"""This test suite is supposed to establish that gemm works as it is supposed to."""
def setUp(self):
......@@ -109,19 +115,19 @@ class t_gemm(TestCase):
l2_reg=T.constant(0.0001)
#test constant merge with gemm
f = theano.function([a,b],updates={s:lr1*T.dot(a,b)+l2_reg*lr2*s}).maker.env.toposort()
f = theano.function([a,b],updates={s:lr1*T.dot(a,b)+l2_reg*lr2*s},mode=mode_not_fast_compile).maker.env.toposort()
#[Gemm{inplace}(<TensorType(float64, matrix)>, 0.01, <TensorType(float64, matrix)>, <TensorType(float64, matrix)>, 2e-06)]
assert len(f)==1
assert f[0].op==gemm_inplace
#test factored scalar with merge
f = theano.function([a,b],updates={s:lr1*(T.dot(a,b)-l2_reg*s)}).maker.env.toposort()
f = theano.function([a,b],updates={s:lr1*(T.dot(a,b)-l2_reg*s)},mode=mode_not_fast_compile).maker.env.toposort()
#[Gemm{inplace}(<TensorType(float64, matrix)>, 0.01, <TensorType(float64, matrix)>, <TensorType(float64, matrix)>, -2e-06)]
assert len(f)==1
assert f[0].op==gemm_inplace
#test factored scalar with merge and neg
f = theano.function([a,b],updates={s:s-lr1*(s*.0002+T.dot(a,b))}).maker.env.toposort()
f = theano.function([a,b],updates={s:s-lr1*(s*.0002+T.dot(a,b))},mode=mode_not_fast_compile).maker.env.toposort()
#[Gemm{inplace}(<TensorType(float64, matrix)>, -0.01, <TensorType(float64, matrix)>, <TensorType(float64, matrix)>, 0.999998)]
assert len(f)==1
assert f[0].op==gemm_inplace
......@@ -533,12 +539,9 @@ def test_inplace1():
assert [n.op for n in f.maker.env.nodes] == [gemm_no_inplace]
def test_dot22():
if config.mode == 'FAST_COMPILE':
m = 'FAST_RUN'
else: m = config.mode
a=T.matrix()
b=T.matrix()
f = theano.function([a,b],T.dot(a,b),mode=m)
f = theano.function([a,b],T.dot(a,b),mode=mode_blas_opt)
topo = f.maker.env.toposort()
assert _dot22 in [x.op for x in topo]
av=numpy.random.rand(5,5)
......@@ -550,7 +553,7 @@ def test_dot22scalar():
## 'local_dot22_to_dot22scalar'
## TODO: exclude other optimizations in BlasOpt?
#m = theano.compile.get_default_mode().including('local_dot_to_dot22','local_dot22_to_dot22scalar','specialize')
m = theano.compile.get_default_mode().including('BlasOpt', 'specialize')
#m = theano.compile.get_default_mode().including('BlasOpt', 'specialize')
a=T.matrix()
b=T.matrix()
c=T.matrix()
......@@ -559,20 +562,20 @@ def test_dot22scalar():
cv=numpy.random.rand(5,5)
if True:
f = theano.function([a,b],0.2*T.dot(a,b),mode=m)
f = theano.function([a,b],0.2*T.dot(a,b),mode=mode_blas_opt)
topo = f.maker.env.toposort()
assert _dot22scalar in [x.op for x in topo]
assert len(topo)==1
f(av,bv)
if True:
f = theano.function([a,b,c],0.2*c*T.dot(a,b),mode=m)
f = theano.function([a,b,c],0.2*c*T.dot(a,b),mode=mode_blas_opt)
topo = f.maker.env.toposort()
assert _dot22scalar in [x.op for x in topo]
assert len(topo)==2
f(av,bv,cv)
f = theano.function([a,b,c],c * 0.2*T.dot(a,b),mode=m)
f = theano.function([a,b,c],c * 0.2*T.dot(a,b),mode=mode_blas_opt)
topo = f.maker.env.toposort()
assert _dot22scalar in [x.op for x in topo]
assert len(topo)==2
......@@ -580,7 +583,7 @@ def test_dot22scalar():
## Here, canonicalize also seems needed
## TODO: add only the optimizations needed?
m2 = m.including('canonicalize')
m2 = mode_blas_opt.including('canonicalize')
f = theano.function([a,b,c],0.1*c * 0.2*T.dot(a,b),mode=m2)
topo = f.maker.env.toposort()
assert _dot22scalar in [x.op for x in topo]
......@@ -593,7 +596,7 @@ def test_dot22scalar():
assert len(topo)==2
f(av,bv,cv)
f = theano.function([a,b,c],0.2*c *a*T.dot(a,b),mode=m)
f = theano.function([a,b,c],0.2*c *a*T.dot(a,b),mode=mode_blas_opt)
topo = f.maker.env.toposort()
#currently the canonizer don't always merge all Mul together...
#that force the optimizer to make a recursive search witch it don't do now.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论