提交 c10edb16 authored 作者: nouiz's avatar nouiz

Merge pull request #1310 from lamblin/fix_scan_grad_testval

Use safe_new(x) instead of x.type() in scan grad.
...@@ -1446,7 +1446,7 @@ class Scan(PureOp): ...@@ -1446,7 +1446,7 @@ class Scan(PureOp):
# We are looking for x[t-1] for a given x[t] # We are looking for x[t-1] for a given x[t]
if idx >= self.n_mit_mot_outs: if idx >= self.n_mit_mot_outs:
Xt_placeholder = Xt.type() Xt_placeholder = safe_new(Xt)
Xts.append(Xt_placeholder) Xts.append(Xt_placeholder)
if Xt not in self.inner_nitsot_outs(self_outputs): if Xt not in self.inner_nitsot_outs(self_outputs):
# What we do here is loop through dC_douts and collect all # What we do here is loop through dC_douts and collect all
...@@ -1502,12 +1502,12 @@ class Scan(PureOp): ...@@ -1502,12 +1502,12 @@ class Scan(PureOp):
for pos, x in enumerate(dC_dinps_t[self.n_seqs:]): for pos, x in enumerate(dC_dinps_t[self.n_seqs:]):
opos = self.get_output_pos(pos) opos = self.get_output_pos(pos)
if opos >= 0: if opos >= 0:
dC_dXtm1s.append(dC_dXts[opos].type()) dC_dXtm1s.append(safe_new(dC_dXts[opos]))
if x.dtype != dC_dXts[opos].dtype: if x.dtype != dC_dXts[opos].dtype:
dC_dinps_t[pos + self.n_seqs] = \ dC_dinps_t[pos + self.n_seqs] = \
x.astype(dC_dXts[opos].dtype) x.astype(dC_dXts[opos].dtype)
else: else:
dC_dXtm1s.append(x.type()) dC_dXtm1s.append(safe_new(x))
for dx, dC_dXtm1 in enumerate(dC_dXtm1s): for dx, dC_dXtm1 in enumerate(dC_dXtm1s):
dC_dinps_t[dx + self.n_seqs] += dC_dXtm1 dC_dinps_t[dx + self.n_seqs] += dC_dXtm1
# Construct scan op # Construct scan op
......
...@@ -3809,10 +3809,35 @@ def test_compute_test_value(): ...@@ -3809,10 +3809,35 @@ def test_compute_test_value():
x.tag.test_value = xv x.tag.test_value = xv
y = theano.shared(numpy.arange(3, dtype=theano.config.floatX), y = theano.shared(numpy.arange(3, dtype=theano.config.floatX),
name='y') name='y')
z, _ = theano.scan( z, updates = theano.scan(
fn=lambda u, v: u + v, fn=lambda u, v: u + v,
sequences=[x, y]) sequences=[x, y])
assert not _ assert not updates
z.name = 'z'
# The gradient computation used to crash before 6af465e.
g = tensor.grad(z.sum(), x)
#f = theano.function([x], g)
#print f(xv)
finally:
theano.config.compute_test_value = backup
def test_compute_test_value_nonseq():
# Verify that test values can be used for non_sequences with scan.
backup = theano.config.compute_test_value
theano.config.compute_test_value = 'raise'
try:
x = tensor.vector('x')
xv = numpy.ones(3, dtype=theano.config.floatX)
x.tag.test_value = xv
y = theano.shared(
numpy.arange(9, dtype=theano.config.floatX).reshape(3, 3),
name='y')
z, updates = theano.scan(
fn=lambda u, v: u + v,
sequences=[x],
non_sequences=[y])
assert not updates
z.name = 'z' z.name = 'z'
# The gradient computation used to crash before 6af465e. # The gradient computation used to crash before 6af465e.
g = tensor.grad(z.sum(), x) g = tensor.grad(z.sum(), x)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论