提交 6747ddf7 authored 作者: Razvan Pascanu's avatar Razvan Pascanu 提交者: David Warde-Farley

Two optimization for ifelse

They optimize cases of the form ifelse(cond, ifelse(cond, x,y), y), where cond is the same.
上级 ca144799
...@@ -323,3 +323,51 @@ ifelse_equilibrium.register('seq_ifelse', ifelse_seqopt, 'fast_run', ...@@ -323,3 +323,51 @@ ifelse_equilibrium.register('seq_ifelse', ifelse_seqopt, 'fast_run',
'ifelse') 'ifelse')
optdb.register('ifelse_equilibriumOpt', ifelse_equilibrium, .5, 'fast_run', 'ifelse') optdb.register('ifelse_equilibriumOpt', ifelse_equilibrium, .5, 'fast_run', 'ifelse')
@gof.local_optimizer([None])
def cond_merge_ifs_true(node):
op = node.op
if not isinstance(op, IfElse):
return False
t_ins = node.inputs[1:][:op.n_outs]
replace = {}
for idx,tval in enumerate(t_ins):
if (tval.owner and isinstance(tval.owner.op, IfElse) and
tval.owner.inputs[0] == node.inputs[0]):
ins_op = tval.owner.op
ins_t = tval.owner.inputs[1:][:ins_op.n_outs]
replace[idx+1] = ins_t[tval.owner.outputs.index(tval)]
if len(replace.items()) == 0:
return False
old_ins = list(node.inputs)
for pos,var in replace.items():
old_ins[pos] = var
return op.make_node(*old_ins).outputs
@gof.local_optimizer([None])
def cond_merge_ifs_false(node):
op = node.op
if not isinstance(op, IfElse):
return False
f_ins = node.inputs[1:][op.n_outs:]
replace = {}
for idx,fval in enumerate(f_ins):
if (fval.owner and isinstance(fval.owner.op, IfElse) and
fval.owner.inputs[0] == node.inputs[0]):
ins_op = fval.owner.op
ins_t = fval.owner.inputs[1:][ins_op.n_outs:]
replace[idx+1+op.n_outs] = ins_t[fval.owner.outputs.index(fval)]
if len(replace.items()) == 0:
return False
old_ins = list(node.inputs)
for pos,var in replace.items():
old_ins[pos] = var
return op.make_node(*old_ins).outputs
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论