提交 31ab8fdd authored 作者: Ricardo's avatar Ricardo 提交者: Ricardo Vieira

Fix `DiffOp` output type when input has partially known shape

上级 fde62163
......@@ -2,6 +2,7 @@ from collections.abc import Collection
from typing import Iterable, Tuple, Union
import numpy as np
import numpy.core.numeric
import aesara
from aesara.gradient import (
......@@ -482,7 +483,17 @@ class DiffOp(Op):
def make_node(self, x):
x = at.as_tensor_variable(x)
return Apply(self, [x], [x.type()])
axis = numpy.core.numeric.normalize_axis_index(self.axis, x.ndim)
shape = [None] * x.type.ndim
for i, shape_i in enumerate(x.type.shape):
if shape_i is None:
continue
if i == axis:
shape[i] = max(0, shape_i - self.n)
else:
shape[i] = shape_i
out_type = TensorType(dtype=x.type.dtype, shape=shape)
return Apply(self, [x], [out_type()])
def perform(self, node, inputs, output_storage):
x = inputs[0]
......
......@@ -302,6 +302,28 @@ class TestDiffOp(utt.InferShapeTester):
g = aesara.function([x], diff(x, n=n, axis=axis))
assert np.allclose(np.diff(a, n=n, axis=axis), g(a))
@pytest.mark.parametrize(
"x_type",
(
at.TensorType("float64", (None, None)),
at.TensorType("float64", (None, 30)),
at.TensorType("float64", (10, None)),
at.TensorType("float64", (10, 30)),
),
)
def test_output_type(self, x_type):
x = x_type("x")
x_test = np.empty((10, 30))
for axis in (-2, -1, 0, 1):
for n in (0, 1, 2, 10, 11):
out = diff(x, n=n, axis=axis)
out_test = np.diff(x_test, n=n, axis=axis)
for i in range(2):
if x.type.shape[i] is None:
assert out.type.shape[i] is None
else:
assert out.type.shape[i] == out_test.shape[i]
def test_infer_shape(self):
x = matrix("x")
a = np.random.random((30, 50)).astype(config.floatX)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论