提交 9b23a287 authored 作者: Razvan Pascanu's avatar Razvan Pascanu

Test for the bug regarding grad of states with multiple taps.

上级 38d12f58
......@@ -2677,6 +2677,39 @@ class T_Scan(unittest.TestCase):
assert numpy.allclose(f(vx, vA), vR)
def test_grad_multiple_taps_state(self):
# The test is based on the code provided by Timothy Lillicrap
def onestep(xdl,xprev, w):
xnew = tensor.tanh(tensor.dot(w,xprev))
return xnew
xinit = tensor.tensor3('xinit')
w = tensor.matrix('w')
(xseq, updates) = theano.scan(n_steps = 10,
fn = onestep,
outputs_info = [dict(initial = xinit, taps=[-4,-1])],
non_sequences = w)
loss = (xseq[-1]**2).sum()
cost_fn = theano.function([xinit, w],
loss,
no_default_updates = True,
allow_input_downcast = True)
gw, gx = tensor.grad(loss, [w, xinit])
grad_fn = theano.function([xinit, w], [gx,gw],
allow_input_downcast = True)
rng = numpy.random.RandomState(utt.fetch_seed())
v_x = numpy.array(rng.uniform(size=(5,2,3), low=-.5, high=.5),
dtype=theano.config.floatX)
v_w = numpy.array(rng.uniform(size=(2,2)), dtype= theano.config.floatX)
analytic_grad = grad_fn(v_x, v_w)
num_grad = multiple_outputs_numeric_grad(cost_fn,
[v_x, v_w])
max_err, max_err_pos = num_grad.max_err(analytic_grad)
if max_err > 1e-2:
raise Exception(theano.tensor.verify_grad.E_grad,
(max_err, 1e-2, max_err_pos))
def test_speed():
#
# This function prints out the speed of very simple recurrent
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论