提交 e9174a6d authored 作者: Frederic Bastien's avatar Frederic Bastien

removed duplicate into when fusing Elemwise and GpuElemwise.

This allow to fuse more op in GpuElemwise(gpu have a limit in the number of fct inputs) and make the graph a little simpler.
上级 ec195403
......@@ -2315,7 +2315,13 @@ def local_elemwise_fusion_op(OP):
#if the scalar_op don't have a c implementation, we skip its fusion to allow the fusion of the other ops.
do_fusion=True
try:
s_input = [scalar.Scalar(x.dtype).make_variable() for x in i.owner.inputs]
s_input = []
for ii in i.owner.inputs:
if ii in inputs:
s_input.append(s_inputs[inputs.index(ii)])
else:
s_input.append(scalar.Scalar(ii.dtype).make_variable())
#s_input = [scalar.Scalar(x.dtype).make_variable() for x in i.owner.inputs]
s_op=i.owner.op.scalar_op(*s_input)
i.owner.op.scalar_op.c_code(s_op.owner,"test_presence_of_c_code",
["x" for x in i.owner.inputs],
......@@ -2334,8 +2340,11 @@ def local_elemwise_fusion_op(OP):
s_inputs.extend(s_input)
s_g.append(s_op)
else:
if i in inputs:
s=s_inputs[inputs.index(i)]
else:
s=scalar.Scalar(i.dtype).make_variable()
inputs.append(i)
s=scalar.Scalar(i.dtype).make_variable()
s_inputs.append(s)
s_g.append(s)
......@@ -2344,6 +2353,21 @@ def local_elemwise_fusion_op(OP):
# print "local_elemwise_fusion: no elemwise in inputs. Nothing to fuse."
return False
#remove duplicate inputs, we most keep the order.
inputs2=[]
s_inputs2=[]
for i,si in zip(inputs,s_inputs):
if i not in inputs2:
inputs2.append(i)
s_inputs2.append(si)
else:
assert si in s_inputs2
inputs = inputs2
s_inputs = s_inputs2
del inputs2, s_inputs2
assert len(s_inputs)==len(inputs)
otype = node.outputs[0].type
s_new_out=node.op.scalar_op(*s_g)
try:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论