提交 c23a8365 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Merge pull request #2624 from carriepl/scan_crash

[CRASH] Fix crash with scan grad and add unit test
......@@ -1931,9 +1931,10 @@ class Scan(PureOp):
type_outs.append(vl.type.why_null)
# Replace the inner output with a zero tensor of
# the right shape
inner_out_sitsot[_p] = tensor.zeros(
inner_out_nitsot[_p] = tensor.zeros(
diff_inputs[_p].shape,
dtype=theano.config.floatX)
if through_shared:
type_outs.append('through_shared')
elif disconnected_dC_dinps_t[_p]:
......
......@@ -836,6 +836,36 @@ class T_Scan(unittest.TestCase):
n_steps=2)
tensor.grad(a[-1], a0)
def test_grad_two_scans(self):
# data input & output
x = tensor.tensor3('x')
t = tensor.imatrix('t')
# forward pass
W = theano.shared(
numpy.random.randn(2, 2).astype('float32'),
name="W", borrow=True)
def forward_scanner(x_t):
a2_t = tensor.dot(x_t, W)
y_t = tensor.nnet.softmax(a2_t)
return y_t
y, _ = theano.scan(fn=forward_scanner, sequences=x,
outputs_info=[None])
# loss function
def error_scanner(y_t, t_t):
return tensor.mean(tensor.nnet.categorical_crossentropy(y_t, t_t))
L, _ = theano.scan(fn=error_scanner, sequences=[y, t],
outputs_info=[None])
L = tensor.mean(L)
# backward pass
gW = tensor.grad(L, [W])
# simple rnn, one input, one state, weights for each; input/state are
# vectors, weights are scalars; using shared variables and past
# taps (sequences and outputs)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论