提交 c4708d8e authored 作者: carriepl's avatar carriepl

Merge pull request #2 from lamblin/scan_crash_infer_shape

Misunderstood what expand did. New try.
......@@ -1336,23 +1336,31 @@ class ScanSaveMem(gof.Optimizer):
if global_nsteps is not None:
for idx, val in enumerate(store_steps[op.n_mit_mot:]):
if val == 0:
# val == 0 means that we want to keep all intermediate
# results for that state, including the initial values.
if idx < op.n_mit_sot + op.n_sit_sot:
in_idx = offset + idx
# Number of steps in the initial state
initl = init_l[op.n_mit_mot + idx]
# If the initial buffer has the form
# inc_subtensor(zeros(...)[...], _nw_input)
# we want to make the zeros tensor as small as
# possible (nw_steps), and call inc_subtensor
# on that instead.
# Otherwise, simply take elements 0:nw_steps.
# possible (nw_steps + initl), and call
# inc_subtensor on that instead.
# Otherwise, simply take 0:(nw_steps+initl).
if ((nw_inputs[in_idx].owner and
isinstance(nw_inputs[in_idx].owner.op,
tensor.IncSubtensor))):
tensor.IncSubtensor) and
isinstance(
nw_inputs[in_idx].owner.op.idx_list[0],
slice))):
_nw_input = nw_inputs[in_idx].owner.inputs[1]
nw_input = scan_utils.expand(_nw_input, nw_steps)
nw_input = scan_utils.expand(_nw_input,
nw_steps)
nw_inputs[in_idx] = nw_input
else:
nw_input = nw_inputs[in_idx][:nw_steps]
nw_input = nw_inputs[in_idx][:(initl+nw_steps)]
elif idx < op.n_mit_sot + op.n_sit_sot + op.n_nit_sot:
in_idx = offset + idx + op.n_shared_outs
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论