提交 0c358fb1 authored 作者: Razvan Pascanu's avatar Razvan Pascanu

refractor test to cover more cases

上级 2bce61d0
...@@ -25,59 +25,33 @@ class TestScan(unittest.TestCase): ...@@ -25,59 +25,33 @@ class TestScan(unittest.TestCase):
def setUp(self): def setUp(self):
utt.seed_rng() utt.seed_rng()
def test001_scan_removed_nsteps_1(self): def test001_generator_one_scalar_output(self):
def f_pow2(x_tm1): def f_pow2(x_tm1):
return 2 * x_tm1 return 2 * x_tm1
for n_steps in [-1,1, 5, -5]:
state = theano.tensor.scalar('state') state = theano.tensor.scalar('state')
output, updates = scan_module.scan(f_pow2, output, updates = scan_module.scan(f_pow2,
[], [],
state, state,
[], [],
n_steps=1, n_steps=n_steps,
truncate_gradient=-1, truncate_gradient=-1,
go_backwards=False) go_backwards=False)
my_f = theano.function([state],
my_f = theano.function([state, n_steps],
output, output,
updates=updates, updates=updates,
allow_input_downcast=True) allow_input_downcast=True)
if abs(n_steps) == 1:
assert len([x for x in my_f.maker.env.toposort() assert len([x for x in my_f.maker.env.toposort()
if isinstance(x.op, scan_module.scan_op.ScanOp)]) if isinstance(x.op, scan_module.scan_op.ScanOp)]) == 0
rng = numpy.random.RandomState(utt.fetch_seed()) rng = numpy.random.RandomState(utt.fetch_seed())
state = rng.uniform() state = rng.uniform()
steps = 1
numpy_values = numpy.array([state * (2 ** (k + 1)) for k numpy_values = numpy.array([state * (2 ** (k + 1)) for k
in xrange(steps)]) in xrange(abs(n_steps))])
theano_values = my_f(state, steps) theano_values = my_f(state)
assert numpy.allclose(numpy_values, theano_values) assert numpy.allclose(numpy_values, theano_values)
# generator network, only one output , type scalar ; no sequence or
# non sequence arguments
def test002_generator_one_scalar_output(self):
def f_pow2(x_tm1):
return 2 * x_tm1
state = theano.tensor.scalar('state')
n_steps = theano.tensor.iscalar('nsteps')
output, updates = scan_module.scan(f_pow2,
[],
state,
[],
n_steps=n_steps,
truncate_gradient=-1,
go_backwards=False)
my_f = theano.function([state, n_steps],
output,
updates=updates,
allow_input_downcast=True)
rng = numpy.random.RandomState(utt.fetch_seed())
state = rng.uniform()
steps = 5
numpy_values = numpy.array([state * (2 ** (k + 1)) for k
in xrange(steps)])
theano_values = my_f(state, steps)
assert numpy.allclose(numpy_values, theano_values)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论