提交 f1db1bd6 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Lift Subtensor over transpose

上级 db7b988f
from collections.abc import Iterable from collections.abc import Iterable, Sequence
import numpy as np import numpy as np
...@@ -17,12 +17,14 @@ from pytensor.tensor.basic import ( ...@@ -17,12 +17,14 @@ from pytensor.tensor.basic import (
) )
from pytensor.tensor.elemwise import DimShuffle, Elemwise from pytensor.tensor.elemwise import DimShuffle, Elemwise
from pytensor.tensor.exceptions import NotScalarConstantError from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.extra_ops import squeeze
from pytensor.tensor.math import Dot, ceil_intdiv, dot from pytensor.tensor.math import Dot, ceil_intdiv, dot
from pytensor.tensor.rewriting.basic import ( from pytensor.tensor.rewriting.basic import (
register_canonicalize, register_canonicalize,
register_specialize, register_specialize,
register_stabilize, register_stabilize,
) )
from pytensor.tensor.rewriting.elemwise import local_dimshuffle_lift
from pytensor.tensor.rewriting.subtensor import is_full_slice, register_useless from pytensor.tensor.rewriting.subtensor import is_full_slice, register_useless
from pytensor.tensor.shape import ( from pytensor.tensor.shape import (
Shape, Shape,
...@@ -42,6 +44,12 @@ from pytensor.tensor.type import TensorType ...@@ -42,6 +44,12 @@ from pytensor.tensor.type import TensorType
from pytensor.tensor.type_other import SliceType from pytensor.tensor.type_other import SliceType
def _dims_dropped_by_basic_index(idxs: Sequence[slice | int]) -> tuple[int, ...]:
# Inputs can be slice or integer indexes
# Slices keep the dimensions, integers collapse them
return tuple(i for i, idx in enumerate(idxs) if not isinstance(idx, slice))
@register_canonicalize @register_canonicalize
@register_stabilize @register_stabilize
@register_specialize @register_specialize
...@@ -243,6 +251,55 @@ def local_subtensor_of_expand_dims(fgraph, node): ...@@ -243,6 +251,55 @@ def local_subtensor_of_expand_dims(fgraph, node):
return [out] return [out]
@register_canonicalize
@register_specialize
@node_rewriter([Subtensor])
def local_subtensor_of_transpose(fgraph, node):
"""Lift a Subtensor through a DimShuffle that only transposes.
transpose(x, (1, 0, 2))[i:, j:, k:] -> transpose(x[j:, i:, k:], (1, 0, 2))
"""
ds, *idx = node.inputs
if not (ds.owner and isinstance(ds.owner.op, DimShuffle)):
return None
ds_op = ds.owner.op
if not ds_op.is_transpose:
return None
transposition = ds_op.transposition
[x] = ds.owner.inputs
idx_tuple = indices_from_subtensor(idx, node.op.idx_list)
# Apply the transposition to the indexes
ndim = x.type.ndim
n_implicit_idxs = ndim - len(idx_tuple)
idx_tuple = idx_tuple + (slice(None),) * n_implicit_idxs
new_idxs = [idx_tuple[transposition.index(i)] for i in range(ndim)]
new_x = x[tuple(new_idxs)]
# Reintroduce any dims dropped by indexing so the original transpose still works
dims_dropped_by_new_idx = _dims_dropped_by_basic_index(new_idxs)
if dims_dropped_by_new_idx:
new_x = expand_dims(new_x, axis=dims_dropped_by_new_idx)
# Apply the transpose
new_out = ds_op(new_x)
# Squeeze dims again now that the transpose is done
if dims_dropped_by_new_idx:
dims_dropped_by_original_idx = _dims_dropped_by_basic_index(idx_tuple)
new_out = squeeze(new_out, axis=dims_dropped_by_original_idx)
# Cleanup consecutive expand_dims / transpose / squeeze (if any)
if dims_dropped_by_new_idx:
[new_out] = local_dimshuffle_lift.transform(fgraph, new_out.owner)
return [new_out]
@register_infer_shape @register_infer_shape
@register_useless @register_useless
@register_canonicalize @register_canonicalize
......
...@@ -252,7 +252,7 @@ def test_local_subtensor_of_expand_dims(original_fn, expected_fn): ...@@ -252,7 +252,7 @@ def test_local_subtensor_of_expand_dims(original_fn, expected_fn):
out = original_fn(x) out = original_fn(x)
expected_opt_out = expected_fn(x) expected_opt_out = expected_fn(x)
opt_out = rewrite_graph(out, exclude=["local_uint_constant_indices"]) opt_out = rewrite_graph(out)
assert equal_computations([opt_out], [expected_opt_out]), debugprint( assert equal_computations([opt_out], [expected_opt_out]), debugprint(
[opt_out, expected_opt_out], print_type=True [opt_out, expected_opt_out], print_type=True
) )
...@@ -262,6 +262,35 @@ def test_local_subtensor_of_expand_dims(original_fn, expected_fn): ...@@ -262,6 +262,35 @@ def test_local_subtensor_of_expand_dims(original_fn, expected_fn):
) )
@pytest.mark.parametrize(
"original_fn, expected_fn",
[
(lambda x: x.transpose(2, 1, 0)[0], lambda x: x[:, :, 0].transpose(1, 0)),
(lambda x: x.transpose(2, 1, 0)[:, :, 1:], lambda x: x[1:].transpose(2, 1, 0)),
(
lambda x: x.transpose(2, 1, 0)[0, :1, 1:],
lambda x: x[1:, :1, 0].transpose(1, 0),
),
(lambda x: x.transpose(2, 1, 0)[0, :1, 1], lambda x: x[1, :1, 0]),
],
)
def test_local_subtensor_of_transpose(original_fn, expected_fn):
rng = np.random.default_rng(232)
x = tensor("x", shape=(7, 5, 3))
x_test = rng.normal(size=x.type.shape).astype(x.dtype)
out = original_fn(x)
expected_opt_out = expected_fn(x)
opt_out = rewrite_graph(out)
assert equal_computations([opt_out], [expected_opt_out]), debugprint(
[expected_opt_out, opt_out], print_type=True
)
np.testing.assert_allclose(
opt_out.eval({x: x_test}, mode=NO_OPTIMIZATION_MODE),
out.eval({x: x_test}, mode=NO_OPTIMIZATION_MODE),
)
def test_local_subtensor_of_alloc(): def test_local_subtensor_of_alloc():
# DebugMode should detect if something goes wrong. # DebugMode should detect if something goes wrong.
# test shape combination of odd and event shape. # test shape combination of odd and event shape.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论