提交 db7b988f authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Lift Subtensor over expand_dims

上级 07706002
...@@ -77,7 +77,7 @@ from pytensor.tensor.subtensor import ( ...@@ -77,7 +77,7 @@ from pytensor.tensor.subtensor import (
indices_from_subtensor, indices_from_subtensor,
) )
from pytensor.tensor.type import TensorType, integer_dtypes from pytensor.tensor.type import TensorType, integer_dtypes
from pytensor.tensor.type_other import NoneTypeT, SliceConstant, SliceType from pytensor.tensor.type_other import NoneTypeT, SliceType
from pytensor.tensor.variable import TensorConstant, TensorVariable from pytensor.tensor.variable import TensorConstant, TensorVariable
...@@ -157,19 +157,21 @@ def transform_take(a, indices, axis): ...@@ -157,19 +157,21 @@ def transform_take(a, indices, axis):
def is_full_slice(x): def is_full_slice(x):
"""Determine if `x` is a ``slice(None)`` or a symbolic equivalent.""" """Determine if `x` is a ``slice(None)`` or a symbolic equivalent."""
if ( if isinstance(x, slice):
(isinstance(x, slice) and x == slice(None)) return x == slice(None)
or (isinstance(x, SliceConstant) and x.value == slice(None))
or ( if isinstance(x, Variable) and isinstance(x.type, SliceType):
not isinstance(x, SliceConstant) if x.owner is None:
and isinstance(getattr(x, "type", None), SliceType) if isinstance(x, Constant):
and x.owner is not None return x.data == slice(None)
and all( else:
isinstance(getattr(i, "type", None), NoneTypeT) for i in x.owner.inputs # Root slice variable
) return False
)
): # Symbolic MakeSlice
return True # Ignores start = 0, step = 1 cases
return all(isinstance(i.type, NoneTypeT) for i in x.owner.inputs)
return False return False
......
...@@ -11,10 +11,11 @@ from pytensor.tensor.basic import ( ...@@ -11,10 +11,11 @@ from pytensor.tensor.basic import (
MakeVector, MakeVector,
alloc, alloc,
as_tensor, as_tensor,
expand_dims,
get_underlying_scalar_constant_value, get_underlying_scalar_constant_value,
register_infer_shape, register_infer_shape,
) )
from pytensor.tensor.elemwise import Elemwise from pytensor.tensor.elemwise import DimShuffle, Elemwise
from pytensor.tensor.exceptions import NotScalarConstantError from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.math import Dot, ceil_intdiv, dot from pytensor.tensor.math import Dot, ceil_intdiv, dot
from pytensor.tensor.rewriting.basic import ( from pytensor.tensor.rewriting.basic import (
...@@ -22,7 +23,7 @@ from pytensor.tensor.rewriting.basic import ( ...@@ -22,7 +23,7 @@ from pytensor.tensor.rewriting.basic import (
register_specialize, register_specialize,
register_stabilize, register_stabilize,
) )
from pytensor.tensor.rewriting.subtensor import register_useless from pytensor.tensor.rewriting.subtensor import is_full_slice, register_useless
from pytensor.tensor.shape import ( from pytensor.tensor.shape import (
Shape, Shape,
SpecifyShape, SpecifyShape,
...@@ -35,6 +36,7 @@ from pytensor.tensor.subtensor import ( ...@@ -35,6 +36,7 @@ from pytensor.tensor.subtensor import (
get_canonical_form_slice, get_canonical_form_slice,
get_constant_idx, get_constant_idx,
get_idx_list, get_idx_list,
indices_from_subtensor,
) )
from pytensor.tensor.type import TensorType from pytensor.tensor.type import TensorType
from pytensor.tensor.type_other import SliceType from pytensor.tensor.type_other import SliceType
...@@ -167,6 +169,80 @@ def local_subtensor_lift(fgraph, node): ...@@ -167,6 +169,80 @@ def local_subtensor_lift(fgraph, node):
return [ret] return [ret]
@register_canonicalize("shape_unsafe")
@register_specialize("shape_unsafe")
@node_rewriter([Subtensor])
def local_subtensor_of_expand_dims(fgraph, node):
"""Lift a Subtensor through a DimShuffle that only expands dims.
expand_dims(x, axis=0)[0] -> x
expand_dims(x, axis=0)[:, 0] -> expand_dims(x[0], axis=0)
expand_dims(x, axis=2)[0] -> expand_dims(x[0], axis=1)
This goes beyond `local_subtensor_remove_broadcastable_index` which
simply removes useless subtensors on broadcastable dimensions.
"""
ds, *idx = node.inputs
if not (ds.owner and isinstance(ds.owner.op, DimShuffle)):
return None
ds_op = ds.owner.op
if not ds_op.is_expand_dims:
return None
expanded_axes = ds_op.augment
[x] = ds.owner.inputs
idx_tuple = indices_from_subtensor(idx, node.op.idx_list)
# Keep indexes for the original dimensions, and drop indexes for the expanded dimensions when safe
new_idxs = []
for i, idx_item in enumerate(idx_tuple):
if i in expanded_axes:
if isinstance(idx_item, slice):
# Slice could be keeping or dropping this dimension
if is_full_slice(idx_item):
# A None slice, always keeps the dimension.
# We skip the index, and later introduce the needed expand_dim
continue
else:
# Other slices could keep or drop the dimension.
# Get out instead o trying to figure out which case it is
return None
else:
# Integer indexing can only drop the dimension (if it's a valid graph)
# We can just drop the index and avoid expanding the dimension
# This is why this rewrite is tagged with "shape_unsafe"
continue
else:
# Keep indexes for non-expanded dimensions
new_idxs.append(idx_item)
[old_out] = node.outputs
out = x[tuple(new_idxs)]
copy_stack_trace(old_out, out)
if out.type.broadcastable != old_out.type.broadcastable:
# Re-introduce needed new dimensions (corresponding to full slices on the original expanded dimensions)
# If out.type.broadcastable == (False) and old_out.type.broadcastable == (True, False, True)
# then axis = (0, 2)
old_bcast = list(old_out.type.broadcastable)
expanded_bcast = list(out.type.broadcastable)
axis = []
i = 0
while i < len(old_bcast):
if i == len(expanded_bcast) or expanded_bcast[i] != old_bcast[i]:
expanded_bcast.insert(i, True)
axis.append(i)
i += 1
out = expand_dims(out, axis=axis)
copy_stack_trace(old_out, out)
return [out]
@register_infer_shape @register_infer_shape
@register_useless @register_useless
@register_canonicalize @register_canonicalize
......
...@@ -19,7 +19,9 @@ from pytensor.graph import ( ...@@ -19,7 +19,9 @@ from pytensor.graph import (
Type, Type,
rewrite_graph, rewrite_graph,
) )
from pytensor.graph.basic import equal_computations
from pytensor.graph.rewriting.basic import check_stack_trace from pytensor.graph.rewriting.basic import check_stack_trace
from pytensor.printing import debugprint
from pytensor.tensor import ( from pytensor.tensor import (
add, add,
exp, exp,
...@@ -37,7 +39,7 @@ from pytensor.tensor import ( ...@@ -37,7 +39,7 @@ from pytensor.tensor import (
tensor3, tensor3,
vector, vector,
) )
from pytensor.tensor.basic import MakeVector, make_vector from pytensor.tensor.basic import MakeVector, expand_dims, make_vector
from pytensor.tensor.elemwise import DimShuffle, Elemwise from pytensor.tensor.elemwise import DimShuffle, Elemwise
from pytensor.tensor.rewriting.subtensor_lift import ( from pytensor.tensor.rewriting.subtensor_lift import (
local_subtensor_make_vector, local_subtensor_make_vector,
...@@ -53,6 +55,9 @@ if mode_opt == "FAST_COMPILE": ...@@ -53,6 +55,9 @@ if mode_opt == "FAST_COMPILE":
mode_opt = get_mode(mode_opt) mode_opt = get_mode(mode_opt)
NO_OPTIMIZATION_MODE = Mode(linker="py", optimizer=None)
class TestLocalSubtensorLift: class TestLocalSubtensorLift:
def test_basic(self): def test_basic(self):
# basic test that the Op works # basic test that the Op works
...@@ -134,8 +139,8 @@ class TestLocalSubtensorLift: ...@@ -134,8 +139,8 @@ class TestLocalSubtensorLift:
assert check_stack_trace(f, ops_to_check="all") assert check_stack_trace(f, ops_to_check="all")
prog = f.maker.fgraph.toposort() prog = f.maker.fgraph.toposort()
assert isinstance(prog[0].op, DimShuffle) assert isinstance(prog[0].op, Subtensor)
assert isinstance(prog[1].op, Subtensor) assert isinstance(prog[1].op, DimShuffle)
assert prog[2].op == exp assert prog[2].op == exp
assert len(prog) == 3 assert len(prog) == 3
f([4, 5]) # let debugmode test something f([4, 5]) # let debugmode test something
...@@ -198,6 +203,65 @@ class TestLocalSubtensorLift: ...@@ -198,6 +203,65 @@ class TestLocalSubtensorLift:
f([1, 2, 3], 4) # let debugmode test something f([1, 2, 3], 4) # let debugmode test something
@pytest.mark.parametrize(
"original_fn, expected_fn",
[
# Integer indexing
(lambda x: expand_dims(x, axis=0)[0], lambda x: x),
(
lambda x: expand_dims(x, axis=1)[0],
lambda x: expand_dims(x[0], axis=0),
),
(
lambda x: expand_dims(x, axis=(1, 3))[0],
lambda x: expand_dims(x[0], axis=(0, 2)),
),
# Slice indexing
(
lambda x: expand_dims(x, axis=1)[1:],
lambda x: expand_dims(x[1:], axis=1),
),
(
lambda x: expand_dims(x, axis=(1, 3))[1:],
lambda x: expand_dims(x[1:], axis=(1, 3)),
),
# Not supported, slice indexing on expanded dimension
(
lambda x: expand_dims(x, axis=0)[1:],
lambda x: expand_dims(x, axis=0)[1:],
),
# Mixed indexing
(
lambda x: expand_dims(x, axis=1)[0, :, 1:],
lambda x: expand_dims(x[0, 1:], axis=0),
),
(
lambda x: expand_dims(x, axis=1)[1:, :, 0],
lambda x: expand_dims(x[1:, 0], axis=1),
),
(
lambda x: expand_dims(x, axis=(1, 2))[1:, :, 0],
lambda x: expand_dims(x[1:], axis=1),
),
],
)
def test_local_subtensor_of_expand_dims(original_fn, expected_fn):
rng = np.random.default_rng(232)
x = tensor("x", shape=(5, 3))
x_test = rng.normal(size=x.type.shape).astype(x.dtype)
out = original_fn(x)
expected_opt_out = expected_fn(x)
opt_out = rewrite_graph(out, exclude=["local_uint_constant_indices"])
assert equal_computations([opt_out], [expected_opt_out]), debugprint(
[opt_out, expected_opt_out], print_type=True
)
np.testing.assert_allclose(
opt_out.eval({x: x_test}, mode=NO_OPTIMIZATION_MODE),
out.eval({x: x_test}, mode=NO_OPTIMIZATION_MODE),
)
def test_local_subtensor_of_alloc(): def test_local_subtensor_of_alloc():
# DebugMode should detect if something goes wrong. # DebugMode should detect if something goes wrong.
# test shape combination of odd and event shape. # test shape combination of odd and event shape.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论