提交 8e3c356f authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Parametrize local_useless_subtensor tests

上级 35ad2538
......@@ -198,33 +198,42 @@ def test_local_useless_inc_subtensor_no_opt():
assert any(isinstance(n.op, IncSubtensor) for n in topo)
def test_local_useless_subtensor():
class TestLocalUselessSubtensor:
x = matrix("x")
s = aes.int32("s")
# Test default
for dims in [
(slice(0, None),),
(slice(0, None), slice(0, None)),
]:
f = function([x], exp(x).__getitem__(dims), mode=mode_opt)
@pytest.mark.parametrize(
"dims",
[
(slice(0, None),),
(slice(0, None), slice(0, None)),
],
)
def test_local_useless_subtensor_1(self, dims):
# Test default
f = function([self.x], exp(self.x).__getitem__(dims), mode=mode_opt)
prog = f.maker.fgraph.toposort()
assert prog[0].op == exp
assert len(prog) == 1
# TODO FIXME: Assert something
f([[0, 1, 2], [3, 4, 5]]) # let debugmode test something
x_c = specify_shape(x, (2, 3))
# Test constant
for dims, res in [
((slice(0, 2),), True),
((slice(0, 2), slice(0, None)), True),
((slice(0, 2), slice(0, 3)), True),
((slice(0, None), slice(0, 3)), True),
((slice(0, 3), slice(0, 13)), True),
((slice(0, 3), slice(0, 2)), False),
((slice(0, 1), slice(0, None)), False),
((slice(0, 1), 1), False),
]:
f = function([x], exp(x_c).__getitem__(dims), mode=mode_opt)
@pytest.mark.parametrize(
"dims, res",
[
((slice(0, 2),), True),
((slice(0, 2), slice(0, None)), True),
((slice(0, 2), slice(0, 3)), True),
((slice(0, None), slice(0, 3)), True),
((slice(0, 3), slice(0, 13)), True),
((slice(0, 3), slice(0, 2)), False),
((slice(0, 1), slice(0, None)), False),
((slice(0, 1), 1), False),
],
)
def test_local_useless_subtensor_2(self, dims, res):
x_c = specify_shape(self.x, (2, 3))
f = function([self.x], exp(x_c).__getitem__(dims), mode=mode_opt)
prog = f.maker.fgraph.toposort()
if res:
assert isinstance(prog[0].op, SpecifyShape), dims
......@@ -234,8 +243,8 @@ def test_local_useless_subtensor():
assert any(isinstance(node.op, Subtensor) for node in prog)
f([[0, 1, 2], [3, 4, 5]]) # let debugmode test something
# Test Variable
for idx, (dims, res) in enumerate(
@pytest.mark.parametrize(
"dims, res",
[
((slice(0, x.shape[0]),), True),
((slice(0, x.shape[1]),), False),
......@@ -276,9 +285,10 @@ def test_local_useless_subtensor():
False,
),
((slice(0, at.scalar_from_tensor(x.shape[0])),), True),
]
):
f = function([x], exp(x).__getitem__(dims), mode=mode_opt)
],
)
def test_local_useless_subtensor_3(self, dims, res):
f = function([self.x], exp(self.x).__getitem__(dims), mode=mode_opt)
prog = f.maker.fgraph.toposort()
if res:
assert prog[0].op == exp, dims
......@@ -286,15 +296,19 @@ def test_local_useless_subtensor():
else:
assert any(isinstance(node.op, Subtensor) for node in prog)
f([[0, 1, 2], [3, 4, 5]]) # let debugmode test something
# Test mix Variable and Constant
# Currently not supported
for idx, (dims, res) in enumerate(
@pytest.mark.parametrize(
"dims, res",
[
((slice(0, x.shape[0]), slice(0, 3)), False),
((slice(0, 3), slice(0, x.shape[1])), False),
]
):
f = function([x], exp(x_c).__getitem__(dims), mode=mode_opt)
],
)
def test_local_useless_subtensor_4(self, dims, res):
# Test mix Variable and Constant
# Currently not supported
x_c = specify_shape(self.x, (2, 3))
f = function([self.x], exp(x_c).__getitem__(dims), mode=mode_opt)
prog = f.maker.fgraph.toposort()
if res:
assert prog[0].op == exp, dims
......@@ -303,14 +317,15 @@ def test_local_useless_subtensor():
assert any(isinstance(node.op, Subtensor) for node in prog)
f([[0, 1, 2], [3, 4, 5]]) # let debugmode test something
# Test scalar variable
s = aes.int32("s")
for idx, (dims, res) in enumerate(
@pytest.mark.parametrize(
"dims, res",
[
((slice(0, s),), False),
]
):
f = function([x, s], exp(x).__getitem__(dims), mode=mode_opt)
],
)
def test_local_useless_subtensor_5(self, dims, res):
# Test scalar variable
f = function([self.x, self.s], exp(self.x).__getitem__(dims), mode=mode_opt)
prog = f.maker.fgraph.toposort()
if res:
assert prog[0].op == exp, dims
......@@ -320,20 +335,25 @@ def test_local_useless_subtensor():
f([[1, 2, 3], [4, 5, 6]], 1)
f([[1, 2, 3], [4, 5, 6]], 3)
# Test AdvancedSubtensor1 case when all rows are selected by a list/vector
# or ARange op
for dims, res in (
([0, 1], True),
([1, 0], False),
([0, 0], False),
([0, 0, 1], False),
(at.arange(2), True),
(at.arange(0, 2), True),
(at.arange(0, 2, 2), False),
(at.arange(0, 2, -1), False),
(at.arange(1, 2), False),
):
f = function([x], exp(x_c).__getitem__(dims), mode=mode_opt)
@pytest.mark.parametrize(
"dims, res",
[
([0, 1], True),
([1, 0], False),
([0, 0], False),
([0, 0, 1], False),
(at.arange(2), True),
(at.arange(0, 2), True),
(at.arange(0, 2, 2), False),
(at.arange(0, 2, -1), False),
(at.arange(1, 2), False),
],
)
def test_local_useless_subtensor_6(self, dims, res):
# Test AdvancedSubtensor1 case when all rows are selected by a list/vector
# or ARange op
x_c = specify_shape(self.x, (2, 3))
f = function([self.x], exp(x_c).__getitem__(dims), mode=mode_opt)
prog = f.maker.fgraph.toposort()
if res:
assert isinstance(prog[0].op, SpecifyShape), dims
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论