提交 1dff4010 authored 作者: Razvan Pascanu's avatar Razvan Pascanu

Test for infershape when using scan as a repeat until

上级 a2480e41
...@@ -2515,6 +2515,17 @@ class T_Scan(unittest.TestCase): ...@@ -2515,6 +2515,17 @@ class T_Scan(unittest.TestCase):
if isinstance(x.op, theano.scan_module.scan_op.Scan)] if isinstance(x.op, theano.scan_module.scan_op.Scan)]
assert len(lssc) == 1 assert len(lssc) == 1
def test_while_infershape(self):
x = tensor.vector('x')
def lambda_fn(x_t):
return x_t + 1, theano.scan_module.until( x_t > 3)
o, _ = theano.scan(lambda_fn, x)
f = theano.function([x], o.shape[0], mode=mode_with_opt)
vx = numpy.zeros((50,), dtype = theano.config.floatX)
vx[23] = 4
out = f(vx)
assert out == 24
def test_grad_multiple_seqs_different_nsteps(self): def test_grad_multiple_seqs_different_nsteps(self):
# Example provided Michael Forbes # Example provided Michael Forbes
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论