提交 e0ef024d authored 作者: Alexander Matyasko's avatar Alexander Matyasko

Add tests for scan until when number of dim > 1

上级 9a3fb648
......@@ -5489,6 +5489,25 @@ class TestGradUntil(unittest.TestCase):
utt.assert_allclose(theano_output, self.numpy_output)
utt.assert_allclose(theano_gradient, self.numpy_gradient)
def test_grad_until_ndim_greater_one(self):
def tile_array(inp):
n_cols = 5
return np.tile(inp.reshape((-1, 1)), (1, n_cols))
X = tensor.matrix(name='x')
arr = tile_array(self.seq)
r, _ = theano.scan(lambda x, u: (x * x,
theano.scan_module.until(
tensor.all(x > u))),
sequences=X,
non_sequences=[self.threshold])
g = theano.grad(r.sum(), X)
f = theano.function([X, self.threshold], [r, g])
theano_output, theano_gradient = f(arr, 5)
utt.assert_allclose(theano_output, tile_array(self.numpy_output))
utt.assert_allclose(theano_gradient, tile_array(self.numpy_gradient))
def test_grad_until_and_truncate(self):
n = 3
r, _ = theano.scan(lambda x, u: (x * x,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论