提交 d0a9a4cd authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Fix bug in scan merge optimization reported by Yao Li.

上级 d2947cae
......@@ -1217,8 +1217,13 @@ class ScanMerge(gof.Optimizer):
belongs_to_set_idx = -1
for pos, subset in enumerate(all_sets):
if self.belongs_to_set(nd, subset):
assert belongs_to_set_idx == -1
belongs_to_set_idx = pos
# It is possible that nd belongs to more than one subset.
# For instance, if we have 3 Scan nodes X, Y and Z, if Z
# depends on the output of X, then X and Z are incompatible
# and would create different subsets, but Y could be
# compatible with both X and Z. We choose the first one.
break
if belongs_to_set_idx == -1:
all_sets.append([nd])
......
......@@ -2428,6 +2428,36 @@ class T_Scan(unittest.TestCase):
n.op, theano.scan_module.scan_op.Scan)]
self.assertTrue(len(scans) == 2)
def test_merge_3scans(self):
# This test checks a case where we have 3 scans, two of them
# cannot be merged together, but the third one can be merged with
# either.
x = theano.tensor.vector()
y = theano.tensor.vector()
def sum(s):
return s + 1
sx, upx = theano.scan(sum, sequences=[x], n_steps=4, name='X')
# We need to use an expression of y rather than y so the toposort
# comes up with the 'Y' scan last.
sy, upy = theano.scan(sum, sequences=[2 * y + 2], n_steps=4, name='Y')
sz, upz = theano.scan(sum, sequences=[sx], n_steps=4, name='Z')
f = theano.function(
[x, y], [sy, sz],
mode=mode_with_opt.excluding('scanOp_pushout_seqs_ops'))
topo = f.maker.fgraph.toposort()
scans = [n for n in topo if isinstance(
n.op, theano.scan_module.scan_op.Scan)]
self.assertTrue(len(scans) == 2)
rng = numpy.random.RandomState(utt.fetch_seed())
x_val = rng.uniform(size=(4,)).astype(theano.config.floatX)
y_val = rng.uniform(size=(4,)).astype(theano.config.floatX)
# Run it so DebugMode can detect optimization problems.
f(x_val, y_val)
def test_hash(self):
x = theano.tensor.vector()
y = theano.tensor.vector()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论