提交 d1d0f164 authored 作者: Frederic Bastien's avatar Frederic Bastien

fuse elemwise event when a variable is used for many into to the elemwise.

上级 9bec6b50
......@@ -3119,7 +3119,9 @@ def local_elemwise_fusion_op(OP, max_input_fct=lambda node: 1024):
s_inputs.extend(tmp_scalar)
s_g.append(s_op)
else:
if i in inputs:
# We must support the case where the same variable appear many
# time in the inputs
if inputs.count(i)==node.inputs.count(i):
s=s_inputs[inputs.index(i)]
else:
s=scalar.Scalar(i.dtype).make_variable()
......
......@@ -798,7 +798,9 @@ class test_fusion(unittest.TestCase):
(theano.tensor.mul(fx,fx,fx,fx),(fx,),(fxv,),1,fxv*fxv*fxv*fxv,'float32'),
(theano.tensor.mul(fx,ftanx,ftanx),(fx,),(fxv,),2,fxv*numpy.tan(fxv)*numpy.tan(fxv),'float32'),# TODO: This case is not fused!
(theano.tensor.mul(fx,ftanx,ftanx,fx),(fx,),(fxv,),2,fxv*numpy.tan(fxv)*numpy.tan(fxv)*fxv,'float32'),# TODO: This case is not fused!
(theano.tensor.mul(ftanx,ftanx,fx+fy),(fx,fy),(fxv,fyv),3,numpy.tan(fxv)*numpy.tan(fxv)*(fxv+fyv),'float32'),# TODO: This case is not fused!
# The next case test when one variable appear as many inputs to an op.
# In the past, this was not fused.
(theano.tensor.mul(ftanx,ftanx,fx+fy),(fx,fy),(fxv,fyv),2,numpy.tan(fxv)*numpy.tan(fxv)*(fxv+fyv),'float32'),# TODO: This case is not fused!
]
if slice:
cases = cases[slice]
......@@ -844,6 +846,14 @@ class test_fusion(unittest.TestCase):
if assert_len_topo:
if not len(topo_)==nb_elemwise:
fail3.append((id,topo_,nb_elemwise))
if nb_elemwise == 1:
# check that the number of input to the Composite Elemwise is ok
# when there is not variable that appear multiple time the in input
# of g
assert ((numpy.sum([not isinstance(x, theano.gof.Constant)
for x in topo_[0].inputs]) ==
len(sym_inputs)) or
len(set(g.owner.inputs)) != len(g.owner.inputs))
if not out_dtype==out.dtype:
fail4.append((id,out_dtype,out.dtype))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论