提交 140d61b7 authored 作者: nouiz's avatar nouiz

Merge pull request #315 from pascanur/scan_grax_fixed

Scan grad computation fixed
......@@ -1289,7 +1289,7 @@ class Scan(PureOp):
for idx in xrange(self.n_mit_mot + self.n_mit_sot):
mintap = numpy.min(self.tap_array[idx])
maxtap = numpy.max(self.tap_array[idx])
seq = scan_outputs[offset+idx][::-1]
seq = scan_outputs[offset+idx]
for k in self.tap_array[idx]:
# We cut the sequence such that seq[i] to correspond to
# seq[i-k]
......@@ -1300,9 +1300,9 @@ class Scan(PureOp):
if maxtap == mintap and maxtap != 0:
nw_seq =seq[:abs(maxtap)]
elif maxtap -k != 0 :
nw_seq = seq[dim_offset +k -mintap: -(maxtap -k)]
nw_seq = seq[dim_offset +k -mintap - 1: -(maxtap -k + 1)][::-1]
else:
nw_seq = seq[dim_offset +k -mintap: ]
nw_seq = seq[dim_offset +k -mintap - 1: -1 ][::-1]
if getattr(seq,'name', None) is not None:
nw_seq.name = seq.name + '[%d:]'%k
scan_seqs.append(nw_seq)
......
......@@ -2677,6 +2677,39 @@ class T_Scan(unittest.TestCase):
assert numpy.allclose(f(vx, vA), vR)
def test_grad_multiple_taps_state(self):
# The test is based on the code provided by Timothy Lillicrap
def onestep(xdl,xprev, w):
xnew = tensor.tanh(tensor.dot(w,xprev))
return xnew
xinit = tensor.tensor3('xinit')
w = tensor.matrix('w')
(xseq, updates) = theano.scan(n_steps = 10,
fn = onestep,
outputs_info = [dict(initial = xinit, taps=[-4,-1])],
non_sequences = w)
loss = (xseq[-1]**2).sum()
cost_fn = theano.function([xinit, w],
loss,
no_default_updates = True,
allow_input_downcast = True)
gw, gx = tensor.grad(loss, [w, xinit])
grad_fn = theano.function([xinit, w], [gx,gw],
allow_input_downcast = True)
rng = numpy.random.RandomState(utt.fetch_seed())
v_x = numpy.array(rng.uniform(size=(5,2,3), low=-.5, high=.5),
dtype=theano.config.floatX)
v_w = numpy.array(rng.uniform(size=(2,2)), dtype= theano.config.floatX)
analytic_grad = grad_fn(v_x, v_w)
num_grad = multiple_outputs_numeric_grad(cost_fn,
[v_x, v_w])
max_err, max_err_pos = num_grad.max_err(analytic_grad)
if max_err > 1e-2:
raise Exception(theano.tensor.verify_grad.E_grad,
(max_err, 1e-2, max_err_pos))
def test_speed():
#
# This function prints out the speed of very simple recurrent
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论