提交 3d43bab4 authored 作者: Cesar Laurent's avatar Cesar Laurent

Small fixes

上级 967e3d8c
...@@ -2201,6 +2201,7 @@ class Scan(PureOp): ...@@ -2201,6 +2201,7 @@ class Scan(PureOp):
# Construct scan op # Construct scan op
# Seqs # Seqs
if self.as_while: if self.as_while:
# equivalent to x[:n_steps][::-1]
outer_inp_seqs = [x[n_steps - 1::-1] for x in inputs[1:1 + self.n_seqs]] outer_inp_seqs = [x[n_steps - 1::-1] for x in inputs[1:1 + self.n_seqs]]
else: else:
outer_inp_seqs = [x[::-1] for x in inputs[1:1 + self.n_seqs]] outer_inp_seqs = [x[::-1] for x in inputs[1:1 + self.n_seqs]]
...@@ -2222,6 +2223,7 @@ class Scan(PureOp): ...@@ -2222,6 +2223,7 @@ class Scan(PureOp):
for x in self.outer_nitsot_outs(dC_douts): for x in self.outer_nitsot_outs(dC_douts):
if not isinstance(x.type, DisconnectedType): if not isinstance(x.type, DisconnectedType):
if self.as_while: if self.as_while:
# equivalent to x[:n_steps][::-1]
outer_inp_seqs.append(x[n_steps - 1::-1]) outer_inp_seqs.append(x[n_steps - 1::-1])
else: else:
outer_inp_seqs.append(x[::-1]) outer_inp_seqs.append(x[::-1])
...@@ -2232,20 +2234,25 @@ class Scan(PureOp): ...@@ -2232,20 +2234,25 @@ class Scan(PureOp):
# fct add and we want to keep it for all Scan op. This is # fct add and we want to keep it for all Scan op. This is
# used in T_Scan.test_grad_multiple_outs_taps to test # used in T_Scan.test_grad_multiple_outs_taps to test
# that. # that.
if self.as_while:
n = n_steps.tag.test_value
else:
n = inputs[0].tag.test_value
for taps, x in zip(self.mitsot_taps(), for taps, x in zip(self.mitsot_taps(),
self.outer_mitsot_outs(outs)): self.outer_mitsot_outs(outs)):
mintap = np.min(taps) mintap = np.min(taps)
if hasattr(x[::-1][:mintap], 'test_value'): if hasattr(x[::-1][:mintap], 'test_value'):
assert (x[::-1][:mintap].tag.test_value.shape[0] == assert (x[::-1][:mintap].tag.test_value.shape[0] == n)
grad_steps.tag.test_value)
for x in self.outer_sitsot_outs(outs): for x in self.outer_sitsot_outs(outs):
if hasattr(x[::-1][:-1].tag, 'test_value'): if hasattr(x[::-1][:-1].tag, 'test_value'):
assert (x[::-1][:-1].tag.test_value.shape[0] == assert (x[::-1][:-1].tag.test_value.shape[0] == n)
grad_steps.tag.test_value)
for x in self.outer_nitsot_outs(outs): for x in self.outer_nitsot_outs(outs):
if hasattr(x[::-1].tag, 'test_value'): if hasattr(x[::-1].tag, 'test_value'):
assert (x[grad_steps - 1::-1].tag.test_value.shape[0] == if self.as_while:
grad_steps.tag.test_value) assert (x[n_steps - 1::-1].tag.test_value.shape[0] ==
n)
else:
assert (x[::-1].tag.test_value.shape[0] == n)
outer_inp_seqs += [x[::-1][:np.min(taps)] outer_inp_seqs += [x[::-1][:np.min(taps)]
for taps, x in zip(self.mitsot_taps(), for taps, x in zip(self.mitsot_taps(),
self.outer_mitsot_outs(outs))] self.outer_mitsot_outs(outs))]
......
...@@ -5535,7 +5535,7 @@ class TestGradUntil(unittest.TestCase): ...@@ -5535,7 +5535,7 @@ class TestGradUntil(unittest.TestCase):
def setUp(self): def setUp(self):
self.x = tensor.vector(name='x') self.x = tensor.vector(name='x')
self.until = tensor.scalar(name='until', dtype='int64') self.threshold = tensor.scalar(name='threshold', dtype='int64')
self.seq = np.arange(15, dtype=theano.config.floatX) self.seq = np.arange(15, dtype=theano.config.floatX)
self.numpy_output = self.seq[:7]**2 self.numpy_output = self.seq[:7]**2
z = np.zeros(8, dtype=theano.config.floatX) z = np.zeros(8, dtype=theano.config.floatX)
...@@ -5545,9 +5545,9 @@ class TestGradUntil(unittest.TestCase): ...@@ -5545,9 +5545,9 @@ class TestGradUntil(unittest.TestCase):
r, _ = theano.scan(lambda x, u: (x * x, r, _ = theano.scan(lambda x, u: (x * x,
theano.scan_module.until(x > u)), theano.scan_module.until(x > u)),
sequences=self.x, sequences=self.x,
non_sequences=[self.until]) non_sequences=[self.threshold])
g = theano.grad(r.sum(), self.x) g = theano.grad(r.sum(), self.x)
f = theano.function([self.x, self.until], [r, g]) f = theano.function([self.x, self.threshold], [r, g])
theano_output, theano_gradient = f(self.seq, 5) theano_output, theano_gradient = f(self.seq, 5)
utt.assert_allclose(theano_output, self.numpy_output) utt.assert_allclose(theano_output, self.numpy_output)
...@@ -5558,10 +5558,10 @@ class TestGradUntil(unittest.TestCase): ...@@ -5558,10 +5558,10 @@ class TestGradUntil(unittest.TestCase):
r, _ = theano.scan(lambda x, u: (x * x, r, _ = theano.scan(lambda x, u: (x * x,
theano.scan_module.until(x > u)), theano.scan_module.until(x > u)),
sequences=self.x, sequences=self.x,
non_sequences=[self.until], non_sequences=[self.threshold],
truncate_gradient=n) truncate_gradient=n)
g = theano.grad(r.sum(), self.x) g = theano.grad(r.sum(), self.x)
f = theano.function([self.x, self.until], [r, g]) f = theano.function([self.x, self.threshold], [r, g])
theano_output, theano_gradient = f(self.seq, 5) theano_output, theano_gradient = f(self.seq, 5)
self.numpy_gradient[:7 - n] = 0 self.numpy_gradient[:7 - n] = 0
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论