提交 07706002 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Group subtensor specify_shape lift tests in class

上级 ccbab653
...@@ -247,7 +247,8 @@ def test_local_subtensor_of_alloc(): ...@@ -247,7 +247,8 @@ def test_local_subtensor_of_alloc():
assert xval.__getitem__(slices).shape == val.shape assert xval.__getitem__(slices).shape == val.shape
@pytest.mark.parametrize( class TestLocalSubtensorSpecifyShapeLift:
@pytest.mark.parametrize(
"x, s, idx, x_val, s_val", "x, s, idx, x_val, s_val",
[ [
( (
...@@ -293,8 +294,8 @@ def test_local_subtensor_of_alloc(): ...@@ -293,8 +294,8 @@ def test_local_subtensor_of_alloc():
np.array([2, 3, 5], dtype=np.int64), np.array([2, 3, 5], dtype=np.int64),
), ),
], ],
) )
def test_local_subtensor_SpecifyShape_lift(x, s, idx, x_val, s_val): def test_local_subtensor_SpecifyShape_lift(self, x, s, idx, x_val, s_val):
y = specify_shape(x, s)[idx] y = specify_shape(x, s)[idx]
assert isinstance(y.owner.inputs[0].owner.op, SpecifyShape) assert isinstance(y.owner.inputs[0].owner.op, SpecifyShape)
...@@ -319,8 +320,7 @@ def test_local_subtensor_SpecifyShape_lift(x, s, idx, x_val, s_val): ...@@ -319,8 +320,7 @@ def test_local_subtensor_SpecifyShape_lift(x, s, idx, x_val, s_val):
assert np.allclose(y_val, y_opt_val) assert np.allclose(y_val, y_opt_val)
@pytest.mark.parametrize(
@pytest.mark.parametrize(
"x, s, idx", "x, s, idx",
[ [
( (
...@@ -339,8 +339,8 @@ def test_local_subtensor_SpecifyShape_lift(x, s, idx, x_val, s_val): ...@@ -339,8 +339,8 @@ def test_local_subtensor_SpecifyShape_lift(x, s, idx, x_val, s_val):
(1, 0), (1, 0),
), ),
], ],
) )
def test_local_subtensor_SpecifyShape_lift_fail(x, s, idx): def test_local_subtensor_SpecifyShape_lift_fail(self, x, s, idx):
y = specify_shape(x, s)[idx] y = specify_shape(x, s)[idx]
# This optimization should appear in the canonicalizations # This optimization should appear in the canonicalizations
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论