提交 0bef350f authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Misunderstood what expand did. New try.

Hoping we can trust `init_l[i]` to be the actual length of the initial state.
上级 efafdf1a
...@@ -1336,23 +1336,31 @@ class ScanSaveMem(gof.Optimizer): ...@@ -1336,23 +1336,31 @@ class ScanSaveMem(gof.Optimizer):
if global_nsteps is not None: if global_nsteps is not None:
for idx, val in enumerate(store_steps[op.n_mit_mot:]): for idx, val in enumerate(store_steps[op.n_mit_mot:]):
if val == 0: if val == 0:
# val == 0 means that we want to keep all intermediate
# results for that state, including the initial values.
if idx < op.n_mit_sot + op.n_sit_sot: if idx < op.n_mit_sot + op.n_sit_sot:
in_idx = offset + idx in_idx = offset + idx
# Number of steps in the initial state
initl = init_l[op.n_mit_mot + idx]
# If the initial buffer has the form # If the initial buffer has the form
# inc_subtensor(zeros(...)[...], _nw_input) # inc_subtensor(zeros(...)[...], _nw_input)
# we want to make the zeros tensor as small as # we want to make the zeros tensor as small as
# possible (nw_steps), and call inc_subtensor # possible (nw_steps + initl), and call
# on that instead. # inc_subtensor on that instead.
# Otherwise, simply take elements 0:nw_steps. # Otherwise, simply take 0:(nw_steps+initl).
if ((nw_inputs[in_idx].owner and if ((nw_inputs[in_idx].owner and
isinstance(nw_inputs[in_idx].owner.op, isinstance(nw_inputs[in_idx].owner.op,
tensor.IncSubtensor))): tensor.IncSubtensor) and
isinstance(
nw_inputs[in_idx].owner.op.idx_list[0],
slice))):
_nw_input = nw_inputs[in_idx].owner.inputs[1] _nw_input = nw_inputs[in_idx].owner.inputs[1]
nw_input = scan_utils.expand(_nw_input, nw_steps) nw_input = scan_utils.expand(_nw_input,
nw_steps)
nw_inputs[in_idx] = nw_input nw_inputs[in_idx] = nw_input
else: else:
nw_input = nw_inputs[in_idx][:nw_steps] nw_input = nw_inputs[in_idx][:(initl+nw_steps)]
elif idx < op.n_mit_sot + op.n_sit_sot + op.n_nit_sot: elif idx < op.n_mit_sot + op.n_sit_sot + op.n_nit_sot:
in_idx = offset + idx + op.n_shared_outs in_idx = offset + idx + op.n_shared_outs
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论