提交 2abc014b authored 作者: Razvan Pascanu's avatar Razvan Pascanu

fix formatting of the new introduced test

上级 fc22aaf9
...@@ -2489,18 +2489,17 @@ class T_Scan(unittest.TestCase): ...@@ -2489,18 +2489,17 @@ class T_Scan(unittest.TestCase):
# get the right values in # get the right values in
c = theano.tensor.vector('c') c = theano.tensor.vector('c')
x = theano.tensor.scalar('x') x = theano.tensor.scalar('x')
_max_coefficients_supported = 100 _max_coefficients_supported = 1000
full_range = theano.tensor.arange(_max_coefficients_supported) full_range = theano.tensor.arange(_max_coefficients_supported)
components, updates = theano.scan(fn=lambda coeff, power, components, updates = theano.scan(
free_var: fn=lambda coeff, power, free_var: coeff * (free_var ** power),
coeff * (free_var ** power),
outputs_info=None, outputs_info=None,
sequences=[c, full_range], sequences=[c, full_range],
non_sequences=x) non_sequences=x)
P = components.sum() P = components.sum()
dP = theano.tensor.grad(P, x) dP = theano.tensor.grad(P, x)
tf = theano.function([c,x], dP) tf = theano.function([c, x], dP)
assert tf([1.0,2.0,-3.0,4.0], 2.0) == 38 assert tf([1.0, 2.0, -3.0, 4.0], 2.0) == 38
def test_return_steps(self): def test_return_steps(self):
rng = numpy.random.RandomState(utt.fetch_seed()) rng = numpy.random.RandomState(utt.fetch_seed())
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论