提交 1f016055 authored 作者: Frederic Bastien's avatar Frederic Bastien

Make the fusion of elemwise respect in all case the maximum number of op and…

Make the fusion of elemwise respect in all case the maximum number of op and make it fuse more element to get closer to it in some case. This make the new failure in the Buildbot fail in less case.
上级 533f8738
...@@ -3014,22 +3014,30 @@ def local_elemwise_fusion_op(OP, max_input_fct=lambda node: 1024): ...@@ -3014,22 +3014,30 @@ def local_elemwise_fusion_op(OP, max_input_fct=lambda node: 1024):
# There is a hard limit of 256 bytes for the formal argument list to a # There is a hard limit of 256 bytes for the formal argument list to a
# GPU kernel function. # GPU kernel function.
max_nb_input = max_input_fct(node) max_nb_input = max_input_fct(node)
#print len(node.inputs),max_nb_input # The number of input to the new fused op if we don't fuse more inputs.
new_nb_input = len(node.inputs) new_nb_input = len(node.inputs)
# Did we fused something.
# Needed as we can fuse unary op that don't change the number of input
# And their is case where the input inputs are the same as the current node.
# That won't change the number of inputs of the new op.
fused = False
for i in node.inputs: for i in node.inputs:
do_fusion = False do_fusion = False
catch = False catch = False
tmp_input=[]#used to remove duplicate input. tmp_input=[]#used to remove duplicate input.
tmp_scalar=[] tmp_scalar=[]
if ((new_nb_input+1)<=max_nb_input and
i.owner and # We should not check the number of inputs here
# As fusing op don't always change the number of input.
if (i.owner and
isinstance(i.owner.op, OP) and isinstance(i.owner.op, OP) and
len(i.clients)==1): len(i.clients)==1):
#if the scalar_op don't have a c implementation, we skip its fusion to allow the fusion of the other ops.
do_fusion=True do_fusion=True
try: try:
s_input = [] s_input = []
#we should not put duplicate input into s_inputs and inputs
for ii in i.owner.inputs: for ii in i.owner.inputs:
if ii in inputs: if ii in inputs:
s_input.append(s_inputs[inputs.index(ii)]) s_input.append(s_inputs[inputs.index(ii)])
...@@ -3040,6 +3048,8 @@ def local_elemwise_fusion_op(OP, max_input_fct=lambda node: 1024): ...@@ -3040,6 +3048,8 @@ def local_elemwise_fusion_op(OP, max_input_fct=lambda node: 1024):
tmp_input.append(ii) tmp_input.append(ii)
tmp_scalar.append(s_input[-1]) tmp_scalar.append(s_input[-1])
s_op=i.owner.op.scalar_op(*s_input) s_op=i.owner.op.scalar_op(*s_input)
#if the scalar_op don't have a c implementation, we skip its fusion to allow the fusion of the other ops.
i.owner.op.scalar_op.c_code(s_op.owner,"test_presence_of_c_code", i.owner.op.scalar_op.c_code(s_op.owner,"test_presence_of_c_code",
["x" for x in i.owner.inputs], ["x" for x in i.owner.inputs],
"z",{}) "z",{})
...@@ -3051,9 +3061,18 @@ def local_elemwise_fusion_op(OP, max_input_fct=lambda node: 1024): ...@@ -3051,9 +3061,18 @@ def local_elemwise_fusion_op(OP, max_input_fct=lambda node: 1024):
_logger.info("%s does not implement the c_code function. As well as being potentially slow, this disables loop fusion of this op." % str(i.owner.op.scalar_op)) _logger.info("%s does not implement the c_code function. As well as being potentially slow, this disables loop fusion of this op." % str(i.owner.op.scalar_op))
do_fusion=False do_fusion=False
if do_fusion: # Compute the number of inputs in case we fuse this input
#we should not put duplicate input into s_inputs and inputs # The -1 is that replace an existing input with others.
new_nb_input+=1 new_nb_input_ = new_nb_input + len(tmp_input) - 1
# If the new input is already an input of the current node, it was already counted.
# When new_nb_input was initialized to len(node.inputs)
for x in tmp_input:
if x in node.inputs:
new_nb_input_ -= 1
if do_fusion and (new_nb_input_ <= max_nb_input):
fused = True
new_nb_input = new_nb_input_
inputs.extend(tmp_input) inputs.extend(tmp_input)
s_inputs.extend(tmp_scalar) s_inputs.extend(tmp_scalar)
s_g.append(s_op) s_g.append(s_op)
...@@ -3066,12 +3085,11 @@ def local_elemwise_fusion_op(OP, max_input_fct=lambda node: 1024): ...@@ -3066,12 +3085,11 @@ def local_elemwise_fusion_op(OP, max_input_fct=lambda node: 1024):
s_inputs.append(s) s_inputs.append(s)
s_g.append(s) s_g.append(s)
#if no inputs have are an elemwise, there is nothing to fuse. if not fused:
if new_nb_input==len(node.inputs):
#print "local_elemwise_fusion: no elemwise in inputs. Nothing to fuse."
return False return False
assert len(s_inputs)==len(inputs) assert new_nb_input == len(inputs)
assert len(s_inputs) == len(inputs)
otype = node.outputs[0].type otype = node.outputs[0].type
s_new_out=node.op.scalar_op(*s_g) s_new_out=node.op.scalar_op(*s_g)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论