提交 f69a07ad authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Fix gradient of scan when some outputs are disconnected

Add test case that was reported on theano-users.
上级 370953c5
...@@ -1720,9 +1720,12 @@ class Scan(PureOp): ...@@ -1720,9 +1720,12 @@ class Scan(PureOp):
offset = self.n_mit_mot offset = self.n_mit_mot
for idx in xrange(self.n_mit_sot): for idx in xrange(self.n_mit_sot):
if isinstance(dC_douts[idx + offset].type, DisconnectedType):
outer_inp_mitmot.append(outs[idx + offset].zeros_like())
else:
outer_inp_mitmot.append(dC_douts[idx + offset][::-1])
mitmot_inp_taps.append([]) mitmot_inp_taps.append([])
mitmot_out_taps.append([]) mitmot_out_taps.append([])
outer_inp_mitmot.append(dC_douts[idx + offset][::-1])
idx_tap = idx + self.n_mit_mot idx_tap = idx + self.n_mit_mot
inner_inp_mitmot.append(dC_dXts[out_pos]) inner_inp_mitmot.append(dC_dXts[out_pos])
out_pos += 1 out_pos += 1
......
...@@ -1852,6 +1852,80 @@ class T_Scan(unittest.TestCase): ...@@ -1852,6 +1852,80 @@ class T_Scan(unittest.TestCase):
analytic_grad = reset_rng_grad_fn(v_u, v_x0, vW_in) analytic_grad = reset_rng_grad_fn(v_u, v_x0, vW_in)
utt.assert_allclose(analytic_grad[0][:2], numpy.zeros((2, 2))) utt.assert_allclose(analytic_grad[0][:2], numpy.zeros((2, 2)))
def test_grad_multiple_outs_some_disconnected(self):
# Created on Tue Oct 07 13:28:51 2014
# @author: vaneetke
rng = numpy.random.RandomState(utt.fetch_seed())
n_hid = 3
n_in = 1
n_out = 1
W_hh_v = asarrayX(rng.uniform(size=(n_hid, n_hid), low=-.01, high=.01))
h0_v = asarrayX(rng.uniform(size=(2, n_hid), low=-.01, high=.01))
b_h_v = asarrayX(rng.uniform(size=(n_hid), low=-.01, high=.01))
W_ih_v = asarrayX(rng.uniform(size=(n_in, n_hid), low=-.01, high=.01))
W_ho_v = asarrayX(rng.uniform(size=(n_hid, n_out), low=-.01, high=.01))
b_o_v = asarrayX(rng.uniform(size=(n_out), low=-.01, high=.01))
# parameters of the rnn
b_h = theano.shared(b_h_v)
h0 = theano.shared(h0_v)
W_ih = theano.shared(W_ih_v)
W_hh = theano.shared(W_hh_v)
W_ho = theano.shared(W_ho_v)
b_o = theano.shared(b_o_v)
params = [W_ih, W_hh, b_h, W_ho, b_o, h0]
# first dimension is time
x = tensor.matrix()
# sequences: x_t
# prior results: h_tm2, h_tm1
# non-sequences: W_ih, W_hh, W_ho, b_h
def one_step(x_t, h_tm2, h_tm1, W_ih, W_hh, b_h, W_ho, b_o):
h_t = tensor.tanh(theano.dot(x_t, W_ih)
+ theano.dot(h_tm2, W_hh) + b_h)
y_t = theano.dot(h_t, W_ho) + b_o
return [h_t, y_t]
# hidden and outputs of the entire sequence
[h, y], _ = theano.scan(
fn=one_step,
sequences = dict(input=x),
# corresponds to the return type of one_step
outputs_info = [dict(initial=h0, taps=[-2, -1]), None],
non_sequences = [W_ih, W_hh, b_h, W_ho, b_o])
# target values
t = tensor.matrix()
# learning rate
lr = asarrayX(0.1)
learning_rate = theano.shared(lr)
cost = ((0.5 * ((y - t) ** 2.0).mean())
+ (0.5 * (y.std() - t.std()) ** 2.0))
gparams = T.grad(cost, params)
updates = [(param, param - gparam * learning_rate)
for param, gparam in zip(params, gparams)]
learn_rnn_fn = theano.function(inputs=[x, t],
outputs=cost,
updates=updates)
eval_rnn_fn = theano.function(inputs=[x],
outputs=y)
# artificial data
x_v = numpy.arange(0., 100., 0.21, dtype=theano.config.floatX)
x_v = x_v.reshape(len(x_v), 1)
s_v = numpy.sin(x_v)
t_v = numpy.roll(s_v, -1)[:-1]
s_v = s_v[:-1]
for i in xrange(100):
cost = learn_rnn_fn(s_v, t_v)
pred = eval_rnn_fn(s_v)
assert cost < 0.02
def test_draw_as_input_to_scan(self): def test_draw_as_input_to_scan(self):
trng = theano.tensor.shared_randomstreams.RandomStreams(123) trng = theano.tensor.shared_randomstreams.RandomStreams(123)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论