提交 7576ab6d authored 作者: abergeron's avatar abergeron

Merge pull request #3440 from carriepl/scan_push_out_seq_scan

Scan push out seq scan
...@@ -397,10 +397,24 @@ class PushOutNonSeqScan(gof.Optimizer): ...@@ -397,10 +397,24 @@ class PushOutNonSeqScan(gof.Optimizer):
# because the scan op expects for a tensor3, to which an # because the scan op expects for a tensor3, to which an
# subtensor is applied that takes only the last element # subtensor is applied that takes only the last element
if replace_with: if replace_with:
if len(node.outputs) == len(replace_with):
# Every output of the node has a replacement, the Scan
# node can be removed from the graph
fgraph.replace_all_validate_remove( fgraph.replace_all_validate_remove(
replace_with.items(), replace_with.items(),
remove=[node], remove=[node],
reason='scanOp_pushout_nonseqs_ops') reason='scanOp_pushout_nonseqs_ops')
else:
# The node has some outputs for which no replacement has
# been established. This can occur for outputs that are
# not produced by apply nodes (since the optimizations
# only visits apply nodes) such as constants or inputs
# passed directly as outputs. The replacements can be
# performed but the Scan node can't be removed at this
# point.
fgraph.replace_all_validate(
replace_with.items(),
reason='scanOp_pushout_nonseqs_ops')
else: else:
return False return False
......
...@@ -2771,6 +2771,22 @@ class T_Scan(unittest.TestCase): ...@@ -2771,6 +2771,22 @@ class T_Scan(unittest.TestCase):
# an exception being raised # an exception being raised
theano.function([x], outputs, updates=updates) theano.function([x], outputs, updates=updates)
@theano.configparser.change_flags(on_opt_error='raise')
def test_pushout_nonseq(self):
# Test case originally reported by Daniel Renshaw. The crashed occured
# during the optimization PushOutNonSeqScan when it attempted to
# a scan node with two outputs but only providing a replacement for
# one of those outputs. This led the optimization to raise an
# exception.
outputs, _ = theano.scan(lambda x: (x * x, x),
non_sequences=[2], n_steps=2)
f = theano.function(inputs=[], outputs=outputs)
outs = f()
expected_outs = [[4, 4], [2, 2]]
utt.assert_allclose(outs, expected_outs)
def test_sequence_dict(self): def test_sequence_dict(self):
# Test that we can specify sequences as a dictionary with # Test that we can specify sequences as a dictionary with
# only the 'input' key # only the 'input' key
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论