提交 3674d24c authored 作者: Razvan Pascanu's avatar Razvan Pascanu

changes to the test

上级 f4bdfea4
...@@ -25,9 +25,37 @@ class TestScan(unittest.TestCase): ...@@ -25,9 +25,37 @@ class TestScan(unittest.TestCase):
def setUp(self): def setUp(self):
utt.seed_rng() utt.seed_rng()
def test001_scan_removed_nsteps_1(self):
def f_pow2(x_tm1):
return 2 * x_tm1
state = theano.tensor.scalar('state')
output, updates = scan_module.scan(f_pow2,
[],
state,
[],
n_steps=1,
truncate_gradient=-1,
go_backwards=False)
my_f = theano.function([state, n_steps],
output,
updates=updates,
allow_input_downcast=True)
assert len([x for x in my_f.maker.env.toposort()
if isinstance(x.op, scan_module.scan_op.ScanOp)])
rng = numpy.random.RandomState(utt.fetch_seed())
state = rng.uniform()
steps = 1
numpy_values = numpy.array([state * (2 ** (k + 1)) for k
in xrange(steps)])
theano_values = my_f(state, steps)
assert numpy.allclose(numpy_values, theano_values)
# generator network, only one output , type scalar ; no sequence or # generator network, only one output , type scalar ; no sequence or
# non sequence arguments # non sequence arguments
def test001_generator_one_scalar_output(self): def test002_generator_one_scalar_output(self):
def f_pow2(x_tm1): def f_pow2(x_tm1):
return 2 * x_tm1 return 2 * x_tm1
state = theano.tensor.scalar('state') state = theano.tensor.scalar('state')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论