提交 29783b30 authored 作者: Razvan Pascanu's avatar Razvan Pascanu

Fixed ScanSaveMem optimization to deal with the inputs being allocs, not

set_subtensor. This was the cause of the 3 failing test in scan. I also revised the entire optimization, writing a bit more comments/
上级 5548fee1
......@@ -561,15 +561,37 @@ class ScanSaveMem(gof.Optimizer):
# If the memory for this output has been pre-allocated
# before going into the scan op (by an alloc node)
if idx < op.n_mit_sot + op.n_sit_sot:
# In case the input is still an alloc node
if nw_inputs[offset+idx].owner:
# In case the input is still an alloc node, we
# actually have two options:
# a) the input is an alloc (due to an optimization
# that converts set_subtensor(0,0) in 0
# b) the input is an set subtensor
if ( nw_inputs[offset+idx].owner and
isinstance(nw_inputs[offset+idx].owner.op,
tensor.IncSubtensor)):
_nw_input = nw_inputs[offset+idx].owner.inputs[1]
nw_input = scan_utils.expand( _nw_input, val - init_l[i] )
tmp = pre_greedy_local_optimizer(list_opt_slice,
tensor.as_tensor_variable(val - init_l[i]))
tmp = pre_constant_merge([tmp])[0]
nw_input = scan_utils.expand( _nw_input,tmp )
# If it is an alloc
elif ( nw_inputs[offset+idx].owner and
isinstance(nw_inputs[offset+idx].owner.op,
tensor.Alloc)):
tmp = pre_greedy_local_optimizer(list_opt_slice,
tensor.as_tensor_variable(val))
tmp = pre_constant_merge([tmp])[0]
nw_input = nw_inputs[offset+idx][:tmp]
# Else, if it was constant folded to a single value
elif isinstance(nw_inputs[offset+idx], tensor.Constant):
# The hope is that constant folding will fold
# this as well
nw_input = nw_inputs[offset+idx][:val]
tmp = pre_greedy_local_optimizer(list_opt_slice,
tensor.as_tensor_variable(val))
tmp = pre_constant_merge([tmp])[0]
nw_input = nw_inputs[offset+idx][:tmp]
else:
raise Exception(('Unforseen case. Please report'
' to theano-dev with an example'
......@@ -613,7 +635,17 @@ class ScanSaveMem(gof.Optimizer):
# 3.5 Remove unwanted orphane outputs
(inps, outs, info, node_ins, compress_map) = \
scan_utils.compress_outs(op, not_required, nw_inputs)
inv_compress_map = {}
for k,v in compress_map.items():
inv_compress_map[v] = k
node_ins = [ pre_greedy_local_optimizer(list_opt_slice, x) for x in
node_ins]
node_ins = pre_constant_merge(node_ins)
# 3.6 Compose the new scan
# I need to make sure I'm not reapplying the same optimization
# twice since bad things usually happen if I do that
info['_scan_merge_visited'] = True
new_outs = scan_op.Scan(inps
, outs
, info).make_node(*node_ins).outputs
......@@ -641,7 +673,7 @@ class ScanSaveMem(gof.Optimizer):
nw_slice = (fslice,) + tuple(old_slices[1:])
nw_pos = compress_map[idx]
nw_pos = inv_compress_map[idx]
nw_out = new_outs[nw_pos]
......@@ -660,6 +692,7 @@ class ScanSaveMem(gof.Optimizer):
# 3.8. Get replace pairs for those outputs that change
# the number of stored intermediate steps
for pos, old_outs in old_outputs:
if len(old_outs) > 0:
nw_pos = compress_map[pos]
nw_out = new_outs[nw_pos]
for k,old in enumerate(old_outs):
......@@ -707,11 +740,11 @@ class ScanSaveMem(gof.Optimizer):
def apply(self, env):
nodelist = list(env.toposort())
old_new = []
nodelist = [x for x in env.toposort() if isinstance(x.op,
scan_op.Scan)]
for node in nodelist:
op = node.op
if isinstance(op, scan_op.Scan):
if not hasattr(node.op, '_scan_merge_visited'):
self.process_node(env, node)
# Just before specialize to have the other optimization
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论