提交 08d73bae authored 作者: Razvan Pascanu's avatar Razvan Pascanu

two new tests

Conflicts: theano/scan_module/tests/test_scan.py
上级 7c6efc3f
...@@ -3306,6 +3306,70 @@ class T_Scan(unittest.TestCase): ...@@ -3306,6 +3306,70 @@ class T_Scan(unittest.TestCase):
theano.scan_module.scan_op.Scan)] theano.scan_module.scan_op.Scan)]
assert len(scan_nodes) == 1 assert len(scan_nodes) == 1
def test_eliminate_seqs(self):
U = tensor.vector('U')
sh = theano.shared(asarrayX(2.))
x1 = tensor.vector('x1')
x2 = tensor.scalar('x2')
def rec_fn(*args):
u_t = args[0]
return [(u_t + 1, # mitsot
u_t + 2, # sitsot
u_t + 3), # nitsot
{sh: u_t + 4}] # shared
[X1, X2, X3], updates = theano.scan(
rec_fn,
U,
[dict(initial=x1, taps=[-1, -3]), x2, None],
n_steps=None,
truncate_gradient=-1,
go_backwards=False)
f = theano.function([U, x1, x2], [X1, X2, X3],
updates=updates,
mode=theano.Mode(linker='py'),
allow_input_downcast=True)
rng = numpy.random.RandomState(utt.fetch_seed())
v_u = asarrayX(rng.uniform(size=(5,)))
outs = f(v_u, [0, 0, 0], 0)
assert numpy.allclose(outs[0], v_u + 1)
assert numpy.allclose(outs[1], v_u + 2)
assert numpy.allclose(outs[2], v_u + 3)
assert numpy.allclose(sh.get_value(), v_u[-1] + 4)
def test_eliminate_nonseqs(self):
W = tensor.scalar('W')
sh = theano.shared(asarrayX(2.))
x1 = tensor.vector('x1')
x2 = tensor.scalar('x2')
def rec_fn(*args):
w = args[-1]
return [(w + 1., # mitsot
w + 2., # sitsot
w + 3.), # nitsot
{sh: w + 4.}] # shared
[X1, X2, X3], updates = theano.scan(
rec_fn,
[],
[dict(initial=x1, taps=[-1, -3]), x2, None],
W,
n_steps=5,
truncate_gradient=-1,
go_backwards=False)
f = theano.function([W, x1, x2], [X1, X2, X3],
updates=updates,
mode=theano.Mode(linker='py'),
allow_input_downcast=True)
rng = numpy.random.RandomState(utt.fetch_seed())
v_w = asarrayX(rng.uniform())
outs = f(v_w, [0, 0, 0], 0)
assert numpy.allclose(outs[0], v_w + 1)
assert numpy.allclose(outs[1], v_w + 2)
assert numpy.allclose(outs[2], v_w + 3)
assert numpy.allclose(sh.get_value(), v_w + 4)
def test_speed(): def test_speed():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论