提交 e0d91807 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Remove unnecessary code in local_useless_AdvancedSubtensor1 and refactor tests

上级 2ed28f39
...@@ -11,7 +11,6 @@ from aesara.graph.opt import TopoOptimizer, copy_stack_trace, in2out, local_opti ...@@ -11,7 +11,6 @@ from aesara.graph.opt import TopoOptimizer, copy_stack_trace, in2out, local_opti
from aesara.raise_op import Assert from aesara.raise_op import Assert
from aesara.tensor.basic import ( from aesara.tensor.basic import (
Alloc, Alloc,
ARange,
Join, Join,
MakeVector, MakeVector,
ScalarFromTensor, ScalarFromTensor,
...@@ -908,8 +907,6 @@ def local_useless_subtensor(fgraph, node): ...@@ -908,8 +907,6 @@ def local_useless_subtensor(fgraph, node):
# is not a useless subtensor # is not a useless subtensor
return False return False
for pos, idx in enumerate(cdata):
length_pos = shape_of[node.inputs[0]][pos] length_pos = shape_of[node.inputs[0]][pos]
if isinstance(idx.stop, (int, np.integer)): if isinstance(idx.stop, (int, np.integer)):
...@@ -997,21 +994,6 @@ def local_useless_AdvancedSubtensor1(fgraph, node): ...@@ -997,21 +994,6 @@ def local_useless_AdvancedSubtensor1(fgraph, node):
return False return False
if np.any(idx != np.arange(length)): if np.any(idx != np.arange(length)):
return False return False
elif idx.owner is not None and isinstance(idx.owner.op, ARange):
try:
start, stop, step = map(
lambda x: get_scalar_constant_value(x, only_process_constants=True),
idx.owner.inputs,
)
except NotScalarConstantError:
return False
if start != 0:
return False
if stop != length:
return False
if step != 1:
return False
else: else:
return False return False
......
...@@ -201,25 +201,31 @@ def test_local_useless_inc_subtensor_no_opt(): ...@@ -201,25 +201,31 @@ def test_local_useless_inc_subtensor_no_opt():
class TestLocalUselessSubtensor: class TestLocalUselessSubtensor:
x = matrix("x") x = matrix("x")
s = aes.int32("s") s = aes.int32("s")
mode = mode_opt.including(
"local_useless_subtensor", "local_useless_AdvancedSubtensor1"
)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"dims", "idx",
[ [
(slice(0, None),), (slice(0, None),),
(slice(0, None), slice(0, None)), (slice(0, None), slice(0, None)),
], ],
) )
def test_local_useless_subtensor_1(self, dims): def test_local_useless_subtensor_1(self, idx):
# Test default f = function([self.x], exp(self.x).__getitem__(idx), mode=self.mode)
f = function([self.x], exp(self.x).__getitem__(dims), mode=mode_opt)
prog = f.maker.fgraph.toposort() prog = f.maker.fgraph.toposort()
assert prog[0].op == exp assert prog[0].op == exp
assert len(prog) == 1 assert len(prog) == 1
# TODO FIXME: Assert something
f([[0, 1, 2], [3, 4, 5]]) # let debugmode test something x_val = np.array([[0, 1, 2], [3, 4, 5]], dtype=aesara.config.floatX)
idx_val = idx
exp_res = np.exp(x_val)[idx_val]
res = f(x_val)
assert np.allclose(res, exp_res)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"dims, res", "idx, res",
[ [
((slice(0, 2),), True), ((slice(0, 2),), True),
((slice(0, 2), slice(0, None)), True), ((slice(0, 2), slice(0, None)), True),
...@@ -231,112 +237,150 @@ class TestLocalUselessSubtensor: ...@@ -231,112 +237,150 @@ class TestLocalUselessSubtensor:
((slice(0, 1), 1), False), ((slice(0, 1), 1), False),
], ],
) )
def test_local_useless_subtensor_2(self, dims, res): def test_local_useless_subtensor_2(self, idx, res):
x_c = specify_shape(self.x, (2, 3)) x_c = specify_shape(self.x, (2, 3))
f = function([self.x], exp(x_c).__getitem__(dims), mode=mode_opt) f = function([self.x], exp(x_c).__getitem__(idx), mode=self.mode)
prog = f.maker.fgraph.toposort() prog = f.maker.fgraph.toposort()
if res: if res:
assert isinstance(prog[0].op, SpecifyShape), dims assert isinstance(prog[0].op, SpecifyShape)
assert prog[1].op == exp, (dims, prog) assert prog[1].op == exp
assert len(prog) == 2, dims assert len(prog) == 2
else: else:
assert any(isinstance(node.op, Subtensor) for node in prog) assert any(isinstance(node.op, Subtensor) for node in prog)
f([[0, 1, 2], [3, 4, 5]]) # let debugmode test something
x_val = np.array([[0, 1, 2], [3, 4, 5]], dtype=aesara.config.floatX)
idx_val = idx
exp_res = np.exp(x_val)[idx_val]
res = f(x_val)
assert np.allclose(res, exp_res)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"dims, res", "idx_fn, res",
[ [
((slice(0, x.shape[0]),), True), (lambda x: (slice(0, x.shape[0]),), True),
((slice(0, x.shape[1]),), False), (lambda x: (slice(0, x.shape[1]),), False),
( (
( lambda x: (
slice(0, x.shape[0]), slice(0, x.shape[0]),
slice(0, x.shape[1]), slice(0, x.shape[1]),
), ),
True, True,
), ),
( (
( lambda x: (
slice(0, x.shape[0]), slice(0, x.shape[0]),
slice(0, x.shape[0]), slice(0, x.shape[0]),
), ),
False, False,
), ),
( (
( lambda x: (
slice(0, x.shape[1]), slice(0, x.shape[1]),
slice(0, x.shape[0]), slice(0, x.shape[0]),
), ),
False, False,
), ),
( (
( lambda x: (
slice(0, x.shape[1]), slice(0, x.shape[1]),
slice(0, x.shape[1]), slice(0, x.shape[1]),
), ),
False, False,
), ),
((slice(0, x.shape[1]), 2), False), (lambda x: (slice(0, x.shape[1]), 2), False),
( (
( lambda x: (
slice(0, x.shape[1]), slice(0, x.shape[1]),
slice(x.shape[0] - x.shape[0], x.shape[1]), slice(x.shape[0] - x.shape[0], x.shape[1]),
), ),
False, False,
), ),
((slice(0, at.scalar_from_tensor(x.shape[0])),), True), (
lambda x: (
slice(
0,
at.scalar_from_tensor(x.shape[0])
if isinstance(x, Variable)
else x.shape[0],
),
),
True,
),
], ],
) )
def test_local_useless_subtensor_3(self, dims, res): def test_local_useless_subtensor_3(self, idx_fn, res):
f = function([self.x], exp(self.x).__getitem__(dims), mode=mode_opt) idx = idx_fn(self.x)
f = function([self.x], exp(self.x).__getitem__(idx), mode=self.mode)
prog = f.maker.fgraph.toposort() prog = f.maker.fgraph.toposort()
if res: if res:
assert prog[0].op == exp, dims assert prog[0].op == exp
assert len(prog) == 1, dims assert len(prog) == 1
else: else:
assert any(isinstance(node.op, Subtensor) for node in prog) assert any(isinstance(node.op, Subtensor) for node in prog)
f([[0, 1, 2], [3, 4, 5]]) # let debugmode test something
x_val = np.array([[0, 1, 2], [3, 4, 5]], dtype=aesara.config.floatX)
idx_val = idx_fn(x_val)
exp_res = np.exp(x_val)[idx_val]
res = f(x_val)
assert np.allclose(res, exp_res)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"dims, res", "idx_fn, res",
[ [
((slice(0, x.shape[0]), slice(0, 3)), False), (lambda x: (slice(0, x.shape[0]), slice(0, 3)), False),
((slice(0, 3), slice(0, x.shape[1])), False), (lambda x: (slice(0, 3), slice(0, x.shape[1])), False),
], ],
) )
def test_local_useless_subtensor_4(self, dims, res): def test_local_useless_subtensor_4(self, idx_fn, res):
# Test mix Variable and Constant # Test mix Variable and Constant
# Currently not supported # Currently not supported
x_c = specify_shape(self.x, (2, 3)) x_c = specify_shape(self.x, (2, 3))
f = function([self.x], exp(x_c).__getitem__(dims), mode=mode_opt) idx = idx_fn(self.x)
f = function([self.x], exp(x_c).__getitem__(idx), mode=self.mode)
prog = f.maker.fgraph.toposort() prog = f.maker.fgraph.toposort()
if res: if res:
assert prog[0].op == exp, dims assert prog[0].op == exp
assert len(prog) == 1, dims assert len(prog) == 1
else: else:
assert any(isinstance(node.op, Subtensor) for node in prog) assert any(isinstance(node.op, Subtensor) for node in prog)
f([[0, 1, 2], [3, 4, 5]]) # let debugmode test something
x_val = np.array([[0, 1, 2], [3, 4, 5]], dtype=aesara.config.floatX)
idx_val = idx_fn(x_val)
exp_res = np.exp(x_val)[idx_val]
res = f(x_val)
assert np.allclose(res, exp_res)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"dims, res", "idx_fn, res",
[ [
((slice(0, s),), False), (lambda s: (slice(0, s),), False),
], ],
) )
def test_local_useless_subtensor_5(self, dims, res): def test_local_useless_subtensor_5(self, idx_fn, res):
# Test scalar variable # Test scalar variable
f = function([self.x, self.s], exp(self.x).__getitem__(dims), mode=mode_opt) idx = idx_fn(self.s)
f = function([self.x, self.s], exp(self.x).__getitem__(idx), mode=mode_opt)
prog = f.maker.fgraph.toposort() prog = f.maker.fgraph.toposort()
if res: if res:
assert prog[0].op == exp, dims assert prog[0].op == exp
assert len(prog) == 1, dims assert len(prog) == 1
else: else:
assert any(isinstance(node.op, Subtensor) for node in prog) assert any(isinstance(node.op, Subtensor) for node in prog)
f([[1, 2, 3], [4, 5, 6]], 1)
f([[1, 2, 3], [4, 5, 6]], 3) x_val = np.array([[0, 1, 2], [3, 4, 5]], dtype=aesara.config.floatX)
idx_val = idx_fn(1)
exp_res = np.exp(x_val)[idx_val]
res = f(x_val, 1)
assert np.allclose(res, exp_res)
idx_val = idx_fn(3)
exp_res = np.exp(x_val)[idx_val]
res = f(x_val, 3)
assert np.allclose(res, exp_res)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"dims, res", "idx, res",
[ [
([0, 1], True), ([0, 1], True),
([1, 0], False), ([1, 0], False),
...@@ -349,19 +393,24 @@ class TestLocalUselessSubtensor: ...@@ -349,19 +393,24 @@ class TestLocalUselessSubtensor:
(at.arange(1, 2), False), (at.arange(1, 2), False),
], ],
) )
def test_local_useless_subtensor_6(self, dims, res): def test_local_useless_subtensor_6(self, idx, res):
# Test AdvancedSubtensor1 case when all rows are selected by a list/vector # Test AdvancedSubtensor1 case when all rows are selected by a list/vector
# or ARange op # or ARange op
x_c = specify_shape(self.x, (2, 3)) x_c = specify_shape(self.x, (2, 3))
f = function([self.x], exp(x_c).__getitem__(dims), mode=mode_opt) f = function([self.x], exp(x_c).__getitem__(idx), mode=mode_opt)
prog = f.maker.fgraph.toposort() prog = f.maker.fgraph.toposort()
if res: if res:
assert isinstance(prog[0].op, SpecifyShape), dims assert isinstance(prog[0].op, SpecifyShape)
assert prog[1].op == exp, dims assert prog[1].op == exp
assert len(prog) == 2, dims assert len(prog) == 2
else: else:
assert any(isinstance(node.op, AdvancedSubtensor1) for node in prog) assert any(isinstance(node.op, AdvancedSubtensor1) for node in prog)
f([[0, 1, 2], [3, 4, 5]]) # let debugmode test something
x_val = np.array([[0, 1, 2], [3, 4, 5]], dtype=aesara.config.floatX)
idx_val = idx.eval() if isinstance(idx, Variable) else idx
exp_res = np.exp(x_val)[idx_val]
res = f(x_val)
assert np.allclose(res, exp_res)
def test_local_subtensor_remove_broadcastable_index(): def test_local_subtensor_remove_broadcastable_index():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论