提交 aaed5273 authored 作者: Razvan Pascanu's avatar Razvan Pascanu

improved test

The infershape mechanism differs for the two different type of outputs scan can have (those with no initial state and those with initial state)
上级 12422328
...@@ -2519,18 +2519,40 @@ class T_Scan(unittest.TestCase): ...@@ -2519,18 +2519,40 @@ class T_Scan(unittest.TestCase):
out = f(vx) out = f(vx)
assert out == 24 assert out == 24
def test_infershape_seq_shorter_nsteps(self):
raise KnownFailureTest('This is a generic problem with infershape'
' that has to be discussed and figured out')
x = tensor.vector('x')
[o1, o2], _ = theano.scan(lambda x,y: (x+1, y+x),
sequences = x,
outputs_info = [None, x[0]],
n_steps = 20)
f = theano.function([x], [o1.shape[0], o2.shape[0]], mode = mode_with_opt)
vx = numpy.ones((10,), dtype = theano.config.floatX)
out1, out2 = f(vx)
assert out1 == 10
assert out2 == 10
lssc = [x for x in f.maker.env.toposort()
if isinstance(x.op, theano.scan_module.scan_op.Scan)]
assert len(lssc) == 0
def test_infershape_nsteps_smaller_seq_length(self): def test_infershape_nsteps_smaller_seq_length(self):
x = tensor.vector('x') x = tensor.vector('x')
o, _ = theano.scan(lambda x: x+1, [o1, o2], _ = theano.scan(lambda x, y: (x+1, y+x),
sequences = x, sequences = x,
outputs_info = [None], outputs_info = [None, x[0]],
n_steps = 20) n_steps = 20)
f = theano.function([x], o.shape[0], mode = mode_with_opt) f = theano.function([x], [o1.shape[0], o2.shape[0]],
mode = mode_with_opt)
vx = numpy.ones((30,), dtype = theano.config.floatX) vx = numpy.ones((30,), dtype = theano.config.floatX)
out = f(vx) o1, o2 = f(vx)
assert out == 20 assert o1 == 20
assert o2 == 20
lssc = [x for x in f.maker.env.toposort() lssc = [x for x in f.maker.env.toposort()
if isinstance(x.op, theano.scan_module.scan_op.Scan)] if isinstance(x.op, theano.scan_module.scan_op.Scan)]
assert len(lssc) == 0 assert len(lssc) == 0
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论