提交 ed009e03 authored 作者: --global's avatar --global

Make test fail when optimization error is raised

上级 4c5afa2d
......@@ -2524,12 +2524,20 @@ class T_Scan(unittest.TestCase):
# recurrent output of a scan node instead of taking the result
# returned by the scan() function
# Obtain a compilation mode that will cause the test to fail if an
# exception occurs in the optimization process
on_opt_error = theano.config.on_opt_error
theano.config.on_opt_error = "raise"
mode = theano.compile.get_default_mode()
theano.config.on_opt_error = on_opt_error
x = tensor.scalar()
seq = tensor.vector()
outputs_info=[x, tensor.zeros_like(x)]
(out1, out2), updates = theano.scan(lambda a, b, c : (a + b, b + c),
sequences=seq,
outputs_info=outputs_info)
outputs_info=outputs_info,
mode=mode)
# Obtain a reference to the scan outputs before the subtensor and
# compile a function with them as outputs
......@@ -2537,7 +2545,9 @@ class T_Scan(unittest.TestCase):
assert isinstance(out2.owner.op, tensor.subtensor.Subtensor)
out1_direct = out1.owner.inputs[0]
out2_direct = out2.owner.inputs[0]
fct = theano.function([x, seq], [out1_direct[:-1], out2_direct[:-1]])
fct = theano.function([x, seq],
[out1_direct[:-1], out2_direct[:-1]],
mode=mode)
# Test the function to ensure valid outputs
floatX = theano.config.floatX
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论