提交 c36bdaae authored 作者: Razvan Pascanu's avatar Razvan Pascanu

new optimization

上级 958f0cab
...@@ -324,6 +324,200 @@ scan_seqopt.register('scanOp_pushout_nonseqs_ops', ...@@ -324,6 +324,200 @@ scan_seqopt.register('scanOp_pushout_nonseqs_ops',
1, 1,
'fast_run', 'fast_run',
'scan') 'scan')
# This is a global opt for historical reason
# It should be possible to change it to a local opt.
class PushOutSeqScan(gof.Optimizer):
def __init__(self):
gof.Optimizer.__init__(self)
def add_requirements(self, fgraph):
fgraph.attach_feature(gof.toolbox.ReplaceValidate())
def apply(self, fgraph):
nodelist = [x for x in fgraph.toposort() if isinstance(x.op,
scan_op.Scan)]
for node in nodelist:
self.process_node(fgraph, node)
def process_node(self, fgraph, node):
# this flag tells if there was any change during the last iterations
changed = True
clean_inputs, clean_outputs = scan_utils.reconstruct_graph(
node.op.inputs, node.op.outputs)
local_fgraph = gof.FunctionGraph(clean_inputs, clean_outputs)
max_iterations = 2 * len(local_fgraph.toposort()) + 3
counts = 0
to_remove = []
to_replace = []
replace_with_in = []
replace_with_out = []
op = node.op
# Construct the list of non_sequences to simplify a few things
inner_non_seqs = op.inner_non_seqs(clean_inputs)
outer_non_seqs = op.outer_non_seqs(node.inputs)
inner_seqs = op.inner_seqs(clean_inputs)
outer_seqs = op.outer_seqs(node.inputs)
assert len(inner_non_seqs) == len(outer_non_seqs)
assert len(inner_seqs) == len(outer_seqs)
while changed and counts < max_iterations:
counts += 1
changed = False
for nd in local_fgraph.toposort():
if (isinstance(nd.op, theano.tensor.Elemwise) and
numpy.all([(x in inner_non_seqs) or
(x.owner in to_remove) or
isinstance(x, tensor.Constant) or
(x in inner_seqs)
for x in nd.inputs]) and
not nd in to_remove):
to_remove.append(nd)
outside_ins = []
for x in nd.inputs:
if x in inner_non_seqs:
_idx = inner_non_seqs.index(x)
outside_ins += [outer_non_seqs[_idx]]
elif x in inner_seqs:
outside_ins += [outer_seqs[inner_seqs.index(x)]]
elif x in to_replace:
outside_ins += [replace_with_out[\
to_replace.index(x)]]
elif isinstance(x, theano.Constant):
outside_ins += [x.clone()]
else:
raise Exception(
('Error in the `scan_pushout_non_seq_'
'operations`. The optimization tries '
'to move some computation fron scan '
'which is not allowed to move. Report '
'this on theano-users list'), x)
nw_outer_node = nd.op.make_node(*outside_ins)
# Step 2. Create variables for replacements
for idx, y in enumerate(nd.outputs):
y_place_holder = scan_utils.safe_new(y, '_replace')
to_replace += [y]
replace_with_in += [y_place_holder]
replace_with_out += [nw_outer_node.outputs[idx]]
changed = True
elif (isinstance(nd.op, theano.tensor.DimShuffle) and
(nd.inputs[0] in inner_seqs or
nd.inputs[0].owner in to_remove) and
not nd in to_remove):
to_remove.append(nd)
x = nd.inputs[0]
if x in inner_seqs:
outside_ins = outer_seqs[inner_seqs.index(x)]
elif x in to_replace:
outside_ins = replace_with_out[to_replace.index(x)]
new_ord = (0,)
for old_ord in nd.op.new_order:
if isinstance(old_ord, int):
new_ord += (old_ord + 1,)
else:
new_ord += (old_ord,)
new_outer = outside_ins.dimshuffle(new_ord)
y = nd.outputs[0]
y_place_holder = scan_utils.safe_new(y, '_replace')
to_replace += [y]
replace_with_in += [y_place_holder]
replace_with_out += [new_outer]
changed = True
if counts >= max_iterations:
raise Exception('Error in the `scan_pushout_non_seq_operations`.'
' The optimization exhausted the maximal number '
'of iterations allowed!')
# We need to check all candidate replacements and choose those that
# make sense for us
# Step 1. which elements of `to_replace` are used by remaining
# components of the inner function
clean_to_replace = []
clean_replace_with_in = []
clean_replace_with_out = []
existent_nodes = [nd for nd in local_fgraph.toposort()
if nd not in to_remove]
to_keep = []
for nd in existent_nodes:
to_keep += nd.inputs
for idx, out in enumerate(to_replace):
if out in to_keep and out.owner not in existent_nodes:
clean_to_replace += [out]
clean_replace_with_in += [replace_with_in[idx]]
clean_replace_with_out += [replace_with_out[idx]]
if len(clean_to_replace) > 0:
# We can finally put an end to all this madness
givens = {}
nw_outer = []
nw_inner = []
for to_repl, repl_in, repl_out in zip(clean_to_replace,
clean_replace_with_in,
clean_replace_with_out):
if isinstance(repl_out, theano.Constant):
repl_in = repl_out.clone()
else:
nw_inner += [repl_in]
nw_outer += [repl_out]
givens[to_repl] = repl_in
_op_outs = scan_utils.clone(clean_outputs,
replace=givens)
_op_ins = nw_inner + clean_inputs
op_ins, op_outs = scan_utils.reconstruct_graph(_op_ins, _op_outs)
# Reconstruct node
nw_info = op.info.copy()
nw_info['n_seqs'] += len(nw_inner)
nwScan = scan_op.Scan(op_ins, op_outs, nw_info)
nw_node = nwScan.make_node(* (node.inputs[:1] + nw_outer +
node.inputs[1:]))
fgraph.replace_all_validate_remove(
zip(node.outputs, nw_node.outputs),
remove=[node],
reason='scan_push_computation_out')
return True
elif (to_keep == [] and
not op.as_while and
not op.outer_mitmot(node)):
# Nothing in the inner graph should be kept
replace_with = {}
for idx, out in enumerate(to_replace):
if out in local_fgraph.outputs:
x = node.outputs[local_fgraph.outputs.index(out)]
_y = replace_with_out[idx]
ls = local_fgraph.outputs
if out in op.inner_mitsot_outs(ls):
odx = op.inner_mitsot_outs(ls).index(out)
inp = op.outer_mitsot(node)[odx]
st = abs(numpy.min(op.mitsot_taps()))
y = tensor.set_subtensor(inp[st:], _y)
elif out in op.inner_sitsot_outs(ls):
odx = op.inner_sitsot_outs(ls).index(out)
inp = op.outer_sitsot(node)[odx]
y = tensor.set_subtensor(inp[1:], _y)
elif out in op.inner_nitsot_outs(ls):
y = _y
else:
y = _y[-1]
replace_with[x] = y
# We need to add one extra dimension to the outputs
if replace_with:
fgraph.replace_all_validate_remove(
replace_with.items(),
remove=[node],
reason='scan_push_seq_computation_out')
else:
return False
class ScanInplaceOptimizer(Optimizer): class ScanInplaceOptimizer(Optimizer):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论