提交 fde62163 authored 作者: Ricardo's avatar Ricardo 提交者: Ricardo Vieira

Fix `DiffOp` `infer_shape` when `n` is larger than axis dimension

Other changes: * Test negative axis * Add test for faulty gradient behavior when n is larger than input size * Add test for NotImplementedError for gradients of ndim != 1 * Remove vague aesara.function call in tests
上级 40b51621
...@@ -505,6 +505,7 @@ class DiffOp(Op): ...@@ -505,6 +505,7 @@ class DiffOp(Op):
app = at.concatenate([z, [0.0]]) app = at.concatenate([z, [0.0]])
return pre - app return pre - app
# FIXME: This fails when n is larger than the input size
for k in range(self.n): for k in range(self.n):
z = _grad_helper(z) z = _grad_helper(z)
return [z] return [z]
...@@ -512,7 +513,7 @@ class DiffOp(Op): ...@@ -512,7 +513,7 @@ class DiffOp(Op):
def infer_shape(self, fgraph, node, ins_shapes): def infer_shape(self, fgraph, node, ins_shapes):
i0_shapes = ins_shapes[0] i0_shapes = ins_shapes[0]
out_shape = list(i0_shapes) out_shape = list(i0_shapes)
out_shape[self.axis] = out_shape[self.axis] - self.n out_shape[self.axis] = at_max((0, out_shape[self.axis] - self.n))
return [out_shape] return [out_shape]
......
...@@ -43,7 +43,6 @@ from aesara.tensor.extra_ops import ( ...@@ -43,7 +43,6 @@ from aesara.tensor.extra_ops import (
to_one_hot, to_one_hot,
unravel_index, unravel_index,
) )
from aesara.tensor.math import sum as at_sum
from aesara.tensor.subtensor import AdvancedIncSubtensor from aesara.tensor.subtensor import AdvancedIncSubtensor
from aesara.tensor.type import ( from aesara.tensor.type import (
TensorType, TensorType,
...@@ -291,13 +290,6 @@ class TestBinCount(utt.InferShapeTester): ...@@ -291,13 +290,6 @@ class TestBinCount(utt.InferShapeTester):
class TestDiffOp(utt.InferShapeTester): class TestDiffOp(utt.InferShapeTester):
nb = 10 # Number of time iterating for n
def setup_method(self):
super().setup_method()
self.op_class = DiffOp
self.op = DiffOp()
def test_diffOp(self): def test_diffOp(self):
x = matrix("x") x = matrix("x")
a = np.random.random((30, 50)).astype(config.floatX) a = np.random.random((30, 50)).astype(config.floatX)
...@@ -305,33 +297,42 @@ class TestDiffOp(utt.InferShapeTester): ...@@ -305,33 +297,42 @@ class TestDiffOp(utt.InferShapeTester):
f = aesara.function([x], diff(x)) f = aesara.function([x], diff(x))
assert np.allclose(np.diff(a), f(a)) assert np.allclose(np.diff(a), f(a))
for axis in range(len(a.shape)): for axis in (-2, -1, 0, 1):
for k in range(TestDiffOp.nb): for n in (0, 1, 2, a.shape[0], a.shape[0] + 1):
g = aesara.function([x], diff(x, n=k, axis=axis)) g = aesara.function([x], diff(x, n=n, axis=axis))
assert np.allclose(np.diff(a, n=k, axis=axis), g(a)) assert np.allclose(np.diff(a, n=n, axis=axis), g(a))
def test_infer_shape(self): def test_infer_shape(self):
x = matrix("x") x = matrix("x")
a = np.random.random((30, 50)).astype(config.floatX) a = np.random.random((30, 50)).astype(config.floatX)
self._compile_and_check([x], [self.op(x)], [a], self.op_class) # Test default n and axis
self._compile_and_check([x], [DiffOp()(x)], [a], DiffOp)
for axis in range(len(a.shape)): for axis in (-2, -1, 0, 1):
for k in range(TestDiffOp.nb): for n in (0, 1, 2, a.shape[0], a.shape[0] + 1):
self._compile_and_check( self._compile_and_check([x], [diff(x, n=n, axis=axis)], [a], DiffOp)
[x], [diff(x, n=k, axis=axis)], [a], self.op_class
)
def test_grad(self): def test_grad(self):
x = vector("x")
a = np.random.random(50).astype(config.floatX) a = np.random.random(50).astype(config.floatX)
aesara.function([x], grad(at_sum(diff(x)), x)) # Test default n and axis
utt.verify_grad(self.op, [a]) utt.verify_grad(DiffOp(), [a])
for n in (0, 1, 2, a.shape[0]):
utt.verify_grad(DiffOp(n=n), [a], eps=7e-3)
@pytest.mark.xfail(reason="gradient is wrong when n is larger than input size")
def test_grad_n_larger_than_input(self):
# Gradient is wrong when n is larger than the input size. Until it is fixed,
# this test ensures the behavior is documented
a = np.random.random(10).astype(config.floatX)
utt.verify_grad(DiffOp(n=11), [a], eps=7e-3)
for k in range(TestDiffOp.nb): def test_grad_not_implemented(self):
aesara.function([x], grad(at_sum(diff(x, n=k)), x)) x = at.matrix("x")
utt.verify_grad(DiffOp(n=k), [a], eps=7e-3) with pytest.raises(NotImplementedError):
grad(diff(x).sum(), x)
class TestSqueeze(utt.InferShapeTester): class TestSqueeze(utt.InferShapeTester):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论