提交 edcf97f3 authored 作者: Frédéric Bastien's avatar Frédéric Bastien

Merge pull request #1566 from lamblin/fix_scan_linesearch

[WIP] Fix bug in remove_constants_and_unused_inputs_scan
......@@ -118,25 +118,26 @@ def remove_constants_and_unused_inputs_scan(node):
# Add outputs stuff
nw_inner += out_stuff_inner
nw_outer += out_stuff_outer
# Look through non sequences
nw_inner_nonseq = []
nw_outer_nonseq = []
for idx, (nw_in, nw_out) in enumerate(zip(non_seqs, outer_non_seqs)):
if isinstance(nw_out, tensor.Constant):
givens[nw_in] = nw_out.clone()
elif nw_in in all_ins:
identical_non_seqs = [x for x in nw_outer
if scan_utils.equal_computations(
[x], [nw_out])]
if identical_non_seqs:
identical_idx = outer_non_seqs.index(identical_non_seqs[0])
# If we have identical non sequences, the previous one
# must be in nw_inner or be a constant.
assert (non_seqs[identical_idx] in nw_inner or
isinstance(identical_non_seqs[0], tensor.Constant))
index = outer_non_seqs.index(identical_non_seqs[0])
givens[nw_in] = non_seqs[index]
# Indices of elements of nw_outer_nonseq that are equivalent
# to nw_out.
identical_nonseq_idx = [
i for (i, x) in enumerate(nw_outer_nonseq)
if scan_utils.equal_computations([x], [nw_out])]
if identical_nonseq_idx:
givens[nw_in] = nw_inner_nonseq[identical_nonseq_idx[0]]
else:
nw_inner += [nw_in]
nw_outer += [nw_out]
nw_inner_nonseq += [nw_in]
nw_outer_nonseq += [nw_out]
nw_inner.extend(nw_inner_nonseq)
nw_outer.extend(nw_outer_nonseq)
if len(nw_inner) != len(op_ins):
op_outs = scan_utils.clone(op_outs, replace=givens)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论