提交 69ec63f0 authored 作者: Frederic Bastien's avatar Frederic Bastien

code cleanup and fuse more case of elemwise.

上级 0d9b26fc
...@@ -1242,9 +1242,10 @@ def local_elemwise_fusion(node): ...@@ -1242,9 +1242,10 @@ def local_elemwise_fusion(node):
s_inputs = []#inputs of the new scalar op. s_inputs = []#inputs of the new scalar op.
s_g=[]#graph of scalar, what will by done in the inner loop. s_g=[]#graph of scalar, what will by done in the inner loop.
for i in node.inputs: for i in node.inputs:
if i.owner and isinstance(i.owner.op,T.Elemwise): if i.owner and isinstance(i.owner.op,T.Elemwise) and len(i.clients)<=1:
if len(i.clients)>1: if len(i.clients)>1:
#should we put this in the first if, then we would go to the elif to don't fuse it? #should we put this in the first if, then we would go to the elif to don't fuse it?
#if one of the inputs have more then 1 clients and it is an intermediate result. We don't fuse.
print "local_elemwise_fusion: Elemwise inputs have more then 1 client. Don't optimise for now" print "local_elemwise_fusion: Elemwise inputs have more then 1 client. Don't optimise for now"
return False return False
...@@ -1254,36 +1255,22 @@ def local_elemwise_fusion(node): ...@@ -1254,36 +1255,22 @@ def local_elemwise_fusion(node):
s_inputs.extend(s_input) s_inputs.extend(s_input)
s_op=i.owner.op.scalar_op(*s_input) s_op=i.owner.op.scalar_op(*s_input)
s_g.append(s_op) s_g.append(s_op)
elif not i.owner: else:
if i.owner and isinstance(i.owner.op,T.Elemwise) and len(i.clients)>1:
#should we put this in the first if, then we would go to the elif to don't fuse it?
print "local_elemwise_fusion: inputs have more then 1 client. Don't fuse it for now.!"
return False
inputs.append(i) inputs.append(i)
s=scalar.Scalar(i.dtype).make_variable() s=scalar.Scalar(i.dtype).make_variable()
s_inputs.append(s) s_inputs.append(s)
s_g.append(s) s_g.append(s)
else:
print "local_elemwise_fusion: have an owner that is not an Elemwise."
return False
#if no inputs have are an elemwise, their is nothing to fuse.
if nb_elemwise==0: if nb_elemwise==0:
# print "local_elemwise_fusion: node have no elemwise in inputs. Nothing to fuse."
return False return False
#if one of the inputs have more then 1 clients and it is an intermediate result. We don't fuse.
if any([len(x.clients)!=1 and x.owner for x in node.inputs]):#len(node.inputs[0].clients)!=1:
print "local_elemwise_fusion: node have more then 1 clients.", [x.clients for x in node.inputs]
return False
otype = node.outputs[0].type otype = node.outputs[0].type
# print "local_elemwise_fusion"
# print [type(x) for x in s_inputs]
# print "node",node
# print "node.inputs",node.inputs
# print "node.inputs[0].op",node.inputs[0].op
# print "node.inputs[1].op",node.inputs[1].op
# print "node.outputs",node.outputs
# print "s_g",s_g, [type(x) for x in s_g]
# print "s_inpust",s_inputs
s_new_out=node.op.scalar_op(*s_g) s_new_out=node.op.scalar_op(*s_g)
# print "s_new_out",s_new_out
#create the composite op. #create the composite op.
C = scalar.Composite(s_inputs,[s_new_out]) C = scalar.Composite(s_inputs,[s_new_out])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论