提交 e934ac7c authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Fix bug in `local_useless_slice` rewrite

Canonical slice start and stop values depend on the sign of the step. The rewrite wrongly assumed they were always 0:len(dim)
上级 f0244adf
...@@ -342,14 +342,18 @@ def local_subtensor_of_dot(fgraph, node): ...@@ -342,14 +342,18 @@ def local_subtensor_of_dot(fgraph, node):
@node_rewriter([Subtensor]) @node_rewriter([Subtensor])
def local_useless_slice(fgraph, node): def local_useless_slice(fgraph, node):
""" """
Remove Subtensor of the form: Remove useless slice(None) of the form:
1. X[0, :] -> X[0] 1. X[0, :] -> X[0]
2. X[:] -> X 2. X[:] -> X
Also, rewrite Subtensor of the form: Also, canonicalize slices of the form:
X[0:7:1] -> X[None:None:None] X[0:7:1] -> X[None:None:None]
where X is a vector of length 7 where X is a vector of length 7
And:
X[-1:-8:-1] -> X[::-1]
where x is a vector of length 7
""" """
idxs = get_idx_list(node.inputs, node.op.idx_list) idxs = get_idx_list(node.inputs, node.op.idx_list)
x = node.inputs[0] x = node.inputs[0]
...@@ -368,32 +372,40 @@ def local_useless_slice(fgraph, node): ...@@ -368,32 +372,40 @@ def local_useless_slice(fgraph, node):
if s == slice(None): if s == slice(None):
continue continue
step = s.step
if step is None:
positive_step = True
elif isinstance(step, Constant):
step_value = step.data
positive_step = step.data > 0
if step_value == 1:
change_flag = True
step = None
else:
# We can only canonicalize start and stop if we know the sign of step
last_useful_idx = dim
continue
start = s.start start = s.start
stop = s.stop stop = s.stop
step = s.step
if ( if start is not None and extract_constant(
start is not None start, only_process_constants=True
and extract_constant(start, only_process_constants=True) == 0 ) == (0 if positive_step else -1):
):
change_flag = True change_flag = True
start = None start = None
if ( if (
stop is not None stop is not None
and x.type.shape[dim] is not None and x.type.shape[dim] is not None
and extract_constant(stop, only_process_constants=True) == x.type.shape[dim] and extract_constant(stop, only_process_constants=True)
== (x.type.shape[dim] if positive_step else -x.type.shape[dim] - 1)
): ):
change_flag = True change_flag = True
stop = None stop = None
if ( if start is not None or stop is not None or step is not None:
step is not None
and extract_constant(step, only_process_constants=True) == 1
):
change_flag = True
step = None
if not (start is None and stop is None and step is None):
last_useful_idx = dim last_useful_idx = dim
new_idxs[dim] = slice(start, stop, step) new_idxs[dim] = slice(start, stop, step)
...@@ -402,7 +414,6 @@ def local_useless_slice(fgraph, node): ...@@ -402,7 +414,6 @@ def local_useless_slice(fgraph, node):
out = x[tuple(new_idxs[: last_useful_idx + 1])] out = x[tuple(new_idxs[: last_useful_idx + 1])]
# Copy over previous output stacktrace # Copy over previous output stacktrace
copy_stack_trace(node.outputs, out) copy_stack_trace(node.outputs, out)
return [out] return [out]
......
...@@ -2404,42 +2404,74 @@ def test_local_blockwise_advanced_inc_subtensor(set_instead_of_inc): ...@@ -2404,42 +2404,74 @@ def test_local_blockwise_advanced_inc_subtensor(set_instead_of_inc):
np.testing.assert_allclose(fn(test_x, test_y), expected_out) np.testing.assert_allclose(fn(test_x, test_y), expected_out)
def test_slice_canonicalize(): class TestUselessSlice:
rng = np.random.default_rng(43) def test_positive_step(self):
x = tensor(shape=(3, 5, None, 9)) # When steps are positive, default start and end are `0` and `len(dim)`
test_x = rng.normal(size=(3, 5, 8, 9)) x = tensor(shape=(3, 5, None, 9), dtype="float64")
# Test case 1 test_x = np.random.normal(size=(3, 5, 8, 9))
y = x[0:None, 0:5, 0:7, 0:9:1]
f = pytensor.function([x], y, allow_input_downcast=True) y = x[0:3:1, 1:5:2, 0:7:1, 0:9:1]
f = pytensor.function([x], y)
# Get the DeepCopy input and assert that the Op is a DeepCopy
test_y = f.maker.fgraph.outputs[0].owner.inputs[0] # Get the DeepCopy input and assert that the Op is a DeepCopy
assert isinstance(f.maker.fgraph.outputs[0].owner.op, DeepCopyOp) deep_copy_node = f.maker.fgraph.outputs[0].owner
assert isinstance(deep_copy_node.op, DeepCopyOp)
expected_y = x[None:None:None, None:None:None, None:7:None]
rewritten_y = deep_copy_node.inputs[0]
assert equal_computations([test_y], [expected_y]) expected_y = x[None:None:None, 1:None:2, None:7:None]
assert equal_computations([rewritten_y], [expected_y])
np.testing.assert_allclose(
f(test_x), np.testing.assert_allclose(
test_x[ f(test_x),
0:None, 0:5, 0:7, 0:9:1 # Use the unoptimized slice to make sure our rewrite logic is correct
], # Use the unoptimized slice to make sure our rewrite logic is correct test_x[0:3:1, 1:5:2, 0:7:1, 0:9:1],
) )
# Test case 2 def test_negative_step(self):
y1 = x[0:-1, 0:5, 0:7, 0:-1:-1] # When steps are negative, default start and end are `-1` and `-len(dim) - 1`
f1 = pytensor.function([x], y1, allow_input_downcast=True) x = tensor(shape=(3, 5, None, 9), dtype="float64")
test_x = np.random.normal(size=(3, 5, 8, 9))
# Get the DeepCopy input and assert that the Op is a DeepCopy y = x[-1:-4:-1, 0:5:-2, -1:-9:-1, 0:9:None]
test_y1 = f1.maker.fgraph.outputs[0].owner.inputs[0] f = pytensor.function([x], y)
assert isinstance(f1.maker.fgraph.outputs[0].owner.op, DeepCopyOp)
expected_y1 = x[None:-1:None, None:None:None, None:7:None, None:-1:-1] # Get the DeepCopy input and assert that the Op is a DeepCopy
deep_copy_node = f.maker.fgraph.outputs[0].owner
assert isinstance(deep_copy_node.op, DeepCopyOp)
assert equal_computations([test_y1], [expected_y1]) rewritten_y = deep_copy_node.inputs[0]
expected_y = x[None:None:-1, 0:5:-2, None:-9:-1]
assert equal_computations([rewritten_y], [expected_y])
np.testing.assert_allclose( np.testing.assert_allclose(
f1(test_x), f(test_x),
test_x[0:-1, 0:5, 0:7, 0:-1:-1], test_x[-1:-4:-1, 0:5:-2, -1:-9:-1, 0:9:None],
) )
def test_unknown_step(self):
# If step isn't known, we can't canonicalize start and stop points
step = pt.scalar("step", dtype=int)
x = tensor(shape=(3, 5, None), dtype="float64")
test_x = np.random.normal(size=(3, 5, 7))
y = x[0:3:step, -1:-6:-step, ::]
# Need this rewrite when `FAST_COMPILE` otherwise step = -1 * step instead of neg(step)
mode = get_default_mode().including("local_mul_specialize")
f = pytensor.function([x, step], y, mode=mode)
# Get the DeepCopy input and assert that the Op is a DeepCopy
deep_copy_node = f.maker.fgraph.outputs[0].owner
assert isinstance(deep_copy_node.op, DeepCopyOp)
rewritten_y = deep_copy_node.inputs[0]
expected_y = x[0:3:step, -1:-6:-step]
assert equal_computations([rewritten_y], [expected_y])
np.testing.assert_allclose(
f(test_x, 1),
test_x[0:3:1, -1:-6:-1, ::],
)
np.testing.assert_allclose(
f(test_x, -2),
test_x[0:3:-2, -1:-6:2, ::],
)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论