提交 00cd33ba authored 作者: Razvan Pascanu's avatar Razvan Pascanu

Fixed a bug in the optimization. The optimization was expecting the inputs

of scan to be in a certain form, which is not true anymore if constant folding replaces them by a constant.
上级 573650de
...@@ -367,8 +367,21 @@ class ScanSaveMem(gof.Optimizer): ...@@ -367,8 +367,21 @@ class ScanSaveMem(gof.Optimizer):
# If the memory for this output has been pre-allocated # If the memory for this output has been pre-allocated
# before going into the scan op (by an alloc node) # before going into the scan op (by an alloc node)
if idx < op.n_mit_sot + op.n_sit_sot: if idx < op.n_mit_sot + op.n_sit_sot:
_nw_input = nw_inputs[offset+idx].owner.inputs[1] # In case the input is still an alloc node
nw_input = scan_utils.expand( _nw_input, val - init_l[i] ) if nw_inputs[offset+idx].owner:
_nw_input = nw_inputs[offset+idx].owner.inputs[1]
nw_input = scan_utils.expand( _nw_input, val - init_l[i] )
# Else, if it was constant folded to a single value
elif isinstance(nw_inputs[offset+idx], tensor.Constant):
# The hope is that constant folding will fold
# this as well
nw_input = nw_inputs[offset+idx][:val]
else:
raise Exception(('Unforseen case. Please report'
' to theano-dev with an example'
' script for this case to be'
' debuged'))
nw_inputs[offset+idx] = nw_input nw_inputs[offset+idx] = nw_input
replaced_outs.append(op.n_mit_mot + idx) replaced_outs.append(op.n_mit_mot + idx)
odx = op.n_mit_mot + idx odx = op.n_mit_mot + idx
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论