提交 51813bfe authored 作者: Razvan Pascanu's avatar Razvan Pascanu

no conflicts

...@@ -561,15 +561,37 @@ class ScanSaveMem(gof.Optimizer): ...@@ -561,15 +561,37 @@ class ScanSaveMem(gof.Optimizer):
# If the memory for this output has been pre-allocated # If the memory for this output has been pre-allocated
# before going into the scan op (by an alloc node) # before going into the scan op (by an alloc node)
if idx < op.n_mit_sot + op.n_sit_sot: if idx < op.n_mit_sot + op.n_sit_sot:
# In case the input is still an alloc node # In case the input is still an alloc node, we
if nw_inputs[offset+idx].owner: # actually have two options:
# a) the input is an alloc (due to an optimization
# that converts set_subtensor(0,0) in 0
# b) the input is an set subtensor
if ( nw_inputs[offset+idx].owner and
isinstance(nw_inputs[offset+idx].owner.op,
tensor.IncSubtensor)):
_nw_input = nw_inputs[offset+idx].owner.inputs[1] _nw_input = nw_inputs[offset+idx].owner.inputs[1]
nw_input = scan_utils.expand( _nw_input, val - init_l[i] ) tmp = pre_greedy_local_optimizer(list_opt_slice,
tensor.as_tensor_variable(val - init_l[i]))
tmp = pre_constant_merge([tmp])[0]
nw_input = scan_utils.expand( _nw_input,tmp )
# If it is an alloc
elif ( nw_inputs[offset+idx].owner and
isinstance(nw_inputs[offset+idx].owner.op,
tensor.Alloc)):
tmp = pre_greedy_local_optimizer(list_opt_slice,
tensor.as_tensor_variable(val))
tmp = pre_constant_merge([tmp])[0]
nw_input = nw_inputs[offset+idx][:tmp]
# Else, if it was constant folded to a single value # Else, if it was constant folded to a single value
elif isinstance(nw_inputs[offset+idx], tensor.Constant): elif isinstance(nw_inputs[offset+idx], tensor.Constant):
# The hope is that constant folding will fold # The hope is that constant folding will fold
# this as well # this as well
nw_input = nw_inputs[offset+idx][:val]
tmp = pre_greedy_local_optimizer(list_opt_slice,
tensor.as_tensor_variable(val))
tmp = pre_constant_merge([tmp])[0]
nw_input = nw_inputs[offset+idx][:tmp]
else: else:
raise Exception(('Unforseen case. Please report' raise Exception(('Unforseen case. Please report'
' to theano-dev with an example' ' to theano-dev with an example'
...@@ -613,7 +635,17 @@ class ScanSaveMem(gof.Optimizer): ...@@ -613,7 +635,17 @@ class ScanSaveMem(gof.Optimizer):
# 3.5 Remove unwanted orphane outputs # 3.5 Remove unwanted orphane outputs
(inps, outs, info, node_ins, compress_map) = \ (inps, outs, info, node_ins, compress_map) = \
scan_utils.compress_outs(op, not_required, nw_inputs) scan_utils.compress_outs(op, not_required, nw_inputs)
inv_compress_map = {}
for k,v in compress_map.items():
inv_compress_map[v] = k
node_ins = [ pre_greedy_local_optimizer(list_opt_slice, x) for x in
node_ins]
node_ins = pre_constant_merge(node_ins)
# 3.6 Compose the new scan # 3.6 Compose the new scan
# I need to make sure I'm not reapplying the same optimization
# twice since bad things usually happen if I do that
info['_scan_merge_visited'] = True
new_outs = scan_op.Scan(inps new_outs = scan_op.Scan(inps
, outs , outs
, info).make_node(*node_ins).outputs , info).make_node(*node_ins).outputs
...@@ -641,7 +673,7 @@ class ScanSaveMem(gof.Optimizer): ...@@ -641,7 +673,7 @@ class ScanSaveMem(gof.Optimizer):
nw_slice = (fslice,) + tuple(old_slices[1:]) nw_slice = (fslice,) + tuple(old_slices[1:])
nw_pos = compress_map[idx] nw_pos = inv_compress_map[idx]
nw_out = new_outs[nw_pos] nw_out = new_outs[nw_pos]
...@@ -660,41 +692,42 @@ class ScanSaveMem(gof.Optimizer): ...@@ -660,41 +692,42 @@ class ScanSaveMem(gof.Optimizer):
# 3.8. Get replace pairs for those outputs that change # 3.8. Get replace pairs for those outputs that change
# the number of stored intermediate steps # the number of stored intermediate steps
for pos, old_outs in old_outputs: for pos, old_outs in old_outputs:
nw_pos = compress_map[pos] if len(old_outs) > 0:
nw_out = new_outs[nw_pos] nw_pos = compress_map[pos]
for k,old in enumerate(old_outs): nw_out = new_outs[nw_pos]
# Get the correct slice for k,old in enumerate(old_outs):
cnf_slice, old_slices = slices[pos][k] # Get the correct slice
if type(cnf_slice[0]) is slice: cnf_slice, old_slices = slices[pos][k]
start = ( cnf_slice[0].start - nw_steps - if type(cnf_slice[0]) is slice:
init_l[pos] + store_steps[pos] ) start = ( cnf_slice[0].start - nw_steps -
if ( cnf_slice[0].stop is not None and init_l[pos] + store_steps[pos] )
cnf_slice[0].stop != sys.maxint ): if ( cnf_slice[0].stop is not None and
stop = ( cnf_slice[0].stop - nw_steps - cnf_slice[0].stop != sys.maxint ):
init_l[pos] + store_steps[pos]) stop = ( cnf_slice[0].stop - nw_steps -
init_l[pos] + store_steps[pos])
else:
stop = None
nw_slice = ( (slice(sanitize(start),
sanitize(stop),
sanitize(cnf_slice[0].step)),) +
tuple(old_slices[1:]) )
else: else:
stop = None position = (cnf_slice[0] - nw_steps -
nw_slice = ( (slice(sanitize(start), init_l[pos] + store_steps[pos] )
sanitize(stop),
sanitize(cnf_slice[0].step)),) +
tuple(old_slices[1:]) )
else: nw_slice = (sanitize(position),) + tuple(old_slices[1:])
position = (cnf_slice[0] - nw_steps -
init_l[pos] + store_steps[pos] ) subtens = tensor.basic.Subtensor(nw_slice)
sl_ins = tensor.basic.Subtensor.collapse(
nw_slice = (sanitize(position),) + tuple(old_slices[1:]) nw_slice
, lambda entry: isinstance(entry
subtens = tensor.basic.Subtensor(nw_slice) , tensor.Variable))
sl_ins = tensor.basic.Subtensor.collapse( new_o = subtens.make_node(new_outs[nw_pos],
nw_slice *sl_ins).outputs[0]
, lambda entry: isinstance(entry if new_o.ndim > 0:
, tensor.Variable)) new_o = new_o[::cnf_slice[1]]
new_o = subtens.make_node(new_outs[nw_pos], old_new += [(old, new_o)]
*sl_ins).outputs[0]
if new_o.ndim > 0:
new_o = new_o[::cnf_slice[1]]
old_new += [(old, new_o)]
# 3.9. Get replace pairs for all other nodes # 3.9. Get replace pairs for all other nodes
if flag_store or global_nsteps is not None: if flag_store or global_nsteps is not None:
...@@ -707,11 +740,11 @@ class ScanSaveMem(gof.Optimizer): ...@@ -707,11 +740,11 @@ class ScanSaveMem(gof.Optimizer):
def apply(self, env): def apply(self, env):
nodelist = list(env.toposort())
old_new = [] nodelist = [x for x in env.toposort() if isinstance(x.op,
scan_op.Scan)]
for node in nodelist: for node in nodelist:
op = node.op if not hasattr(node.op, '_scan_merge_visited'):
if isinstance(op, scan_op.Scan):
self.process_node(env, node) self.process_node(env, node)
# Just before specialize to have the other optimization # Just before specialize to have the other optimization
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论