提交 80a2c70c authored 作者: Frederic Bastien's avatar Frederic Bastien

fix error introduced in fusion this week.

上级 47d5b5a4
...@@ -2311,6 +2311,8 @@ def local_elemwise_fusion_op(OP): ...@@ -2311,6 +2311,8 @@ def local_elemwise_fusion_op(OP):
for i in node.inputs: for i in node.inputs:
do_fusion = False do_fusion = False
catch = False catch = False
tmp_input=[]#used to remove duplicate input.
tmp_scalar=[]
if i.owner and isinstance(i.owner.op, OP) and len(i.clients)==1: if i.owner and isinstance(i.owner.op, OP) and len(i.clients)==1:
#if the scalar_op don't have a c implementation, we skip its fusion to allow the fusion of the other ops. #if the scalar_op don't have a c implementation, we skip its fusion to allow the fusion of the other ops.
do_fusion=True do_fusion=True
...@@ -2319,8 +2321,12 @@ def local_elemwise_fusion_op(OP): ...@@ -2319,8 +2321,12 @@ def local_elemwise_fusion_op(OP):
for ii in i.owner.inputs: for ii in i.owner.inputs:
if ii in inputs: if ii in inputs:
s_input.append(s_inputs[inputs.index(ii)]) s_input.append(s_inputs[inputs.index(ii)])
elif ii in tmp_input:
s_input.append(tmp_scalar[tmp_input.index(ii)])
else: else:
s_input.append(scalar.Scalar(ii.dtype).make_variable()) s_input.append(scalar.Scalar(ii.dtype).make_variable())
tmp_input.append(ii)
tmp_scalar.append(s_input[-1])
s_op=i.owner.op.scalar_op(*s_input) s_op=i.owner.op.scalar_op(*s_input)
i.owner.op.scalar_op.c_code(s_op.owner,"test_presence_of_c_code", i.owner.op.scalar_op.c_code(s_op.owner,"test_presence_of_c_code",
["x" for x in i.owner.inputs], ["x" for x in i.owner.inputs],
...@@ -2334,17 +2340,18 @@ def local_elemwise_fusion_op(OP): ...@@ -2334,17 +2340,18 @@ def local_elemwise_fusion_op(OP):
do_fusion=False do_fusion=False
if do_fusion: if do_fusion:
#we should not put duplicate input into s_inputs and inputs
nb_elemwise+=1 nb_elemwise+=1
inputs.extend(i.owner.inputs) inputs.extend(tmp_input)
s_inputs.extend(s_input) s_inputs.extend(tmp_scalar)
s_g.append(s_op) s_g.append(s_op)
else: else:
if i in inputs: if i in inputs:
s=s_inputs[inputs.index(i)] s=s_inputs[inputs.index(i)]
else: else:
s=scalar.Scalar(i.dtype).make_variable() s=scalar.Scalar(i.dtype).make_variable()
inputs.append(i) inputs.append(i)
s_inputs.append(s) s_inputs.append(s)
s_g.append(s) s_g.append(s)
#if no inputs have are an elemwise, there is nothing to fuse. #if no inputs have are an elemwise, there is nothing to fuse.
...@@ -2352,20 +2359,7 @@ def local_elemwise_fusion_op(OP): ...@@ -2352,20 +2359,7 @@ def local_elemwise_fusion_op(OP):
# print "local_elemwise_fusion: no elemwise in inputs. Nothing to fuse." # print "local_elemwise_fusion: no elemwise in inputs. Nothing to fuse."
return False return False
#remove duplicate inputs, we most keep the order.
inputs2=[]
s_inputs2=[]
for i,si in zip(inputs,s_inputs):
if i not in inputs2:
inputs2.append(i)
s_inputs2.append(si)
else:
assert si in s_inputs2
inputs = inputs2
s_inputs = s_inputs2
del inputs2, s_inputs2
assert len(s_inputs)==len(inputs) assert len(s_inputs)==len(inputs)
otype = node.outputs[0].type otype = node.outputs[0].type
s_new_out=node.op.scalar_op(*s_g) s_new_out=node.op.scalar_op(*s_g)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论