提交 fd41c1f1 authored 作者: Frederic Bastien's avatar Frederic Bastien

Make local_add_mul_fusion don't fuse too much GpuElemwise

上级 0e9d1ab3
...@@ -7347,18 +7347,23 @@ def local_add_mul_fusion(node): ...@@ -7347,18 +7347,23 @@ def local_add_mul_fusion(node):
s_op = node.op.scalar_op.__class__ s_op = node.op.scalar_op.__class__
new_inp = [] new_inp = []
fused = False fused = False
nb_inputs = len(node.inputs)
max_inputs = float('inf')
if hasattr(node.op, 'max_inputs'):
max_inputs = node.op.max_inputs(node)
for inp in node.inputs: for inp in node.inputs:
if (inp.owner and if (inp.owner and
isinstance(inp.owner.op, Elemwise) and isinstance(inp.owner.op, Elemwise) and
isinstance(inp.owner.op.scalar_op, s_op) and isinstance(inp.owner.op.scalar_op, s_op) and
# Do not duplicate the operation. # Do not duplicate the operation.
len(inp.clients) == 1): len(inp.clients) == 1 and
(nb_inputs+len(inp.owner.inputs) - 1) <= max_inputs):
new_inp.extend(inp.owner.inputs) new_inp.extend(inp.owner.inputs)
fused = True fused = True
else: else:
new_inp.append(inp) new_inp.append(inp)
# We ca not compare the number of inputs as Mul and Add could have # We can not compare the number of inputs as Mul and Add could have
# 0 or 1 inputs in some corner cases. # 0 or 1 inputs in some corner cases.
if fused: if fused:
output = node.op(*new_inp) output = node.op(*new_inp)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论