提交 069d112f authored 作者: Razvan Pascanu's avatar Razvan Pascanu 提交者: David Warde-Farley

Merge together 'ifelse' ops that have the same condition

上级 6747ddf7
......@@ -371,3 +371,53 @@ def cond_merge_ifs_false(node):
old_ins[pos] = var
return op.make_node(*old_ins).outputs
class CondMerge(gof.Optimizer):
""" Graph Optimizer that merges different cond ops """
def add_requirements(self,env):
env.extend(gof.toolbox.ReplaceValidate())
def apply(self, env):
nodelist = list(env.toposort())
cond_nodes = filter(lambda s: isinstance(s.op, IfElse), nodelist)
if len(cond_nodes) < 2:
return False
merging_node = cond_nodes[0]
for proposal in cond_nodes[1:]:
if (proposal.inputs[0] == merging_node.inputs[0] and
not find_up(proposal, merging_node)):
# Create a list of replacements for proposal
mn_ts = merging_node.inputs[1:][:merging_node.op.n_outs]
mn_fs = merging_node.inputs[1:][merging_node.op.n_outs:]
pl_ts = proposal.inputs[1:][:proposal.op.n_outs]
pl_fs = proposal.inputs[1:][proposal.op.n_outs:]
new_ins = ([merging_node.inputs[0]] +
mn_ts + pl_ts + mn_fs + pl_fs )
mn_name = '?'
if merging_node.op.name:
mn_name = merging_node.op.name
pl_name = '?'
mn_n_ts = len(mn_ts)
mn_n_fs = len(mn_fs)
if proposal.op.name:
pl_name = proposal.op.name
new_ifelse = IfElse(
n_outs = len(mn_ts+pl_ts),
as_view=False,
gpu = False,
name = mn_name+'&'+pl_name)
print 'here'
new_outs = new_ifelse.make_node(*new_ins).outputs
new_outs = [clone(x) for x in new_outs]
old_outs = []
if type(merging_node.outputs) not in (list, tuple):
old_outs += [merging_node.outputs]
else:
old_outs += merging_node.outputs
if type(proposal.outputs) not in (list, tuple):
old_outs += [proposal.outputs]
else:
old_outs += proposal.outputs
pairs = zip(old_outs, new_outs)
env.replace_all_validate(pairs, reason='cond_merge')
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论