提交 4ebf9e1a authored 作者: lamblin's avatar lamblin

Merge pull request #673 from nouiz/fusion

Enable fusion of node that have the same clients multiple time.
...@@ -4395,9 +4395,11 @@ def local_elemwise_fusion_op(OP, max_input_fct=lambda node: 1024): ...@@ -4395,9 +4395,11 @@ def local_elemwise_fusion_op(OP, max_input_fct=lambda node: 1024):
# We should not check the number of inputs here # We should not check the number of inputs here
# As fusing op don't always change the number of input. # As fusing op don't always change the number of input.
# If a variable is used as multiple into to the same node,
# we still want to fusion. So we take the set.
if (i.owner and if (i.owner and
isinstance(i.owner.op, OP) and isinstance(i.owner.op, OP) and
len(i.clients) == 1): len(set([n for n, idx in i.clients])) == 1):
do_fusion = True do_fusion = True
try: try:
......
...@@ -887,13 +887,9 @@ class test_fusion(unittest.TestCase): ...@@ -887,13 +887,9 @@ class test_fusion(unittest.TestCase):
# Cases where the same input is reused many times. # Cases where the same input is reused many times.
(theano.tensor.mul(fx,fx,fx,fx),(fx,),(fxv,),1,fxv*fxv*fxv*fxv,'float32'), (theano.tensor.mul(fx,fx,fx,fx),(fx,),(fxv,),1,fxv*fxv*fxv*fxv,'float32'),
# TODO: This case is not fused! (theano.tensor.mul(fx,ftanx,ftanx),(fx,),(fxv,),1,fxv*numpy.tan(fxv)*numpy.tan(fxv),'float32'),
(theano.tensor.mul(fx,ftanx,ftanx),(fx,),(fxv,),2,fxv*numpy.tan(fxv)*numpy.tan(fxv),'float32'), (theano.tensor.mul(fx,ftanx,ftanx,fx),(fx,),(fxv,),1,fxv*numpy.tan(fxv)*numpy.tan(fxv)*fxv,'float32'),
# TODO: This case is not fused! (theano.tensor.mul(ftanx,ftanx,fx+fy),(fx,fy),(fxv,fyv),1,numpy.tan(fxv)*numpy.tan(fxv)*(fxv+fyv),'float32'),
(theano.tensor.mul(fx,ftanx,ftanx,fx),(fx,),(fxv,),2,fxv*numpy.tan(fxv)*numpy.tan(fxv)*fxv,'float32'),
# The next case test when one variable appear as many inputs to an op.
# In the past, this was not fused. (TODO) Now it is partially fused.
(theano.tensor.mul(ftanx,ftanx,fx+fy),(fx,fy),(fxv,fyv),2,numpy.tan(fxv)*numpy.tan(fxv)*(fxv+fyv),'float32'),
] ]
if slice: if slice:
cases = cases[slice] cases = cases[slice]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论