提交 8e3c356f authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Parametrize local_useless_subtensor tests

上级 35ad2538
...@@ -198,33 +198,42 @@ def test_local_useless_inc_subtensor_no_opt(): ...@@ -198,33 +198,42 @@ def test_local_useless_inc_subtensor_no_opt():
assert any(isinstance(n.op, IncSubtensor) for n in topo) assert any(isinstance(n.op, IncSubtensor) for n in topo)
def test_local_useless_subtensor(): class TestLocalUselessSubtensor:
x = matrix("x") x = matrix("x")
s = aes.int32("s")
# Test default @pytest.mark.parametrize(
for dims in [ "dims",
(slice(0, None),), [
(slice(0, None), slice(0, None)), (slice(0, None),),
]: (slice(0, None), slice(0, None)),
f = function([x], exp(x).__getitem__(dims), mode=mode_opt) ],
)
def test_local_useless_subtensor_1(self, dims):
# Test default
f = function([self.x], exp(self.x).__getitem__(dims), mode=mode_opt)
prog = f.maker.fgraph.toposort() prog = f.maker.fgraph.toposort()
assert prog[0].op == exp assert prog[0].op == exp
assert len(prog) == 1 assert len(prog) == 1
# TODO FIXME: Assert something
f([[0, 1, 2], [3, 4, 5]]) # let debugmode test something f([[0, 1, 2], [3, 4, 5]]) # let debugmode test something
x_c = specify_shape(x, (2, 3)) @pytest.mark.parametrize(
# Test constant "dims, res",
for dims, res in [ [
((slice(0, 2),), True), ((slice(0, 2),), True),
((slice(0, 2), slice(0, None)), True), ((slice(0, 2), slice(0, None)), True),
((slice(0, 2), slice(0, 3)), True), ((slice(0, 2), slice(0, 3)), True),
((slice(0, None), slice(0, 3)), True), ((slice(0, None), slice(0, 3)), True),
((slice(0, 3), slice(0, 13)), True), ((slice(0, 3), slice(0, 13)), True),
((slice(0, 3), slice(0, 2)), False), ((slice(0, 3), slice(0, 2)), False),
((slice(0, 1), slice(0, None)), False), ((slice(0, 1), slice(0, None)), False),
((slice(0, 1), 1), False), ((slice(0, 1), 1), False),
]: ],
f = function([x], exp(x_c).__getitem__(dims), mode=mode_opt) )
def test_local_useless_subtensor_2(self, dims, res):
x_c = specify_shape(self.x, (2, 3))
f = function([self.x], exp(x_c).__getitem__(dims), mode=mode_opt)
prog = f.maker.fgraph.toposort() prog = f.maker.fgraph.toposort()
if res: if res:
assert isinstance(prog[0].op, SpecifyShape), dims assert isinstance(prog[0].op, SpecifyShape), dims
...@@ -234,8 +243,8 @@ def test_local_useless_subtensor(): ...@@ -234,8 +243,8 @@ def test_local_useless_subtensor():
assert any(isinstance(node.op, Subtensor) for node in prog) assert any(isinstance(node.op, Subtensor) for node in prog)
f([[0, 1, 2], [3, 4, 5]]) # let debugmode test something f([[0, 1, 2], [3, 4, 5]]) # let debugmode test something
# Test Variable @pytest.mark.parametrize(
for idx, (dims, res) in enumerate( "dims, res",
[ [
((slice(0, x.shape[0]),), True), ((slice(0, x.shape[0]),), True),
((slice(0, x.shape[1]),), False), ((slice(0, x.shape[1]),), False),
...@@ -276,9 +285,10 @@ def test_local_useless_subtensor(): ...@@ -276,9 +285,10 @@ def test_local_useless_subtensor():
False, False,
), ),
((slice(0, at.scalar_from_tensor(x.shape[0])),), True), ((slice(0, at.scalar_from_tensor(x.shape[0])),), True),
] ],
): )
f = function([x], exp(x).__getitem__(dims), mode=mode_opt) def test_local_useless_subtensor_3(self, dims, res):
f = function([self.x], exp(self.x).__getitem__(dims), mode=mode_opt)
prog = f.maker.fgraph.toposort() prog = f.maker.fgraph.toposort()
if res: if res:
assert prog[0].op == exp, dims assert prog[0].op == exp, dims
...@@ -286,15 +296,19 @@ def test_local_useless_subtensor(): ...@@ -286,15 +296,19 @@ def test_local_useless_subtensor():
else: else:
assert any(isinstance(node.op, Subtensor) for node in prog) assert any(isinstance(node.op, Subtensor) for node in prog)
f([[0, 1, 2], [3, 4, 5]]) # let debugmode test something f([[0, 1, 2], [3, 4, 5]]) # let debugmode test something
# Test mix Variable and Constant
# Currently not supported @pytest.mark.parametrize(
for idx, (dims, res) in enumerate( "dims, res",
[ [
((slice(0, x.shape[0]), slice(0, 3)), False), ((slice(0, x.shape[0]), slice(0, 3)), False),
((slice(0, 3), slice(0, x.shape[1])), False), ((slice(0, 3), slice(0, x.shape[1])), False),
] ],
): )
f = function([x], exp(x_c).__getitem__(dims), mode=mode_opt) def test_local_useless_subtensor_4(self, dims, res):
# Test mix Variable and Constant
# Currently not supported
x_c = specify_shape(self.x, (2, 3))
f = function([self.x], exp(x_c).__getitem__(dims), mode=mode_opt)
prog = f.maker.fgraph.toposort() prog = f.maker.fgraph.toposort()
if res: if res:
assert prog[0].op == exp, dims assert prog[0].op == exp, dims
...@@ -303,14 +317,15 @@ def test_local_useless_subtensor(): ...@@ -303,14 +317,15 @@ def test_local_useless_subtensor():
assert any(isinstance(node.op, Subtensor) for node in prog) assert any(isinstance(node.op, Subtensor) for node in prog)
f([[0, 1, 2], [3, 4, 5]]) # let debugmode test something f([[0, 1, 2], [3, 4, 5]]) # let debugmode test something
# Test scalar variable @pytest.mark.parametrize(
s = aes.int32("s") "dims, res",
for idx, (dims, res) in enumerate(
[ [
((slice(0, s),), False), ((slice(0, s),), False),
] ],
): )
f = function([x, s], exp(x).__getitem__(dims), mode=mode_opt) def test_local_useless_subtensor_5(self, dims, res):
# Test scalar variable
f = function([self.x, self.s], exp(self.x).__getitem__(dims), mode=mode_opt)
prog = f.maker.fgraph.toposort() prog = f.maker.fgraph.toposort()
if res: if res:
assert prog[0].op == exp, dims assert prog[0].op == exp, dims
...@@ -320,20 +335,25 @@ def test_local_useless_subtensor(): ...@@ -320,20 +335,25 @@ def test_local_useless_subtensor():
f([[1, 2, 3], [4, 5, 6]], 1) f([[1, 2, 3], [4, 5, 6]], 1)
f([[1, 2, 3], [4, 5, 6]], 3) f([[1, 2, 3], [4, 5, 6]], 3)
# Test AdvancedSubtensor1 case when all rows are selected by a list/vector @pytest.mark.parametrize(
# or ARange op "dims, res",
for dims, res in ( [
([0, 1], True), ([0, 1], True),
([1, 0], False), ([1, 0], False),
([0, 0], False), ([0, 0], False),
([0, 0, 1], False), ([0, 0, 1], False),
(at.arange(2), True), (at.arange(2), True),
(at.arange(0, 2), True), (at.arange(0, 2), True),
(at.arange(0, 2, 2), False), (at.arange(0, 2, 2), False),
(at.arange(0, 2, -1), False), (at.arange(0, 2, -1), False),
(at.arange(1, 2), False), (at.arange(1, 2), False),
): ],
f = function([x], exp(x_c).__getitem__(dims), mode=mode_opt) )
def test_local_useless_subtensor_6(self, dims, res):
# Test AdvancedSubtensor1 case when all rows are selected by a list/vector
# or ARange op
x_c = specify_shape(self.x, (2, 3))
f = function([self.x], exp(x_c).__getitem__(dims), mode=mode_opt)
prog = f.maker.fgraph.toposort() prog = f.maker.fgraph.toposort()
if res: if res:
assert isinstance(prog[0].op, SpecifyShape), dims assert isinstance(prog[0].op, SpecifyShape), dims
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论