提交 baa52444 authored 作者: Ayush Kumar's avatar Ayush Kumar 提交者: Ricardo Vieira

Update kron tests to explicitly cover static and dynamic shapes (#1912)

上级 8741f7c6
...@@ -732,6 +732,7 @@ class TestKron(utt.InferShapeTester): ...@@ -732,6 +732,7 @@ class TestKron(utt.InferShapeTester):
super().setup_method() super().setup_method()
def test_vec_vec_kron_raises(self): def test_vec_vec_kron_raises(self):
"""Ensure kron raises an error for 1D inputs."""
x = vector() x = vector()
y = vector() y = vector()
with pytest.raises( with pytest.raises(
...@@ -739,29 +740,38 @@ class TestKron(utt.InferShapeTester): ...@@ -739,29 +740,38 @@ class TestKron(utt.InferShapeTester):
): ):
kron(x, y) kron(x, y)
@pytest.mark.parametrize("static_shape", [True, False])
@pytest.mark.parametrize("shp0", [(2,), (2, 3), (2, 3, 4), (2, 3, 4, 5)]) @pytest.mark.parametrize("shp0", [(2,), (2, 3), (2, 3, 4), (2, 3, 4, 5)])
@pytest.mark.parametrize("shp1", [(6,), (6, 7), (6, 7, 8), (6, 7, 8, 9)]) @pytest.mark.parametrize("shp1", [(6,), (6, 7), (6, 7, 8), (6, 7, 8, 9)])
def test_perform(self, shp0, shp1): def test_perform(self, static_shape, shp0, shp1):
"""Test kron execution and symbolic shape inference."""
if len(shp0) + len(shp1) == 2: if len(shp0) + len(shp1) == 2:
pytest.skip("Sum of shp0 and shp1 must be more than 2") pytest.skip("Sum of shp0 and shp1 must be more than 2")
x = tensor(dtype="floatX", shape=shp0)
a = np.asarray(self.rng.random(shp0)).astype(config.floatX) a = np.asarray(self.rng.random(shp0)).astype(config.floatX)
y = tensor(dtype="floatX", shape=shp1)
b = self.rng.random(shp1).astype(config.floatX) b = self.rng.random(shp1).astype(config.floatX)
# Using np.kron to evaluate expected numerical output and dimensionality
np_val = np.kron(a, b)
# Determine tensor shapes
shape_x = shp0 if static_shape else (None,) * len(shp0)
shape_y = shp1 if static_shape else (None,) * len(shp1)
shape_out = np_val.shape if static_shape else (None,) * np_val.ndim
x = tensor(dtype="floatX", shape=shape_x)
y = tensor(dtype="floatX", shape=shape_y)
kron_xy = kron(x, y) kron_xy = kron(x, y)
# Assert symbolic shape inference immediately after node creation
assert kron_xy.type.shape == shape_out
f = function([x, y], kron_xy) f = function([x, y], kron_xy)
out = f(a, b) out = f(a, b)
# Using np.kron to compare outputs
np_val = np.kron(a, b)
np.testing.assert_allclose(out, np_val) np.testing.assert_allclose(out, np_val)
# Regression test for issue #1867
assert kron_xy.type.shape == np_val.shape
@pytest.mark.parametrize( @pytest.mark.parametrize(
"i, shp0, shp1", "i, shp0, shp1",
[(0, (2, 3), (6, 7)), (1, (2, 3), (4, 3, 5)), (2, (2, 4, 3), (4, 3, 5))], [(0, (2, 3), (6, 7)), (1, (2, 3), (4, 3, 5)), (2, (2, 4, 3), (4, 3, 5))],
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论