提交 31ab8fdd authored 作者: Ricardo's avatar Ricardo 提交者: Ricardo Vieira

Fix `DiffOp` output type when input has partially known shape

上级 fde62163
...@@ -2,6 +2,7 @@ from collections.abc import Collection ...@@ -2,6 +2,7 @@ from collections.abc import Collection
from typing import Iterable, Tuple, Union from typing import Iterable, Tuple, Union
import numpy as np import numpy as np
import numpy.core.numeric
import aesara import aesara
from aesara.gradient import ( from aesara.gradient import (
...@@ -482,7 +483,17 @@ class DiffOp(Op): ...@@ -482,7 +483,17 @@ class DiffOp(Op):
def make_node(self, x): def make_node(self, x):
x = at.as_tensor_variable(x) x = at.as_tensor_variable(x)
return Apply(self, [x], [x.type()]) axis = numpy.core.numeric.normalize_axis_index(self.axis, x.ndim)
shape = [None] * x.type.ndim
for i, shape_i in enumerate(x.type.shape):
if shape_i is None:
continue
if i == axis:
shape[i] = max(0, shape_i - self.n)
else:
shape[i] = shape_i
out_type = TensorType(dtype=x.type.dtype, shape=shape)
return Apply(self, [x], [out_type()])
def perform(self, node, inputs, output_storage): def perform(self, node, inputs, output_storage):
x = inputs[0] x = inputs[0]
......
...@@ -302,6 +302,28 @@ class TestDiffOp(utt.InferShapeTester): ...@@ -302,6 +302,28 @@ class TestDiffOp(utt.InferShapeTester):
g = aesara.function([x], diff(x, n=n, axis=axis)) g = aesara.function([x], diff(x, n=n, axis=axis))
assert np.allclose(np.diff(a, n=n, axis=axis), g(a)) assert np.allclose(np.diff(a, n=n, axis=axis), g(a))
@pytest.mark.parametrize(
"x_type",
(
at.TensorType("float64", (None, None)),
at.TensorType("float64", (None, 30)),
at.TensorType("float64", (10, None)),
at.TensorType("float64", (10, 30)),
),
)
def test_output_type(self, x_type):
x = x_type("x")
x_test = np.empty((10, 30))
for axis in (-2, -1, 0, 1):
for n in (0, 1, 2, 10, 11):
out = diff(x, n=n, axis=axis)
out_test = np.diff(x_test, n=n, axis=axis)
for i in range(2):
if x.type.shape[i] is None:
assert out.type.shape[i] is None
else:
assert out.type.shape[i] == out_test.shape[i]
def test_infer_shape(self): def test_infer_shape(self):
x = matrix("x") x = matrix("x")
a = np.random.random((30, 50)).astype(config.floatX) a = np.random.random((30, 50)).astype(config.floatX)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论