Unverified 提交 33d9421a authored 作者: Ayush Kumar's avatar Ayush Kumar 提交者: GitHub

Fix static shape inference in pt.linalg.kron and add regression test (#1898)

上级 db7fa079
......@@ -1176,7 +1176,7 @@ def kron(a, b):
b = ptb.expand_dims(b, tuple(range(a.ndim - b.ndim)))
a_reshaped = ptb.expand_dims(a, tuple(range(1, 2 * a.ndim, 2)))
b_reshaped = ptb.expand_dims(b, tuple(range(0, 2 * b.ndim, 2)))
out_shape = tuple(a.shape * b.shape)
out_shape = tuple(a.shape[i] * b.shape[i] for i in range(a.ndim))
output_out_of_shape = a_reshaped * b_reshaped
output_reshaped = output_out_of_shape.reshape(out_shape)
......
......@@ -744,16 +744,24 @@ class TestKron(utt.InferShapeTester):
def test_perform(self, shp0, shp1):
if len(shp0) + len(shp1) == 2:
pytest.skip("Sum of shp0 and shp1 must be more than 2")
x = tensor(dtype="floatX", shape=(None,) * len(shp0))
x = tensor(dtype="floatX", shape=shp0)
a = np.asarray(self.rng.random(shp0)).astype(config.floatX)
y = tensor(dtype="floatX", shape=(None,) * len(shp1))
f = function([x, y], kron(x, y))
y = tensor(dtype="floatX", shape=shp1)
b = self.rng.random(shp1).astype(config.floatX)
kron_xy = kron(x, y)
f = function([x, y], kron_xy)
out = f(a, b)
# Using the np.kron to compare outputs
# Using np.kron to compare outputs
np_val = np.kron(a, b)
np.testing.assert_allclose(out, np_val)
# Regression test for issue #1867
assert kron_xy.type.shape == np_val.shape
@pytest.mark.parametrize(
"i, shp0, shp1",
[(0, (2, 3), (6, 7)), (1, (2, 3), (4, 3, 5)), (2, (2, 4, 3), (4, 3, 5))],
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论