提交 7e777e65 authored 作者: carriepl's avatar carriepl

Add advanced test for scan.infer_shape

上级 50c3fd00
......@@ -2549,6 +2549,43 @@ class T_Scan(unittest.TestCase):
output, g_output = fct(i)
assert len(output) == g_output
def test_infer_shape2(self):
# Ensure that the shape inference can remove the Scan node in the
# case of a complicated inner graph involving sequences and recurrent
# states
seq = tensor.lvector()
sitsot_init = tensor.lscalar()
mitsot_init = tensor.lvector()
def step(seq1, sitsot_m1, mitsot_m2, mitsot_m1):
# Every iteration, the sitsot state decreases and the mitsot state
# increases such that their total value remains identical. This
# is because this value will be used as the shape of a nitsot
# output and the outputs of every iteration need to have the same
# shape
diff = mitsot_m1 + seq1
next_mitsot_val = mitsot_m2 + diff
next_sitsot_val = sitsot_m1 - diff
nitsot_out = tensor.AllocEmpty('float32')(next_mitsot_val +
next_sitsot_val)
return next_sitsot_val, next_mitsot_val, nitsot_out
out, updates = theano.scan(fn=step,
sequences=seq,
outputs_info=[sitsot_init,
{'initial' : mitsot_init,
'taps' : [-2, -1]},
None],
n_steps=5)
f = theano.function([seq, sitsot_init, mitsot_init], out[2].shape)
assert(len(scan_nodes_from_fct(f)) == 0)
output_shape = f(numpy.arange(5), 5, [1, 2])
assert(all(output_shape == (5,6)))
# The following test will fail in DebugMode if there are
# some problems in Scan.infer_shape
def test_remove_stuff(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论