提交 aeb26e55 authored 作者: Razvan Pascanu's avatar Razvan Pascanu 提交者: David Warde-Farley

reordered optimization order

I want to order them according to the order in the sequential shell that calls them.
上级 0f75cc90
......@@ -377,6 +377,66 @@ where, each of the optimization do the following things:
optdb.register('ifelse_equilibriumOpt', ifelse_equilibrium, .5, 'fast_run',
'ifelse')
acceptable_ops = (theano.tensor.basic.Dot,
theano.tensor.basic.Reshape,
theano.tensor.basic.Shape,
theano.tensor.basic.SpecifyShape,
theano.tensor.basic.MaxAndArgmax,
theano.tensor.basic.Subtensor,
theano.tensor.basic.IncSubtensor,
theano.tensor.basic.Rebroadcast,
theano.tensor.basic.Alloc,
theano.tensor.elemwise.Elemwise,
theano.tensor.elemwise.DimShuffle)
@gof.local_optimizer([None])
def ifelse_lift_single_if_through_acceptable_ops(main_node):
"""This optimization lifts up certain ifelse instances.
op(ifelse(c, x, y)) -> ifelse(c, op(x), op(y))
if `op` is in the `acceptable_ops` list, and there is no other if as
input to that specific `op`, and the if has no other clients !?
"""
if not (isinstance(main_node.op, acceptable_ops)):
return False
all_inp_nodes = set()
for inp in main_node.inputs:
all_inp_nodes.add(inp.owner)
ifnodes = [x for x in list(all_inp_nodes)
if x and isinstance(x.op, IfElse)]
# if we have multiple ifs as inputs .. it all becomes quite complicated
# :)
if len(ifnodes) != 1:
return False
node = ifnodes[0]
op = node.op
ts = node.inputs[1:][:op.n_outs]
fs = node.inputs[1:][op.n_outs:]
outs = main_node.outputs
mop = main_node.op
true_ins = []
false_ins = []
for x in main_node.inputs:
if x in node.outputs:
idx = node.outputs.index(x)
true_ins.append(ts[idx])
false_ins.append(fs[idx])
else:
true_ins.append(x)
false_ins.append(x)
true_eval = mop.make_node(*true_ins).outputs
false_eval = mop.make_node(*false_ins).outputs
#true_eval = clone(outs, replace = dict(zip(node.outputs, ts)))
#false_eval = clone(outs, replace = dict(zip(node.outputs, fs)))
nw_outs = ifelse(node.inputs[0], true_eval, false_eval)
if type(nw_outs) not in (tuple, list):
nw_outs = [nw_outs]
return nw_outs
@gof.local_optimizer([None])
def cond_merge_ifs_true(node):
......@@ -527,60 +587,6 @@ def cond_remove_identical(node):
return rval
acceptable_ops = (theano.tensor.basic.Dot,
theano.tensor.basic.Reshape,
theano.tensor.basic.Shape,
theano.tensor.basic.SpecifyShape,
theano.tensor.basic.MaxAndArgmax,
theano.tensor.basic.Subtensor,
theano.tensor.basic.IncSubtensor,
theano.tensor.basic.Rebroadcast,
theano.tensor.basic.Alloc,
theano.tensor.elemwise.Elemwise,
theano.tensor.elemwise.DimShuffle)
@gof.local_optimizer([None])
def cond_lift_single_if(main_node):
if not (isinstance(main_node.op, acceptable_ops)):
return False
all_inp_nodes = set()
for inp in main_node.inputs:
all_inp_nodes.add(inp.owner)
ifnodes = [x for x in list(all_inp_nodes)
if x and isinstance(x.op, IfElse)]
# if we have multiple ifs as inputs .. it all becomes quite complicated
# :)
if len(ifnodes) != 1:
return False
node = ifnodes[0]
op = node.op
ts = node.inputs[1:][:op.n_outs]
fs = node.inputs[1:][op.n_outs:]
outs = main_node.outputs
mop = main_node.op
true_ins = []
false_ins = []
for x in main_node.inputs:
if x in node.outputs:
idx = node.outputs.index(x)
true_ins.append(ts[idx])
false_ins.append(fs[idx])
else:
true_ins.append(x)
false_ins.append(x)
true_eval = mop.make_node(*true_ins).outputs
false_eval = mop.make_node(*false_ins).outputs
#true_eval = clone(outs, replace = dict(zip(node.outputs, ts)))
#false_eval = clone(outs, replace = dict(zip(node.outputs, fs)))
nw_outs = ifelse(node.inputs[0], true_eval, false_eval)
if type(nw_outs) not in (tuple, list):
nw_outs = [nw_outs]
return nw_outs
@gof.local_optimizer([None])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论