提交 2dbb7baa authored 作者: Frederic Bastien's avatar Frederic Bastien

added some tests case for the fusion when the fusion don't work.

上级 9ad67442
...@@ -712,6 +712,7 @@ class test_fusion(unittest.TestCase): ...@@ -712,6 +712,7 @@ class test_fusion(unittest.TestCase):
iyv = theano._asarray(my_init(shp,num=70),dtype='int32') iyv = theano._asarray(my_init(shp,num=70),dtype='int32')
izv = theano._asarray(my_init(shp,num=70),dtype='int32') izv = theano._asarray(my_init(shp,num=70),dtype='int32')
fwx=fw+fx fwx=fw+fx
ftanx = theano.tensor.tan(fx)
cases = [ cases = [
(fx+fy+fz,(fx,fy,fz),(fxv,fyv,fzv),1,fxv+fyv+fzv,'float32'),#0 (fx+fy+fz,(fx,fy,fz),(fxv,fyv,fzv),1,fxv+fyv+fzv,'float32'),#0
(fx*fy*fz,(fx,fy,fz),(fxv,fyv,fzv),1,fxv*fyv*fzv,'float32'),#1 (fx*fy*fz,(fx,fy,fz),(fxv,fyv,fzv),1,fxv*fyv*fzv,'float32'),#1
...@@ -792,6 +793,12 @@ class test_fusion(unittest.TestCase): ...@@ -792,6 +793,12 @@ class test_fusion(unittest.TestCase):
(theano.tensor.pow(fx*fy+fz,fx*fy),(fx,fy,fz),(fxv,fyv,fzv),1,numpy.power(fxv*fyv+fzv,fxv*fyv),'float32'), (theano.tensor.pow(fx*fy+fz,fx*fy),(fx,fy,fz),(fxv,fyv,fzv),1,numpy.power(fxv*fyv+fzv,fxv*fyv),'float32'),
(fv+fy**fz,(fv,fy,fz),(fvv,fyv,fzv),2,fvv+fyv**fzv,'float32'),#fused with a dimshuffle (fv+fy**fz,(fv,fy,fz),(fvv,fyv,fzv),2,fvv+fyv**fzv,'float32'),#fused with a dimshuffle
(fv-fy+tanh(fz),(fv,fy,fz),(fvv,fyv,fzv),2,fvv-fyv+numpy.tanh(fzv),'float32'),#fused with a dimshuffle (fv-fy+tanh(fz),(fv,fy,fz),(fvv,fyv,fzv),2,fvv-fyv+numpy.tanh(fzv),'float32'),#fused with a dimshuffle
# Cases where the same input is reused many times.
(theano.tensor.mul(fx,fx,fx,fx),(fx,),(fxv,),1,fxv*fxv*fxv*fxv,'float32'),
(theano.tensor.mul(fx,ftanx,ftanx),(fx,),(fxv,),2,fxv*numpy.tan(fxv)*numpy.tan(fxv),'float32'),# TODO: This case is not fused!
(theano.tensor.mul(fx,ftanx,ftanx,fx),(fx,),(fxv,),2,fxv*numpy.tan(fxv)*numpy.tan(fxv)*fxv,'float32'),# TODO: This case is not fused!
(theano.tensor.mul(ftanx,ftanx,fx+fy),(fx,fy),(fxv,fyv),3,numpy.tan(fxv)*numpy.tan(fxv)*(fxv+fyv),'float32'),# TODO: This case is not fused!
] ]
if slice: if slice:
cases = cases[slice] cases = cases[slice]
...@@ -814,7 +821,7 @@ class test_fusion(unittest.TestCase): ...@@ -814,7 +821,7 @@ class test_fusion(unittest.TestCase):
t1=time.time() t1=time.time()
else: else:
out=shared_fn(numpy.zeros(shp, dtype=out_dtype),'out') out=shared_fn(numpy.zeros(shp, dtype=out_dtype),'out')
f = function(sym_inputs,[],updates=[(out,out+g)],mode=mode) f = function(sym_inputs,[],updates=[(out, g)],mode=mode)
t0=time.time() t0=time.time()
for x in range(nb_repeat): for x in range(nb_repeat):
f(*val_inputs) f(*val_inputs)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论