提交 26a71eb7 authored 作者: Frederic's avatar Frederic

Fix extra scan discovered by example in gh-3663. This revert 50c3fd00.

上级 d3e30815
......@@ -1607,28 +1607,32 @@ class Scan(PureOp):
inner_ins_shapes = []
out_equivalent = OrderedDict()
# The two following blocks are commented as it cause in some
# cases extra scans in the graph. See gh-XXX for the
# investigation.
# We skip the first outer input as it is the total or current number
# of iterations.
# sequences
seqs_shape = [x[1:] for x in input_shapes[1:1 + self.n_seqs]]
inner_seqs = self.inputs[:self.n_seqs]
outer_seqs = node.inputs[1:1 + self.n_seqs]
for in_s, out_s in izip(inner_seqs, outer_seqs):
out_equivalent[in_s] = out_s[0]
# inner_seqs = self.inputs[:self.n_seqs]
# outer_seqs = node.inputs[1:1 + self.n_seqs]
# for in_s, out_s in izip(inner_seqs, outer_seqs):
# out_equivalent[in_s] = out_s[0]
# mit_mot, mit_sot, sit_sot
outer_inp_idx = 1 + self.n_seqs
inner_inp_idx = self.n_seqs
# outer_inp_idx = 1 + self.n_seqs
# inner_inp_idx = self.n_seqs
n_outs = self.n_mit_mot + self.n_mit_sot + self.n_sit_sot
outs_shape = []
for idx in xrange(n_outs):
mintap = abs(min(self.tap_array[idx]))
# mintap = abs(min(self.tap_array[idx]))
for k in self.tap_array[idx]:
outs_shape += [input_shapes[idx + self.n_seqs + 1][1:]]
corresponding_tap = node.inputs[outer_inp_idx][mintap + k]
out_equivalent[self.inputs[inner_inp_idx]] = corresponding_tap
inner_inp_idx += 1
outer_inp_idx += 1
# corresponding_tap = node.inputs[outer_inp_idx][mintap + k]
# out_equivalent[self.inputs[inner_inp_idx]] = corresponding_tap
# inner_inp_idx += 1
# outer_inp_idx += 1
# shared_outs
offset = 1 + self.n_seqs + n_outs
......
......@@ -212,7 +212,9 @@ class TestPushOutScanOutputDot(object):
# not be the result of a Dot
scan_node = [node for node in f_opt.maker.fgraph.toposort()
if isinstance(node.op, Scan)][0]
assert len(scan_node.op.outputs) == 1
# NOTE: WHEN INFER_SHAPE IS REENABLED, BELLOW THE SCAN MUST
# HAVE ONLY 1 OUTPUT.
assert len(scan_node.op.outputs) == 2
assert not isinstance(scan_node.op.outputs[0], T.Dot)
# Ensure that the function compiled with the optimization produces
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论