提交 850c24f0 authored 作者: Razvan Pascanu's avatar Razvan Pascanu

Optimization that pushes out computation from scan ops.

上级 396211cb
......@@ -112,6 +112,152 @@ optdb.register( 'scanOp_remove_constants_and_unused_inputs'
, 'fast_run'
, 'scan')
@gof.local_optimizer([None])
def scan_pushout_non_seq_operation(node):
if not isinstance(node.op, scan_op.Scan):
return False
# this flag tells if there was any change during the last iterations
changed = True
try:
clean_inputs, clean_outputs = scan_utils.reconstruct_graph(
node.op.inputs, node.op.outputs)
local_env = gof.Env(clean_inputs, clean_outputs)
except:
import ipdb; ipdb.set_trace()
max_iterations = 2*len(local_env.toposort()) + 3
counts = 0
to_remove = []
to_replace = []
replace_with_in = []
replace_with_out = []
op = node.op
# Construct the list of non_sequences to simplify a few things
st = op.n_seqs
st += int(numpy.sum([len(x) for x in
op.tap_array[:(op.n_mit_mot+op.n_mit_sot)] ]))
st += op.n_sit_sot
st += op.n_shared_outs
non_seqs = clean_inputs[st:]
st = ( op.n_seqs +
op.n_mit_mot +
op.n_mit_sot +
op.n_sit_sot +
op.n_nit_sot +
op.n_shared_outs +1 )
outer_non_seqs = node.inputs[st:]
while changed and counts < max_iterations:
counts += 1
changed = False
for nd in local_env.toposort():
if ( numpy.all([ (x in non_seqs) or
(x.owner in to_remove) or
isinstance(x, tensor.Constant)
for x in nd.inputs]) and
# we can do this because the assumption is that a
# viewOp or deepCopyOp will be just at the end of the
# function and not somewhere in the middle ..
not isinstance(nd.op,theano.compile.ViewOp) and
not isinstance(nd.op,theano.compile.DeepCopyOp) and
# and we didn't already looked at this node
not nd in to_remove
):
# We have a candidate node to removable
# Step 1. Reconstruct it on outside
to_remove += [nd]
outside_ins = []
for x in nd.inputs:
if x in non_seqs:
outside_ins +=[ outer_non_seqs[non_seqs.index(x)]]
elif x in to_replace:
outside_ins +=[replace_with_out[to_replace.index(x)]]
elif isinstance(x, theano.Constant):
outside_ins +=[x.clone()]
else:
raise Exception(
('Error in the `scan_pushout_non_seq_operations`'
'. The optimization tries to move some '
'computation fron scan which is not allowed '
'to move. Report this on theano-users list'),x )
nw_outer_node = nd.op.make_node(*outside_ins)
# Step 2. Create variables for replacements
for idx,y in enumerate(nd.outputs):
y_place_holder = scan_utils.safe_new(y,'_replace')
to_replace += [y]
replace_with_in += [y_place_holder]
if (cuda.cuda_available and
isinstance(nw_outer_node.outputs[idx],
CudaNdarrayType)):
nw_out = nw_outer_node.outputs[idx]
replace_with_out += [host_from_gpu(nw_out)]
else:
replace_with_out += [nw_outer_node.outputs[idx]]
changed = True
if counts >= max_iterations:
raise Exception( ('Error in the `scan_pushout_non_seq_operations`.'
' The optimization exhausted the maximal number '
'of iterations allowed!'))
# We need to check all candidate replacements and choose those that
# make sense for us
# Step 1. which elements of `to_replace` are used by remaining
# components of the inner function
clean_to_replace = []
clean_replace_with_in = []
clean_replace_with_out = []
existent_nodes = [ nd for nd in local_env.toposort()
if nd not in to_remove]
to_keep = []
for nd in existent_nodes:
to_keep += nd.inputs
for idx,out in enumerate(to_replace):
if out in to_keep and out.owner not in existent_nodes:
clean_to_replace += [out]
clean_replace_with_in += [replace_with_in[idx]]
clean_replace_with_out += [replace_with_out[idx]]
if len(clean_to_replace) > 0:
# We can finally put an end to all this madness
givens = {}
nw_outer = []
nw_inner = []
for to_repl, repl_in, repl_out in zip( clean_to_replace,
clean_replace_with_in,
clean_replace_with_out):
if isinstance(repl_out, theano.Constant):
# Is this even possible !?
repl_in = repl_out.clone()
else:
nw_inner += [repl_in]
nw_outer += [repl_out]
givens[to_repl] = repl_in
_op_outs = scan_utils.clone(clean_outputs,
replace=givens)
_op_ins = clean_inputs + nw_inner
op_ins, op_outs = scan_utils.reconstruct_graph(_op_ins, _op_outs, '')
# Reconstruct node
nwScan = scan_op.Scan(op_ins, op_outs, op.info)
node = nwScan.make_node(* (node.inputs + nw_outer))
return node.outputs
else:
return False
optdb.register('scanOp_pushout_nonseqs_ops',
opt.in2out( scan_pushout_non_seq_operation,
ignore_newtrees=True),
1.90,
'fast_run',
'scan')
@gof.local_optimizer([None])
def scan_make_inplace(node):
op = node.op
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论