提交 982144a0 authored 作者: Pierre Luc Carrier's avatar Pierre Luc Carrier

Add checks for correctness of the function output

上级 f221a38e
......@@ -2528,15 +2528,19 @@ class T_Scan(unittest.TestCase):
def inner_fct(previous_val):
new_val = previous_val + srng.uniform()
condition = theano.scan_module.until(previous_val > 50)
condition = theano.scan_module.until(previous_val > 5)
return new_val, condition
out, updates = theano.scan(inner_fct,
outputs_info=x,
n_steps=100)
n_steps=10)
g_out = tensor.grad(out.sum(), x)
fct = theano.function([x], out)
fct = theano.function([x], [out, g_out])
for i in xrange(-5, 5):
output, g_output = fct(i)
assert len(output) == g_output
# The following test will fail in DebugMode if there are
# some problems in Scan.infer_shape
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论