提交 b2cf1984 authored 作者: Razvan Pascanu's avatar Razvan Pascanu

transformed the pushout optimization from local to a global opt.

There is no reason for this, just that I had a lot of problems with this optimization. It turned out the problem to be caused by local_useless_subtensor optimization (fix which I will push later). In my attempts to debug it though, I ended up converting it into a global optimizer, and I don't want to go through the pain of reverting it and making sure also that it works.
上级 2302fc97
...@@ -115,16 +115,31 @@ optdb.register( 'scanOp_remove_constants_and_unused_inputs' ...@@ -115,16 +115,31 @@ optdb.register( 'scanOp_remove_constants_and_unused_inputs'
, 'scan') , 'scan')
@gof.local_optimizer([None]) class PushOutNonSeqScan(gof.Optimizer):
def scan_pushout_non_seq_operation(node):
if not isinstance(node.op, scan_op.Scan): def __init__(self):
return False gof.Optimizer.__init__(self)
def add_requirements(self,env):
env.extend(gof.toolbox.ReplaceValidate())
def apply(self, env):
nodelist = [x for x in env.toposort() if isinstance(x.op,
scan_op.Scan)]
for node in nodelist:
self.process_node(env, node)
def process_node(self, env, node):
# this flag tells if there was any change during the last iterations # this flag tells if there was any change during the last iterations
changed = True changed = True
clean_inputs, clean_outputs = scan_utils.reconstruct_graph( clean_inputs, clean_outputs = scan_utils.reconstruct_graph(
node.op.inputs, node.op.outputs) node.op.inputs, node.op.outputs)
local_env = gof.Env(clean_inputs, clean_outputs)
local_env = gof.Env(clean_inputs, clean_outputs)
max_iterations = 2*len(local_env.toposort()) + 3 max_iterations = 2*len(local_env.toposort()) + 3
counts = 0 counts = 0
to_remove = [] to_remove = []
...@@ -146,10 +161,11 @@ def scan_pushout_non_seq_operation(node): ...@@ -146,10 +161,11 @@ def scan_pushout_non_seq_operation(node):
op.n_nit_sot + op.n_nit_sot +
op.n_shared_outs +1 ) op.n_shared_outs +1 )
outer_non_seqs = node.inputs[st:] outer_non_seqs = node.inputs[st:]
assert len(non_seqs) == len(outer_non_seqs)
while changed and counts < max_iterations: while changed and counts < max_iterations:
counts += 1 counts += 1
changed = False changed = False
for nd in local_env.toposort(): for nd in local_env.toposort():
if ( numpy.all([ (x in non_seqs) or if ( numpy.all([ (x in non_seqs) or
(x.owner in to_remove) or (x.owner in to_remove) or
...@@ -166,7 +182,7 @@ def scan_pushout_non_seq_operation(node): ...@@ -166,7 +182,7 @@ def scan_pushout_non_seq_operation(node):
# We have a candidate node to removable # We have a candidate node to removable
# Step 1. Reconstruct it on outside # Step 1. Reconstruct it on outside
to_remove += [nd] to_remove.append(nd)
outside_ins = [] outside_ins = []
for x in nd.inputs: for x in nd.inputs:
if x in non_seqs: if x in non_seqs:
...@@ -184,9 +200,11 @@ def scan_pushout_non_seq_operation(node): ...@@ -184,9 +200,11 @@ def scan_pushout_non_seq_operation(node):
nw_outer_node = nd.op.make_node(*outside_ins) nw_outer_node = nd.op.make_node(*outside_ins)
# Step 2. Create variables for replacements # Step 2. Create variables for replacements
for idx,y in enumerate(nd.outputs): for idx,y in enumerate(nd.outputs):
y_place_holder = scan_utils.safe_new(y,'_replace') y_place_holder = scan_utils.safe_new(y,'_replace')
to_replace += [y] to_replace += [y]
replace_with_in += [y_place_holder] replace_with_in += [y_place_holder]
assert type(y) == type(nw_outer_node.outputs[idx])
replace_with_out += [nw_outer_node.outputs[idx]] replace_with_out += [nw_outer_node.outputs[idx]]
changed = True changed = True
...@@ -222,31 +240,31 @@ def scan_pushout_non_seq_operation(node): ...@@ -222,31 +240,31 @@ def scan_pushout_non_seq_operation(node):
clean_replace_with_in, clean_replace_with_in,
clean_replace_with_out): clean_replace_with_out):
if isinstance(repl_out, theano.Constant): if isinstance(repl_out, theano.Constant):
# Is this even possible !?
repl_in = repl_out.clone() repl_in = repl_out.clone()
else: else:
nw_inner += [repl_in] nw_inner += [repl_in]
nw_outer += [repl_out] nw_outer += [repl_out]
givens[to_repl] = repl_in givens[to_repl] = repl_in
_op_outs = scan_utils.clone(clean_outputs, _op_outs = scan_utils.clone(clean_outputs,
replace=givens) replace=givens)
_op_ins = clean_inputs + nw_inner _op_ins = clean_inputs + nw_inner
op_ins, op_outs = scan_utils.reconstruct_graph(_op_ins, _op_outs, '') op_ins, op_outs = scan_utils.reconstruct_graph(_op_ins, _op_outs)
# Reconstruct node # Reconstruct node
nwScan = scan_op.Scan(op_ins, op_outs, op.info) nwScan = scan_op.Scan(op_ins, op_outs, op.info)
node = nwScan.make_node(* (node.inputs + nw_outer)) nw_node = nwScan.make_node(* (node.inputs + nw_outer))
return node.outputs env.replace_all_validate(zip(node.outputs, nw_node.outputs),
reason = 'scan_push_computation_out')
return True
else: else:
return False return False
optdb.register('scanOp_pushout_nonseqs_ops', optdb.register('scanOp_pushout_nonseqs_ops',
opt.in2out( scan_pushout_non_seq_operation, PushOutNonSeqScan(),
ignore_newtrees=True), #opt.out2in( scan_pushout_non_seq_operation),
1.90, # ignore_newtrees=True),
1.899,
'fast_run', 'fast_run',
'scan') 'scan')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论