提交 4f15d801 authored 作者: Frederic's avatar Frederic

Fuse add/mul elemwise in a big add/mul elemwise before doing fusion

with Composite. This allow to pickle bigger Composite graph as it simplify the graph, so we bust later the max recursion limit.
上级 337f6c66
...@@ -4888,11 +4888,40 @@ class FusionOptimizer(Optimizer): ...@@ -4888,11 +4888,40 @@ class FusionOptimizer(Optimizer):
print >> stream, blanc, " time_toposort", prof[7] print >> stream, blanc, " time_toposort", prof[7]
def local_add_mul_fusion(node):
"""Fuse consecutive add or mul in one such node with more inputs.
It is better to fuse add/mul that way then in a Composite node as
this make the inner graph of the Compiste smaller. This allow to
put more computation in a Composite before hitting the max
recusion limit when pickling Composite.
"""
if (not isinstance(node.op, Elemwise) or
not isinstance(node.op.scalar_op, (scalar.Add, scalar.Mul))):
return False
s_op = node.op.scalar_op.__class__
for inp in node.inputs:
if (inp.owner and
isinstance(inp.owner.op, Elemwise) and
isinstance(inp.owner.op.scalar_op, s_op)):
l = list(node.inputs)
l.remove(inp)
return [node.op(*(l + inp.owner.inputs))]
if config.tensor.local_elemwise_fusion: if config.tensor.local_elemwise_fusion:
_logger.debug("enabling optimization fusion elemwise in fast_run") _logger.debug("enabling optimization fusion elemwise in fast_run")
#Must be after gpu(48.5) and before AddDestroyHandler(49.5) #Must be after gpu(48.5) and before AddDestroyHandler(49.5)
fuse_seqopt = gof.SequenceDB()
fuse_seqopt.register('local_add_mul_fusion',
FusionOptimizer(local_add_mul_fusion),
0, 'fast_run', 'fusion')
fuse_seqopt.register('composite_elemwise_fusion',
FusionOptimizer(local_elemwise_fusion),
1, 'fast_run', 'fusion')
compile.optdb.register('elemwise_fusion', compile.optdb.register('elemwise_fusion',
FusionOptimizer(local_elemwise_fusion), 49, fuse_seqopt, 49,
'fast_run', 'fusion', 'local_elemwise_fusion', 'fast_run', 'fusion', 'local_elemwise_fusion',
'FusionOptimizer') 'FusionOptimizer')
else: else:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论