提交 e934ac7c authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Fix bug in `local_useless_slice` rewrite

Canonical slice start and stop values depend on the sign of the step. The rewrite wrongly assumed they were always 0:len(dim)
上级 f0244adf
......@@ -342,14 +342,18 @@ def local_subtensor_of_dot(fgraph, node):
@node_rewriter([Subtensor])
def local_useless_slice(fgraph, node):
"""
Remove Subtensor of the form:
Remove useless slice(None) of the form:
1. X[0, :] -> X[0]
2. X[:] -> X
Also, rewrite Subtensor of the form:
Also, canonicalize slices of the form:
X[0:7:1] -> X[None:None:None]
where X is a vector of length 7
And:
X[-1:-8:-1] -> X[::-1]
where x is a vector of length 7
"""
idxs = get_idx_list(node.inputs, node.op.idx_list)
x = node.inputs[0]
......@@ -368,32 +372,40 @@ def local_useless_slice(fgraph, node):
if s == slice(None):
continue
step = s.step
if step is None:
positive_step = True
elif isinstance(step, Constant):
step_value = step.data
positive_step = step.data > 0
if step_value == 1:
change_flag = True
step = None
else:
# We can only canonicalize start and stop if we know the sign of step
last_useful_idx = dim
continue
start = s.start
stop = s.stop
step = s.step
if (
start is not None
and extract_constant(start, only_process_constants=True) == 0
):
if start is not None and extract_constant(
start, only_process_constants=True
) == (0 if positive_step else -1):
change_flag = True
start = None
if (
stop is not None
and x.type.shape[dim] is not None
and extract_constant(stop, only_process_constants=True) == x.type.shape[dim]
and extract_constant(stop, only_process_constants=True)
== (x.type.shape[dim] if positive_step else -x.type.shape[dim] - 1)
):
change_flag = True
stop = None
if (
step is not None
and extract_constant(step, only_process_constants=True) == 1
):
change_flag = True
step = None
if not (start is None and stop is None and step is None):
if start is not None or stop is not None or step is not None:
last_useful_idx = dim
new_idxs[dim] = slice(start, stop, step)
......@@ -402,7 +414,6 @@ def local_useless_slice(fgraph, node):
out = x[tuple(new_idxs[: last_useful_idx + 1])]
# Copy over previous output stacktrace
copy_stack_trace(node.outputs, out)
return [out]
......
......@@ -2404,42 +2404,74 @@ def test_local_blockwise_advanced_inc_subtensor(set_instead_of_inc):
np.testing.assert_allclose(fn(test_x, test_y), expected_out)
def test_slice_canonicalize():
rng = np.random.default_rng(43)
x = tensor(shape=(3, 5, None, 9))
test_x = rng.normal(size=(3, 5, 8, 9))
# Test case 1
y = x[0:None, 0:5, 0:7, 0:9:1]
f = pytensor.function([x], y, allow_input_downcast=True)
# Get the DeepCopy input and assert that the Op is a DeepCopy
test_y = f.maker.fgraph.outputs[0].owner.inputs[0]
assert isinstance(f.maker.fgraph.outputs[0].owner.op, DeepCopyOp)
expected_y = x[None:None:None, None:None:None, None:7:None]
assert equal_computations([test_y], [expected_y])
np.testing.assert_allclose(
f(test_x),
test_x[
0:None, 0:5, 0:7, 0:9:1
], # Use the unoptimized slice to make sure our rewrite logic is correct
)
class TestUselessSlice:
def test_positive_step(self):
# When steps are positive, default start and end are `0` and `len(dim)`
x = tensor(shape=(3, 5, None, 9), dtype="float64")
test_x = np.random.normal(size=(3, 5, 8, 9))
y = x[0:3:1, 1:5:2, 0:7:1, 0:9:1]
f = pytensor.function([x], y)
# Get the DeepCopy input and assert that the Op is a DeepCopy
deep_copy_node = f.maker.fgraph.outputs[0].owner
assert isinstance(deep_copy_node.op, DeepCopyOp)
rewritten_y = deep_copy_node.inputs[0]
expected_y = x[None:None:None, 1:None:2, None:7:None]
assert equal_computations([rewritten_y], [expected_y])
np.testing.assert_allclose(
f(test_x),
# Use the unoptimized slice to make sure our rewrite logic is correct
test_x[0:3:1, 1:5:2, 0:7:1, 0:9:1],
)
# Test case 2
y1 = x[0:-1, 0:5, 0:7, 0:-1:-1]
f1 = pytensor.function([x], y1, allow_input_downcast=True)
def test_negative_step(self):
# When steps are negative, default start and end are `-1` and `-len(dim) - 1`
x = tensor(shape=(3, 5, None, 9), dtype="float64")
test_x = np.random.normal(size=(3, 5, 8, 9))
# Get the DeepCopy input and assert that the Op is a DeepCopy
test_y1 = f1.maker.fgraph.outputs[0].owner.inputs[0]
assert isinstance(f1.maker.fgraph.outputs[0].owner.op, DeepCopyOp)
y = x[-1:-4:-1, 0:5:-2, -1:-9:-1, 0:9:None]
f = pytensor.function([x], y)
expected_y1 = x[None:-1:None, None:None:None, None:7:None, None:-1:-1]
# Get the DeepCopy input and assert that the Op is a DeepCopy
deep_copy_node = f.maker.fgraph.outputs[0].owner
assert isinstance(deep_copy_node.op, DeepCopyOp)
assert equal_computations([test_y1], [expected_y1])
rewritten_y = deep_copy_node.inputs[0]
expected_y = x[None:None:-1, 0:5:-2, None:-9:-1]
assert equal_computations([rewritten_y], [expected_y])
np.testing.assert_allclose(
f1(test_x),
test_x[0:-1, 0:5, 0:7, 0:-1:-1],
)
np.testing.assert_allclose(
f(test_x),
test_x[-1:-4:-1, 0:5:-2, -1:-9:-1, 0:9:None],
)
def test_unknown_step(self):
# If step isn't known, we can't canonicalize start and stop points
step = pt.scalar("step", dtype=int)
x = tensor(shape=(3, 5, None), dtype="float64")
test_x = np.random.normal(size=(3, 5, 7))
y = x[0:3:step, -1:-6:-step, ::]
# Need this rewrite when `FAST_COMPILE` otherwise step = -1 * step instead of neg(step)
mode = get_default_mode().including("local_mul_specialize")
f = pytensor.function([x, step], y, mode=mode)
# Get the DeepCopy input and assert that the Op is a DeepCopy
deep_copy_node = f.maker.fgraph.outputs[0].owner
assert isinstance(deep_copy_node.op, DeepCopyOp)
rewritten_y = deep_copy_node.inputs[0]
expected_y = x[0:3:step, -1:-6:-step]
assert equal_computations([rewritten_y], [expected_y])
np.testing.assert_allclose(
f(test_x, 1),
test_x[0:3:1, -1:-6:-1, ::],
)
np.testing.assert_allclose(
f(test_x, -2),
test_x[0:3:-2, -1:-6:2, ::],
)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论