提交 774a8ba1 authored 作者: Frederic Bastien's avatar Frederic Bastien

first version of elemwise fusion. DISABLED.

上级 a8a59e90
...@@ -1227,6 +1227,78 @@ register_canonicalize(local_transposed_dot, name='local_transposed_dot') ...@@ -1227,6 +1227,78 @@ register_canonicalize(local_transposed_dot, name='local_transposed_dot')
# # Loop fusion # # # Loop fusion #
# ############### # ###############
@gof.local_optimizer([T.Elemwise, T.Elemwise])
def local_elemwise_fusion(node):
"""As part of specialisation, we fusion two consecutif elemwise op of the same shape.
"""
TODO:implement Composite.__eq__ by using CLinker.cmodule_key() to compare the graph.
if not isinstance(node.op, T.Elemwise):
return False
if isinstance(node.op.scalar_op, scalar.Composite):
print "local_elemwise_fusion of Composite"
nb_elemwise=0
inputs=[]
s_inputs = []
s_g=[]#graph of scalar, what will by done in the inner loop.
for i in node.inputs:
if i.owner and isinstance(i.owner.op,T.Elemwise):
if False and len(i.owner.inputs)!=2:
print "local_elemwise_fusion: Elemwise inputs have more then 2 inputs"
return False
if len(i.clients)>1:
print "local_elemwise_fusion: Elemwise inputs have more then 1 client. Don't optimise for now"
return False
if False and i.owner.inputs[0].owner != None:
print "local_elemwise_fusion: Elemwise inputs inputs[0] have an owner"
return False
if i.owner.inputs[1].owner != None:
print "local_elemwise_fusion: Elemwise inputs inputs[1] have an owner"
return False
nb_elemwise+=1
inputs.extend(i.owner.inputs)
s_input = [scalar.Scalar(x.dtype).make_variable() for x in i.owner.inputs]
s_inputs.extend(s_input)
# print s_input
# print i.owner.op.scalar_op, type(i.owner.op.scalar_op)
s_op=i.owner.op.scalar_op(*s_input)
s_g.append(s_op)
# s_g.append(scalar.Mul(*s_input))
elif not i.owner:
inputs.append(i)
s_inputs.append(scalar.Scalar(i.dtype).make_variable())
else:
print "local_elemwise_fusion: have an owner."
return False
#TODO: test nb_clients?
if len(node.inputs)!=2:
print "local_elemwise_fusion: node have more then 2 inputs."
return False
if nb_elemwise!=1:
print "local_elemwise_fusion: node have more then 1 elemwise in its inputs."
return False
if any([len(x.clients)!=1 for x in node.inputs]):#len(node.inputs[0].clients)!=1:
print "local_elemwise_fusion: node have more then 1 clients."
return False
otype = node.outputs[0].type
# print "local_elemwise_fusion"
# print [type(x) for x in s_inputs]
new_out=node.op.scalar_op(*s_g)
#print "s_g",s_g,"new_out",new_out, type(new_out), new_out.owner.op, new_out.owner.inputs
#create the composite op.
#print "Composite",s_inputs, new_out
C = scalar.Composite(s_inputs,[new_out])
#print "inputs",inputs
#print type(T.Elemwise(C))
n=T.Elemwise(C).make_node(*inputs)
# print n, n.outputs
assert len(n.outputs)==1
assert node.outputs[0].dtype==n.outputs[0].dtype
return n.outputs
#TODO: check dtype and broadcastable(type)
#register_specialize(local_elemwise_fusion)
# def make_composite(inputs, outputs): # def make_composite(inputs, outputs):
# scalar_inputs = [scalar.Scalar(dtype = i.type.dtype)() for i in inputs] # scalar_inputs = [scalar.Scalar(dtype = i.type.dtype)() for i in inputs]
# def transform(r): # def transform(r):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论