提交 3d43bab4 authored 作者: Cesar Laurent's avatar Cesar Laurent

Small fixes

上级 967e3d8c
......@@ -2201,6 +2201,7 @@ class Scan(PureOp):
# Construct scan op
# Seqs
if self.as_while:
# equivalent to x[:n_steps][::-1]
outer_inp_seqs = [x[n_steps - 1::-1] for x in inputs[1:1 + self.n_seqs]]
else:
outer_inp_seqs = [x[::-1] for x in inputs[1:1 + self.n_seqs]]
......@@ -2222,6 +2223,7 @@ class Scan(PureOp):
for x in self.outer_nitsot_outs(dC_douts):
if not isinstance(x.type, DisconnectedType):
if self.as_while:
# equivalent to x[:n_steps][::-1]
outer_inp_seqs.append(x[n_steps - 1::-1])
else:
outer_inp_seqs.append(x[::-1])
......@@ -2232,20 +2234,25 @@ class Scan(PureOp):
# fct add and we want to keep it for all Scan op. This is
# used in T_Scan.test_grad_multiple_outs_taps to test
# that.
if self.as_while:
n = n_steps.tag.test_value
else:
n = inputs[0].tag.test_value
for taps, x in zip(self.mitsot_taps(),
self.outer_mitsot_outs(outs)):
mintap = np.min(taps)
if hasattr(x[::-1][:mintap], 'test_value'):
assert (x[::-1][:mintap].tag.test_value.shape[0] ==
grad_steps.tag.test_value)
assert (x[::-1][:mintap].tag.test_value.shape[0] == n)
for x in self.outer_sitsot_outs(outs):
if hasattr(x[::-1][:-1].tag, 'test_value'):
assert (x[::-1][:-1].tag.test_value.shape[0] ==
grad_steps.tag.test_value)
assert (x[::-1][:-1].tag.test_value.shape[0] == n)
for x in self.outer_nitsot_outs(outs):
if hasattr(x[::-1].tag, 'test_value'):
assert (x[grad_steps - 1::-1].tag.test_value.shape[0] ==
grad_steps.tag.test_value)
if self.as_while:
assert (x[n_steps - 1::-1].tag.test_value.shape[0] ==
n)
else:
assert (x[::-1].tag.test_value.shape[0] == n)
outer_inp_seqs += [x[::-1][:np.min(taps)]
for taps, x in zip(self.mitsot_taps(),
self.outer_mitsot_outs(outs))]
......
......@@ -5535,7 +5535,7 @@ class TestGradUntil(unittest.TestCase):
def setUp(self):
self.x = tensor.vector(name='x')
self.until = tensor.scalar(name='until', dtype='int64')
self.threshold = tensor.scalar(name='threshold', dtype='int64')
self.seq = np.arange(15, dtype=theano.config.floatX)
self.numpy_output = self.seq[:7]**2
z = np.zeros(8, dtype=theano.config.floatX)
......@@ -5545,9 +5545,9 @@ class TestGradUntil(unittest.TestCase):
r, _ = theano.scan(lambda x, u: (x * x,
theano.scan_module.until(x > u)),
sequences=self.x,
non_sequences=[self.until])
non_sequences=[self.threshold])
g = theano.grad(r.sum(), self.x)
f = theano.function([self.x, self.until], [r, g])
f = theano.function([self.x, self.threshold], [r, g])
theano_output, theano_gradient = f(self.seq, 5)
utt.assert_allclose(theano_output, self.numpy_output)
......@@ -5558,10 +5558,10 @@ class TestGradUntil(unittest.TestCase):
r, _ = theano.scan(lambda x, u: (x * x,
theano.scan_module.until(x > u)),
sequences=self.x,
non_sequences=[self.until],
non_sequences=[self.threshold],
truncate_gradient=n)
g = theano.grad(r.sum(), self.x)
f = theano.function([self.x, self.until], [r, g])
f = theano.function([self.x, self.threshold], [r, g])
theano_output, theano_gradient = f(self.seq, 5)
self.numpy_gradient[:7 - n] = 0
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论