提交 b2cf1984 authored 作者: Razvan Pascanu's avatar Razvan Pascanu

transformed the pushout optimization from local to a global opt.

There is no reason for this, just that I had a lot of problems with this optimization. It turned out the problem to be caused by local_useless_subtensor optimization (fix which I will push later). In my attempts to debug it though, I ended up converting it into a global optimizer, and I don't want to go through the pain of reverting it and making sure also that it works.
上级 2302fc97
...@@ -115,138 +115,156 @@ optdb.register( 'scanOp_remove_constants_and_unused_inputs' ...@@ -115,138 +115,156 @@ optdb.register( 'scanOp_remove_constants_and_unused_inputs'
, 'scan') , 'scan')
@gof.local_optimizer([None]) class PushOutNonSeqScan(gof.Optimizer):
def scan_pushout_non_seq_operation(node):
if not isinstance(node.op, scan_op.Scan):
return False
# this flag tells if there was any change during the last iterations
changed = True
clean_inputs, clean_outputs = scan_utils.reconstruct_graph(
node.op.inputs, node.op.outputs)
local_env = gof.Env(clean_inputs, clean_outputs)
max_iterations = 2*len(local_env.toposort()) + 3
counts = 0
to_remove = []
to_replace = []
replace_with_in = []
replace_with_out = []
op = node.op
# Construct the list of non_sequences to simplify a few things
st = op.n_seqs
st += int(numpy.sum([len(x) for x in
op.tap_array[:(op.n_mit_mot+op.n_mit_sot)] ]))
st += op.n_sit_sot
st += op.n_shared_outs
non_seqs = clean_inputs[st:]
st = ( op.n_seqs +
op.n_mit_mot +
op.n_mit_sot +
op.n_sit_sot +
op.n_nit_sot +
op.n_shared_outs +1 )
outer_non_seqs = node.inputs[st:]
while changed and counts < max_iterations: def __init__(self):
counts += 1 gof.Optimizer.__init__(self)
changed = False
for nd in local_env.toposort():
if ( numpy.all([ (x in non_seqs) or
(x.owner in to_remove) or
isinstance(x, tensor.Constant)
for x in nd.inputs]) and
# we can do this because the assumption is that a
# viewOp or deepCopyOp will be just at the end of the
# function and not somewhere in the middle ..
not isinstance(nd.op,theano.compile.ViewOp) and
not isinstance(nd.op,theano.compile.DeepCopyOp) and
# and we didn't already looked at this node
not nd in to_remove
):
# We have a candidate node to removable
# Step 1. Reconstruct it on outside
to_remove += [nd]
outside_ins = []
for x in nd.inputs:
if x in non_seqs:
outside_ins +=[ outer_non_seqs[non_seqs.index(x)]]
elif x in to_replace:
outside_ins +=[replace_with_out[to_replace.index(x)]]
elif isinstance(x, theano.Constant):
outside_ins +=[x.clone()]
else:
raise Exception(
('Error in the `scan_pushout_non_seq_operations`'
'. The optimization tries to move some '
'computation fron scan which is not allowed '
'to move. Report this on theano-users list'),x )
nw_outer_node = nd.op.make_node(*outside_ins)
# Step 2. Create variables for replacements
for idx,y in enumerate(nd.outputs):
y_place_holder = scan_utils.safe_new(y,'_replace')
to_replace += [y]
replace_with_in += [y_place_holder]
replace_with_out += [nw_outer_node.outputs[idx]]
changed = True
if counts >= max_iterations:
raise Exception( ('Error in the `scan_pushout_non_seq_operations`.'
' The optimization exhausted the maximal number '
'of iterations allowed!'))
# We need to check all candidate replacements and choose those that
# make sense for us
# Step 1. which elements of `to_replace` are used by remaining
# components of the inner function
clean_to_replace = []
clean_replace_with_in = []
clean_replace_with_out = []
existent_nodes = [ nd for nd in local_env.toposort()
if nd not in to_remove]
to_keep = []
for nd in existent_nodes:
to_keep += nd.inputs
for idx,out in enumerate(to_replace):
if out in to_keep and out.owner not in existent_nodes:
clean_to_replace += [out]
clean_replace_with_in += [replace_with_in[idx]]
clean_replace_with_out += [replace_with_out[idx]]
if len(clean_to_replace) > 0:
# We can finally put an end to all this madness
givens = {}
nw_outer = []
nw_inner = []
for to_repl, repl_in, repl_out in zip( clean_to_replace,
clean_replace_with_in,
clean_replace_with_out):
if isinstance(repl_out, theano.Constant):
# Is this even possible !?
repl_in = repl_out.clone()
else:
nw_inner += [repl_in]
nw_outer += [repl_out]
givens[to_repl] = repl_in
def add_requirements(self,env):
env.extend(gof.toolbox.ReplaceValidate())
_op_outs = scan_utils.clone(clean_outputs,
replace=givens)
_op_ins = clean_inputs + nw_inner def apply(self, env):
op_ins, op_outs = scan_utils.reconstruct_graph(_op_ins, _op_outs, '') nodelist = [x for x in env.toposort() if isinstance(x.op,
# Reconstruct node scan_op.Scan)]
nwScan = scan_op.Scan(op_ins, op_outs, op.info) for node in nodelist:
node = nwScan.make_node(* (node.inputs + nw_outer)) self.process_node(env, node)
return node.outputs
else: def process_node(self, env, node):
return False # this flag tells if there was any change during the last iterations
changed = True
clean_inputs, clean_outputs = scan_utils.reconstruct_graph(
node.op.inputs, node.op.outputs)
local_env = gof.Env(clean_inputs, clean_outputs)
max_iterations = 2*len(local_env.toposort()) + 3
counts = 0
to_remove = []
to_replace = []
replace_with_in = []
replace_with_out = []
op = node.op
# Construct the list of non_sequences to simplify a few things
st = op.n_seqs
st += int(numpy.sum([len(x) for x in
op.tap_array[:(op.n_mit_mot+op.n_mit_sot)] ]))
st += op.n_sit_sot
st += op.n_shared_outs
non_seqs = clean_inputs[st:]
st = ( op.n_seqs +
op.n_mit_mot +
op.n_mit_sot +
op.n_sit_sot +
op.n_nit_sot +
op.n_shared_outs +1 )
outer_non_seqs = node.inputs[st:]
assert len(non_seqs) == len(outer_non_seqs)
while changed and counts < max_iterations:
counts += 1
changed = False
for nd in local_env.toposort():
if ( numpy.all([ (x in non_seqs) or
(x.owner in to_remove) or
isinstance(x, tensor.Constant)
for x in nd.inputs]) and
# we can do this because the assumption is that a
# viewOp or deepCopyOp will be just at the end of the
# function and not somewhere in the middle ..
not isinstance(nd.op,theano.compile.ViewOp) and
not isinstance(nd.op,theano.compile.DeepCopyOp) and
# and we didn't already looked at this node
not nd in to_remove
):
# We have a candidate node to removable
# Step 1. Reconstruct it on outside
to_remove.append(nd)
outside_ins = []
for x in nd.inputs:
if x in non_seqs:
outside_ins +=[ outer_non_seqs[non_seqs.index(x)]]
elif x in to_replace:
outside_ins +=[replace_with_out[to_replace.index(x)]]
elif isinstance(x, theano.Constant):
outside_ins +=[x.clone()]
else:
raise Exception(
('Error in the `scan_pushout_non_seq_operations`'
'. The optimization tries to move some '
'computation fron scan which is not allowed '
'to move. Report this on theano-users list'),x )
nw_outer_node = nd.op.make_node(*outside_ins)
# Step 2. Create variables for replacements
for idx,y in enumerate(nd.outputs):
y_place_holder = scan_utils.safe_new(y,'_replace')
to_replace += [y]
replace_with_in += [y_place_holder]
assert type(y) == type(nw_outer_node.outputs[idx])
replace_with_out += [nw_outer_node.outputs[idx]]
changed = True
if counts >= max_iterations:
raise Exception( ('Error in the `scan_pushout_non_seq_operations`.'
' The optimization exhausted the maximal number '
'of iterations allowed!'))
# We need to check all candidate replacements and choose those that
# make sense for us
# Step 1. which elements of `to_replace` are used by remaining
# components of the inner function
clean_to_replace = []
clean_replace_with_in = []
clean_replace_with_out = []
existent_nodes = [ nd for nd in local_env.toposort()
if nd not in to_remove]
to_keep = []
for nd in existent_nodes:
to_keep += nd.inputs
for idx,out in enumerate(to_replace):
if out in to_keep and out.owner not in existent_nodes:
clean_to_replace += [out]
clean_replace_with_in += [replace_with_in[idx]]
clean_replace_with_out += [replace_with_out[idx]]
if len(clean_to_replace) > 0:
# We can finally put an end to all this madness
givens = {}
nw_outer = []
nw_inner = []
for to_repl, repl_in, repl_out in zip( clean_to_replace,
clean_replace_with_in,
clean_replace_with_out):
if isinstance(repl_out, theano.Constant):
repl_in = repl_out.clone()
else:
nw_inner += [repl_in]
nw_outer += [repl_out]
givens[to_repl] = repl_in
_op_outs = scan_utils.clone(clean_outputs,
replace=givens)
_op_ins = clean_inputs + nw_inner
op_ins, op_outs = scan_utils.reconstruct_graph(_op_ins, _op_outs)
# Reconstruct node
nwScan = scan_op.Scan(op_ins, op_outs, op.info)
nw_node = nwScan.make_node(* (node.inputs + nw_outer))
env.replace_all_validate(zip(node.outputs, nw_node.outputs),
reason = 'scan_push_computation_out')
return True
else:
return False
optdb.register('scanOp_pushout_nonseqs_ops', optdb.register('scanOp_pushout_nonseqs_ops',
opt.in2out( scan_pushout_non_seq_operation, PushOutNonSeqScan(),
ignore_newtrees=True), #opt.out2in( scan_pushout_non_seq_operation),
1.90, # ignore_newtrees=True),
1.899,
'fast_run', 'fast_run',
'scan') 'scan')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论