提交 29783b30 authored 作者: Razvan Pascanu's avatar Razvan Pascanu

Fixed ScanSaveMem optimization to deal with the inputs being allocs, not

set_subtensor. This was the cause of the 3 failing test in scan. I also revised the entire optimization, writing a bit more comments/
上级 5548fee1
......@@ -561,15 +561,37 @@ class ScanSaveMem(gof.Optimizer):
# If the memory for this output has been pre-allocated
# before going into the scan op (by an alloc node)
if idx < op.n_mit_sot + op.n_sit_sot:
# In case the input is still an alloc node
if nw_inputs[offset+idx].owner:
# In case the input is still an alloc node, we
# actually have two options:
# a) the input is an alloc (due to an optimization
# that converts set_subtensor(0,0) in 0
# b) the input is an set subtensor
if ( nw_inputs[offset+idx].owner and
isinstance(nw_inputs[offset+idx].owner.op,
tensor.IncSubtensor)):
_nw_input = nw_inputs[offset+idx].owner.inputs[1]
nw_input = scan_utils.expand( _nw_input, val - init_l[i] )
tmp = pre_greedy_local_optimizer(list_opt_slice,
tensor.as_tensor_variable(val - init_l[i]))
tmp = pre_constant_merge([tmp])[0]
nw_input = scan_utils.expand( _nw_input,tmp )
# If it is an alloc
elif ( nw_inputs[offset+idx].owner and
isinstance(nw_inputs[offset+idx].owner.op,
tensor.Alloc)):
tmp = pre_greedy_local_optimizer(list_opt_slice,
tensor.as_tensor_variable(val))
tmp = pre_constant_merge([tmp])[0]
nw_input = nw_inputs[offset+idx][:tmp]
# Else, if it was constant folded to a single value
elif isinstance(nw_inputs[offset+idx], tensor.Constant):
# The hope is that constant folding will fold
# this as well
nw_input = nw_inputs[offset+idx][:val]
tmp = pre_greedy_local_optimizer(list_opt_slice,
tensor.as_tensor_variable(val))
tmp = pre_constant_merge([tmp])[0]
nw_input = nw_inputs[offset+idx][:tmp]
else:
raise Exception(('Unforseen case. Please report'
' to theano-dev with an example'
......@@ -613,7 +635,17 @@ class ScanSaveMem(gof.Optimizer):
# 3.5 Remove unwanted orphane outputs
(inps, outs, info, node_ins, compress_map) = \
scan_utils.compress_outs(op, not_required, nw_inputs)
inv_compress_map = {}
for k,v in compress_map.items():
inv_compress_map[v] = k
node_ins = [ pre_greedy_local_optimizer(list_opt_slice, x) for x in
node_ins]
node_ins = pre_constant_merge(node_ins)
# 3.6 Compose the new scan
# I need to make sure I'm not reapplying the same optimization
# twice since bad things usually happen if I do that
info['_scan_merge_visited'] = True
new_outs = scan_op.Scan(inps
, outs
, info).make_node(*node_ins).outputs
......@@ -641,7 +673,7 @@ class ScanSaveMem(gof.Optimizer):
nw_slice = (fslice,) + tuple(old_slices[1:])
nw_pos = compress_map[idx]
nw_pos = inv_compress_map[idx]
nw_out = new_outs[nw_pos]
......@@ -660,41 +692,42 @@ class ScanSaveMem(gof.Optimizer):
# 3.8. Get replace pairs for those outputs that change
# the number of stored intermediate steps
for pos, old_outs in old_outputs:
nw_pos = compress_map[pos]
nw_out = new_outs[nw_pos]
for k,old in enumerate(old_outs):
# Get the correct slice
cnf_slice, old_slices = slices[pos][k]
if type(cnf_slice[0]) is slice:
start = ( cnf_slice[0].start - nw_steps -
init_l[pos] + store_steps[pos] )
if ( cnf_slice[0].stop is not None and
cnf_slice[0].stop != sys.maxint ):
stop = ( cnf_slice[0].stop - nw_steps -
init_l[pos] + store_steps[pos])
if len(old_outs) > 0:
nw_pos = compress_map[pos]
nw_out = new_outs[nw_pos]
for k,old in enumerate(old_outs):
# Get the correct slice
cnf_slice, old_slices = slices[pos][k]
if type(cnf_slice[0]) is slice:
start = ( cnf_slice[0].start - nw_steps -
init_l[pos] + store_steps[pos] )
if ( cnf_slice[0].stop is not None and
cnf_slice[0].stop != sys.maxint ):
stop = ( cnf_slice[0].stop - nw_steps -
init_l[pos] + store_steps[pos])
else:
stop = None
nw_slice = ( (slice(sanitize(start),
sanitize(stop),
sanitize(cnf_slice[0].step)),) +
tuple(old_slices[1:]) )
else:
stop = None
nw_slice = ( (slice(sanitize(start),
sanitize(stop),
sanitize(cnf_slice[0].step)),) +
tuple(old_slices[1:]) )
position = (cnf_slice[0] - nw_steps -
init_l[pos] + store_steps[pos] )
else:
position = (cnf_slice[0] - nw_steps -
init_l[pos] + store_steps[pos] )
nw_slice = (sanitize(position),) + tuple(old_slices[1:])
subtens = tensor.basic.Subtensor(nw_slice)
sl_ins = tensor.basic.Subtensor.collapse(
nw_slice
, lambda entry: isinstance(entry
, tensor.Variable))
new_o = subtens.make_node(new_outs[nw_pos],
*sl_ins).outputs[0]
if new_o.ndim > 0:
new_o = new_o[::cnf_slice[1]]
old_new += [(old, new_o)]
nw_slice = (sanitize(position),) + tuple(old_slices[1:])
subtens = tensor.basic.Subtensor(nw_slice)
sl_ins = tensor.basic.Subtensor.collapse(
nw_slice
, lambda entry: isinstance(entry
, tensor.Variable))
new_o = subtens.make_node(new_outs[nw_pos],
*sl_ins).outputs[0]
if new_o.ndim > 0:
new_o = new_o[::cnf_slice[1]]
old_new += [(old, new_o)]
# 3.9. Get replace pairs for all other nodes
if flag_store or global_nsteps is not None:
......@@ -707,11 +740,11 @@ class ScanSaveMem(gof.Optimizer):
def apply(self, env):
nodelist = list(env.toposort())
old_new = []
nodelist = [x for x in env.toposort() if isinstance(x.op,
scan_op.Scan)]
for node in nodelist:
op = node.op
if isinstance(op, scan_op.Scan):
if not hasattr(node.op, '_scan_merge_visited'):
self.process_node(env, node)
# Just before specialize to have the other optimization
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论