提交 e099538a authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Add optimizations to the mode used, to make test pass in FAST_COMPILE

上级 d906e876
......@@ -441,7 +441,11 @@ def test_dot22():
f(av,bv)
def test_dot22scalar():
m = theano.compile.get_default_mode().including('local_dot_to_dot22','local_dot22_to_dot22scalar','specialize')
## including does not seem to work for 'local_dot_to_dot22' and
## 'local_dot22_to_dot22scalar'
## TODO: exclude other optimizations in BlasOpt?
#m = theano.compile.get_default_mode().including('local_dot_to_dot22','local_dot22_to_dot22scalar','specialize')
m = theano.compile.get_default_mode().including('BlasOpt', 'specialize')
a=T.matrix()
b=T.matrix()
c=T.matrix()
......@@ -469,14 +473,16 @@ def test_dot22scalar():
assert len(topo)==2
f(av,bv,cv)
f = theano.function([a,b,c],0.1*c * 0.2*T.dot(a,b),mode=m)
## Here, canonicalize also seems needed
## TODO: add only the optimizations needed?
m2 = m.including('canonicalize')
f = theano.function([a,b,c],0.1*c * 0.2*T.dot(a,b),mode=m2)
topo = f.maker.env.toposort()
assert _dot22scalar in [x.op for x in topo]
assert len(topo)==2
f(av,bv,cv)
f = theano.function([a,b,c],c * 0.2*a*T.dot(a,b),mode=m)
f = theano.function([a,b,c],c * 0.2*a*T.dot(a,b),mode=m2)
topo = f.maker.env.toposort()
assert _dot22scalar in [x.op for x in topo]
assert len(topo)==2
......@@ -490,7 +496,7 @@ def test_dot22scalar():
# assert _dot22scalar in [x.op for x in topo]
# assert len(topo)==2
f(av,bv,cv)
f = theano.function([a,b,c],c * a*0.2*T.dot(a,b),mode=m)
f = theano.function([a,b,c],c * a*0.2*T.dot(a,b),mode=m2)
topo = f.maker.env.toposort()
assert _dot22scalar in [x.op for x in topo]
assert len(topo)==2
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论