提交 7d74c850 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Use safe_new(x) instead of x.type() in scan grad.

This copies the test_value, making possible to use a test value for sequences. Add test.
上级 e9d83000
......@@ -1446,7 +1446,7 @@ class Scan(PureOp):
# We are looking for x[t-1] for a given x[t]
if idx >= self.n_mit_mot_outs:
Xt_placeholder = Xt.type()
Xt_placeholder = safe_new(Xt)
Xts.append(Xt_placeholder)
if Xt not in self.inner_nitsot_outs(self_outputs):
# What we do here is loop through dC_douts and collect all
......@@ -1502,12 +1502,12 @@ class Scan(PureOp):
for pos, x in enumerate(dC_dinps_t[self.n_seqs:]):
opos = self.get_output_pos(pos)
if opos >= 0:
dC_dXtm1s.append(dC_dXts[opos].type())
dC_dXtm1s.append(safe_new(dC_dXts[opos]))
if x.dtype != dC_dXts[opos].dtype:
dC_dinps_t[pos + self.n_seqs] = \
x.astype(dC_dXts[opos].dtype)
else:
dC_dXtm1s.append(x.type())
dC_dXtm1s.append(safe_new(x))
for dx, dC_dXtm1 in enumerate(dC_dXtm1s):
dC_dinps_t[dx + self.n_seqs] += dC_dXtm1
# Construct scan op
......
......@@ -3809,10 +3809,35 @@ def test_compute_test_value():
x.tag.test_value = xv
y = theano.shared(numpy.arange(3, dtype=theano.config.floatX),
name='y')
z, _ = theano.scan(
z, updates = theano.scan(
fn=lambda u, v: u + v,
sequences=[x, y])
assert not _
assert not updates
z.name = 'z'
# The gradient computation used to crash before 6af465e.
g = tensor.grad(z.sum(), x)
#f = theano.function([x], g)
#print f(xv)
finally:
theano.config.compute_test_value = backup
def test_compute_test_value_nonseq():
# Verify that test values can be used for non_sequences with scan.
backup = theano.config.compute_test_value
theano.config.compute_test_value = 'raise'
try:
x = tensor.vector('x')
xv = numpy.ones(3, dtype=theano.config.floatX)
x.tag.test_value = xv
y = theano.shared(
numpy.arange(9, dtype=theano.config.floatX).reshape(3, 3),
name='y')
z, updates = theano.scan(
fn=lambda u, v: u + v,
sequences=[x],
non_sequences=[y])
assert not updates
z.name = 'z'
# The gradient computation used to crash before 6af465e.
g = tensor.grad(z.sum(), x)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论