提交 bac982dd authored 作者: Razvan Pascanu's avatar Razvan Pascanu

merge no conflicts

...@@ -105,9 +105,9 @@ def rebuild_collect_shared( outputs ...@@ -105,9 +105,9 @@ def rebuild_collect_shared( outputs
, (v, v.type, v_update, v_update.type)) , (v, v.type, v_update, v_update.type))
update_d[v] = v_update update_d[v] = v_update
update_expr.append((v, v_update)) update_expr.append((v, v_update))
if not copy_inputs_over and not isinstance(v, Constant): if not copy_inputs_over or isinstance(v, Constant):
### Cloning shared variables implies copying their underlying ### Cloning shared variables implies copying their underlying
### memory buffer ?? ### memory buffer ?? No.
return clone_d.setdefault(v,v.clone()) return clone_d.setdefault(v,v.clone())
else: else:
return clone_d.setdefault(v,v) return clone_d.setdefault(v,v)
......
...@@ -112,6 +112,152 @@ optdb.register( 'scanOp_remove_constants_and_unused_inputs' ...@@ -112,6 +112,152 @@ optdb.register( 'scanOp_remove_constants_and_unused_inputs'
, 'fast_run' , 'fast_run'
, 'scan') , 'scan')
@gof.local_optimizer([None])
def scan_pushout_non_seq_operation(node):
if not isinstance(node.op, scan_op.Scan):
return False
# this flag tells if there was any change during the last iterations
changed = True
try:
clean_inputs, clean_outputs = scan_utils.reconstruct_graph(
node.op.inputs, node.op.outputs)
local_env = gof.Env(clean_inputs, clean_outputs)
except:
import ipdb; ipdb.set_trace()
max_iterations = 2*len(local_env.toposort()) + 3
counts = 0
to_remove = []
to_replace = []
replace_with_in = []
replace_with_out = []
op = node.op
# Construct the list of non_sequences to simplify a few things
st = op.n_seqs
st += int(numpy.sum([len(x) for x in
op.tap_array[:(op.n_mit_mot+op.n_mit_sot)] ]))
st += op.n_sit_sot
st += op.n_shared_outs
non_seqs = clean_inputs[st:]
st = ( op.n_seqs +
op.n_mit_mot +
op.n_mit_sot +
op.n_sit_sot +
op.n_nit_sot +
op.n_shared_outs +1 )
outer_non_seqs = node.inputs[st:]
while changed and counts < max_iterations:
counts += 1
changed = False
for nd in local_env.toposort():
if ( numpy.all([ (x in non_seqs) or
(x.owner in to_remove) or
isinstance(x, tensor.Constant)
for x in nd.inputs]) and
# we can do this because the assumption is that a
# viewOp or deepCopyOp will be just at the end of the
# function and not somewhere in the middle ..
not isinstance(nd.op,theano.compile.ViewOp) and
not isinstance(nd.op,theano.compile.DeepCopyOp) and
# and we didn't already looked at this node
not nd in to_remove
):
# We have a candidate node to removable
# Step 1. Reconstruct it on outside
to_remove += [nd]
outside_ins = []
for x in nd.inputs:
if x in non_seqs:
outside_ins +=[ outer_non_seqs[non_seqs.index(x)]]
elif x in to_replace:
outside_ins +=[replace_with_out[to_replace.index(x)]]
elif isinstance(x, theano.Constant):
outside_ins +=[x.clone()]
else:
raise Exception(
('Error in the `scan_pushout_non_seq_operations`'
'. The optimization tries to move some '
'computation fron scan which is not allowed '
'to move. Report this on theano-users list'),x )
nw_outer_node = nd.op.make_node(*outside_ins)
# Step 2. Create variables for replacements
for idx,y in enumerate(nd.outputs):
y_place_holder = scan_utils.safe_new(y,'_replace')
to_replace += [y]
replace_with_in += [y_place_holder]
if (cuda.cuda_available and
isinstance(nw_outer_node.outputs[idx],
CudaNdarrayType)):
nw_out = nw_outer_node.outputs[idx]
replace_with_out += [host_from_gpu(nw_out)]
else:
replace_with_out += [nw_outer_node.outputs[idx]]
changed = True
if counts >= max_iterations:
raise Exception( ('Error in the `scan_pushout_non_seq_operations`.'
' The optimization exhausted the maximal number '
'of iterations allowed!'))
# We need to check all candidate replacements and choose those that
# make sense for us
# Step 1. which elements of `to_replace` are used by remaining
# components of the inner function
clean_to_replace = []
clean_replace_with_in = []
clean_replace_with_out = []
existent_nodes = [ nd for nd in local_env.toposort()
if nd not in to_remove]
to_keep = []
for nd in existent_nodes:
to_keep += nd.inputs
for idx,out in enumerate(to_replace):
if out in to_keep and out.owner not in existent_nodes:
clean_to_replace += [out]
clean_replace_with_in += [replace_with_in[idx]]
clean_replace_with_out += [replace_with_out[idx]]
if len(clean_to_replace) > 0:
# We can finally put an end to all this madness
givens = {}
nw_outer = []
nw_inner = []
for to_repl, repl_in, repl_out in zip( clean_to_replace,
clean_replace_with_in,
clean_replace_with_out):
if isinstance(repl_out, theano.Constant):
# Is this even possible !?
repl_in = repl_out.clone()
else:
nw_inner += [repl_in]
nw_outer += [repl_out]
givens[to_repl] = repl_in
_op_outs = scan_utils.clone(clean_outputs,
replace=givens)
_op_ins = clean_inputs + nw_inner
op_ins, op_outs = scan_utils.reconstruct_graph(_op_ins, _op_outs, '')
# Reconstruct node
nwScan = scan_op.Scan(op_ins, op_outs, op.info)
node = nwScan.make_node(* (node.inputs + nw_outer))
return node.outputs
else:
return False
optdb.register('scanOp_pushout_nonseqs_ops',
opt.in2out( scan_pushout_non_seq_operation,
ignore_newtrees=True),
1.90,
'fast_run',
'scan')
@gof.local_optimizer([None]) @gof.local_optimizer([None])
def scan_make_inplace(node): def scan_make_inplace(node):
op = node.op op = node.op
......
...@@ -726,13 +726,15 @@ def flatten(l): ...@@ -726,13 +726,15 @@ def flatten(l):
return sum(l , []) return sum(l , [])
def reconstruct_graph(inputs, outputs, tag): def reconstruct_graph(inputs, outputs, tag = None):
""" """
Different interface to clone, that allows you to pass inputs. Different interface to clone, that allows you to pass inputs.
Compared to clone, this method always replaces the inputs with Compared to clone, this method always replaces the inputs with
new variables of the same type, and returns those ( in the same new variables of the same type, and returns those ( in the same
order as the original inputs). order as the original inputs).
""" """
if tag is None:
tag = ''
nw_inputs = [safe_new(x,tag) for x in inputs] nw_inputs = [safe_new(x,tag) for x in inputs]
givens = {} givens = {}
for nw_x, x in zip(nw_inputs, inputs): for nw_x, x in zip(nw_inputs, inputs):
......
...@@ -184,7 +184,7 @@ class test_RopLop(unittest.TestCase): ...@@ -184,7 +184,7 @@ class test_RopLop(unittest.TestCase):
def test_max_argmax(self): def test_max_argmax(self):
self.check_map_rop_lop(TT.max(self.mx, axis=1), self.check_mat_rop_lop(TT.max(self.mx, axis=1),
(self.mat_in_shape[0],)) (self.mat_in_shape[0],))
def test_max_argmax(self): def test_max_argmax(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论