提交 5a70f9ae authored 作者: carriepl's avatar carriepl

Merge pull request #3221 from carriepl/scan_opt_bug

[ENH] Scan PushOutSeqScan
...@@ -116,6 +116,10 @@ def change_flags(**kwargs): ...@@ -116,6 +116,10 @@ def change_flags(**kwargs):
if v.fullname == k] if v.fullname == k]
assert len(l) == 1 assert len(l) == 1
l[0].__set__(None, old_val[k]) l[0].__set__(None, old_val[k])
# Make sure that the name of the decorated function remains the same.
inner.__name__ = f.__name__
return inner return inner
return change_flags_exec return change_flags_exec
......
...@@ -472,13 +472,12 @@ class PushOutSeqScan(gof.Optimizer): ...@@ -472,13 +472,12 @@ class PushOutSeqScan(gof.Optimizer):
for nd in local_fgraph_topo: for nd in local_fgraph_topo:
if (nd not in to_remove_set and if (nd not in to_remove_set and
all([(x in inner_non_seqs_set) or all([(x in inner_non_seqs_set) or
(x.owner in to_remove_set) or (x.owner in to_remove_set) or
isinstance(x, tensor.Constant) or isinstance(x, tensor.Constant) or
(x in inner_seqs_set) for x in nd.inputs]) and (x in inner_seqs_set) for x in nd.inputs]) and
isinstance(nd.op, theano.tensor.Elemwise)): isinstance(nd.op, theano.tensor.Elemwise)):
to_remove_set.add(nd)
outside_ins = [] outside_ins = []
depends_on_seqs = False depends_on_seqs = False
...@@ -511,6 +510,8 @@ class PushOutSeqScan(gof.Optimizer): ...@@ -511,6 +510,8 @@ class PushOutSeqScan(gof.Optimizer):
# scan. # scan.
continue continue
to_remove_set.add(nd)
# Do not call make_node for test_value # Do not call make_node for test_value
nw_outer_node = nd.op(*outside_ins, nw_outer_node = nd.op(*outside_ins,
**dict(return_list=True))[0].owner **dict(return_list=True))[0].owner
......
...@@ -2696,6 +2696,23 @@ class T_Scan(unittest.TestCase): ...@@ -2696,6 +2696,23 @@ class T_Scan(unittest.TestCase):
utt.assert_allclose(expected_output, scan_output) utt.assert_allclose(expected_output, scan_output)
utt.assert_allclose(expected_output, jacobian_outputs) utt.assert_allclose(expected_output, jacobian_outputs)
@theano.configparser.change_flags(on_opt_error='raise')
def test_pushout_seqs2(self):
# This test for a bug with PushOutSeqScan that was reported on the
# theano-user mailing list where the optimization raised an exception
# when applied on this graph.
x = tensor.matrix()
outputs, updates = theano.scan(
lambda x: [x*x, tensor.constant(0).copy().copy()],
n_steps=2,
sequences=[],
non_sequences=[],
outputs_info=[x, None])
# Compile a theano function where any optimization error will lead to
# an exception being raised
theano.function([x], outputs, updates=updates)
def test_sequence_dict(self): def test_sequence_dict(self):
# Test that we can specify sequences as a dictionary with # Test that we can specify sequences as a dictionary with
# only the 'input' key # only the 'input' key
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论