提交 982144a0 authored 作者: Pierre Luc Carrier's avatar Pierre Luc Carrier

Add checks for correctness of the function output

上级 f221a38e
...@@ -2528,15 +2528,19 @@ class T_Scan(unittest.TestCase): ...@@ -2528,15 +2528,19 @@ class T_Scan(unittest.TestCase):
def inner_fct(previous_val): def inner_fct(previous_val):
new_val = previous_val + srng.uniform() new_val = previous_val + srng.uniform()
condition = theano.scan_module.until(previous_val > 50) condition = theano.scan_module.until(previous_val > 5)
return new_val, condition return new_val, condition
out, updates = theano.scan(inner_fct, out, updates = theano.scan(inner_fct,
outputs_info=x, outputs_info=x,
n_steps=100) n_steps=10)
g_out = tensor.grad(out.sum(), x) g_out = tensor.grad(out.sum(), x)
fct = theano.function([x], out) fct = theano.function([x], [out, g_out])
for i in xrange(-5, 5):
output, g_output = fct(i)
assert len(output) == g_output
# The following test will fail in DebugMode if there are # The following test will fail in DebugMode if there are
# some problems in Scan.infer_shape # some problems in Scan.infer_shape
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论