提交 15362d8c authored 作者: Razvan Pascanu's avatar Razvan Pascanu 提交者: David Warde-Farley

Remove identical inputs to the ifelse op (identical on both branches)

上级 069d112f
......@@ -421,3 +421,57 @@ class CondMerge(gof.Optimizer):
env.replace_all_validate(pairs, reason='cond_merge')
@gof.local_optimizer([None])
def cond_remove_identical(node):
op = node.op
if not isinstance(op, IfElse):
return False
ts = node.inputs[1:][:op.n_outs]
fs = node.inputs[1:][op.n_outs:]
# sync outs
out_map = {}
for idx in xrange(len(node.outputs)):
if idx not in out_map:
for jdx in xrange(idx+1,len(node.outputs)):
if (ts[idx] == ts[jdx] and
fs[idx] == fs[jdx] and
jdx not in out_map):
out_map[jdx] = idx
if len(out_map.keys()) == 0:
return False
nw_ts = []
nw_fs = []
inv_map = {}
pos = 0
for idx in xrange(len(node.outputs)):
if idx not in out_map:
inv_map[idx] = pos
pos = pos + 1
nw_ts.append(ts[idx])
nw_fs.append(fs[idx])
new_ifelse = IfElse(n_outs = len(nw_ts),
as_view = op.as_view,
gpu = op.gpu,
name = op.name)
new_ins = [node.inputs[0]] + nw_ts + nw_fs
new_outs = new_ifelse.make_node(*new_ins).outputs
rval = []
for idx in xrange(len(node.outputs)):
if idx in out_map.keys():
rval += [new_outs[inv_map[out_map[idx]]] ]
else:
rval += [new_outs[inv_map[idx]]]
return rval
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论