提交 8136f204 authored 作者: Pascal Lamblin's avatar Pascal Lamblin 提交者: GitHub

Merge pull request #6201 from aam-at/scan_until_fix

Fix bug in scan_until shape concatenation
...@@ -2561,7 +2561,7 @@ class Scan(PureOp): ...@@ -2561,7 +2561,7 @@ class Scan(PureOp):
n_zeros = inputs[0] - n_steps n_zeros = inputs[0] - n_steps
shp = (n_zeros,) shp = (n_zeros,)
if x.ndim > 1: if x.ndim > 1:
shp = shp + x.shape[1:] shp = shp + tuple(x.shape[i] for i in range(1, x.ndim))
z = tensor.zeros(shp, dtype=x.dtype) z = tensor.zeros(shp, dtype=x.dtype)
x = tensor.concatenate([x[::-1], z], axis=0) x = tensor.concatenate([x[::-1], z], axis=0)
gradients.append(x) gradients.append(x)
...@@ -2589,7 +2589,7 @@ class Scan(PureOp): ...@@ -2589,7 +2589,7 @@ class Scan(PureOp):
n_zeros = inputs[0] - grad_steps n_zeros = inputs[0] - grad_steps
shp = (n_zeros,) shp = (n_zeros,)
if x.ndim > 1: if x.ndim > 1:
shp = shp + x.shape[1:] shp = shp + tuple(x.shape[i] for i in range(1, x.ndim))
z = tensor.zeros(shp, dtype=x.dtype) z = tensor.zeros(shp, dtype=x.dtype)
x = tensor.concatenate([x[::-1], z], axis=0) x = tensor.concatenate([x[::-1], z], axis=0)
gradients.append(x) gradients.append(x)
......
...@@ -5489,6 +5489,25 @@ class TestGradUntil(unittest.TestCase): ...@@ -5489,6 +5489,25 @@ class TestGradUntil(unittest.TestCase):
utt.assert_allclose(theano_output, self.numpy_output) utt.assert_allclose(theano_output, self.numpy_output)
utt.assert_allclose(theano_gradient, self.numpy_gradient) utt.assert_allclose(theano_gradient, self.numpy_gradient)
def test_grad_until_ndim_greater_one(self):
def tile_array(inp):
n_cols = 5
return np.tile(inp.reshape((-1, 1)), (1, n_cols))
X = tensor.matrix(name='x')
arr = tile_array(self.seq)
r, _ = theano.scan(lambda x, u: (x * x,
theano.scan_module.until(
tensor.all(x > u))),
sequences=X,
non_sequences=[self.threshold])
g = theano.grad(r.sum(), X)
f = theano.function([X, self.threshold], [r, g])
theano_output, theano_gradient = f(arr, 5)
utt.assert_allclose(theano_output, tile_array(self.numpy_output))
utt.assert_allclose(theano_gradient, tile_array(self.numpy_gradient))
def test_grad_until_and_truncate(self): def test_grad_until_and_truncate(self):
n = 3 n = 3
r, _ = theano.scan(lambda x, u: (x * x, r, _ = theano.scan(lambda x, u: (x * x,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论