提交 6a2333d3 authored 作者: Razvan Pascanu's avatar Razvan Pascanu 提交者: David Warde-Farley

Two more optimization that deal with lifting the ifelse ops (in certain

conditions)
上级 15362d8c
......@@ -475,3 +475,112 @@ def cond_remove_identical(node):
acceptable_ops = (theano.tensor.basic.Dot,
theano.tensor.basic.Reshape,
theano.tensor.basic.Shape,
theano.tensor.basic.SpecifyShape,
theano.tensor.basic.MaxAndArgmax,
theano.tensor.basic.Subtensor,
theano.tensor.basic.IncSubtensor,
theano.tensor.basic.Rebroadcast,
theano.tensor.basic.Alloc,
theano.tensor.elemwise.Elemwise,
theano.tensor.elemwise.DimShuffle )
@gof.local_optimizer([None])
def cond_lift_single_if(main_node):
if not (isinstance(main_node.op, acceptable_ops)):
return False
all_inp_nodes = set()
for inp in main_node.inputs:
all_inp_nodes.add(inp.owner)
ifnodes = [ x for x in list(all_inp_nodes) if x and isinstance(x.op, IfElse)]
# if we have multiple ifs as inputs .. it all becomes quite complicated
# :)
if len(ifnodes) != 1:
return False
node = ifnodes[0]
op = node.op
ts = node.inputs[1:][:op.n_outs]
fs = node.inputs[1:][op.n_outs:]
outs = main_node.outputs
mop = main_node.op
true_ins = []
false_ins = []
for x in main_node.inputs:
if x in node.outputs:
idx = node.outputs.index(x)
true_ins.append(ts[idx])
false_ins.append(fs[idx])
else:
true_ins.append(x)
false_ins.append(x)
true_eval = mop.make_node(*true_ins).outputs
false_eval = mop.make_node(*false_ins).outputs
#true_eval = clone(outs, replace = dict(zip(node.outputs, ts)))
#false_eval = clone(outs, replace = dict(zip(node.outputs, fs)))
nw_outs = ifelse(node.inputs[0], true_eval, false_eval)
if type(nw_outs) not in (tuple, list):
nw_outs = [nw_outs]
return nw_outs
@gof.local_optimizer([None])
def cond_merge_random_op(main_node):
if isinstance(main_node.op, IfElse):
return False
all_inp_nodes = set()
for inp in main_node.inputs:
all_inp_nodes.add(inp.owner)
cond_nodes = [ x for x in list(all_inp_nodes) if x and isinstance(x.op, IfElse)]
if len(cond_nodes) < 2:
return False
merging_node = cond_nodes[0]
for proposal in cond_nodes[1:]:
if (proposal.inputs[0] == merging_node.inputs[0] and
not find_up(proposal, merging_node) and
not find_up(merging_node, proposal)):
# Create a list of replacements for proposal
mn_ts = merging_node.inputs[1:][:merging_node.op.n_outs]
mn_fs = merging_node.inputs[1:][merging_node.op.n_outs:]
pl_ts = proposal.inputs[1:][:proposal.op.n_outs]
pl_fs = proposal.inputs[1:][proposal.op.n_outs:]
new_ins = ([merging_node.inputs[0]] +
mn_ts + pl_ts + mn_fs + pl_fs )
mn_name = '?'
if merging_node.op.name:
mn_name = merging_node.op.name
pl_name = '?'
mn_n_ts = len(mn_ts)
mn_n_fs = len(mn_fs)
if proposal.op.name:
pl_name = proposal.op.name
new_ifelse = IfElse(
n_outs = len(mn_ts+pl_ts),
as_view=False,
gpu = False,
name = mn_name+'&'+pl_name)
new_outs = new_ifelse.make_node(*new_ins).outputs
old_outs = []
if type(merging_node.outputs) not in (list, tuple):
old_outs += [merging_node.outputs]
else:
old_outs += merging_node.outputs
if type(proposal.outputs) not in (list, tuple):
old_outs += [proposal.outputs]
else:
old_outs += proposal.outputs
pairs = zip(old_outs, new_outs)
main_outs = clone(main_node.outputs, replace=pairs)
return main_outs
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论