提交 db7b988f authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Lift Subtensor over expand_dims

上级 07706002
......@@ -77,7 +77,7 @@ from pytensor.tensor.subtensor import (
indices_from_subtensor,
)
from pytensor.tensor.type import TensorType, integer_dtypes
from pytensor.tensor.type_other import NoneTypeT, SliceConstant, SliceType
from pytensor.tensor.type_other import NoneTypeT, SliceType
from pytensor.tensor.variable import TensorConstant, TensorVariable
......@@ -157,19 +157,21 @@ def transform_take(a, indices, axis):
def is_full_slice(x):
"""Determine if `x` is a ``slice(None)`` or a symbolic equivalent."""
if (
(isinstance(x, slice) and x == slice(None))
or (isinstance(x, SliceConstant) and x.value == slice(None))
or (
not isinstance(x, SliceConstant)
and isinstance(getattr(x, "type", None), SliceType)
and x.owner is not None
and all(
isinstance(getattr(i, "type", None), NoneTypeT) for i in x.owner.inputs
)
)
):
return True
if isinstance(x, slice):
return x == slice(None)
if isinstance(x, Variable) and isinstance(x.type, SliceType):
if x.owner is None:
if isinstance(x, Constant):
return x.data == slice(None)
else:
# Root slice variable
return False
# Symbolic MakeSlice
# Ignores start = 0, step = 1 cases
return all(isinstance(i.type, NoneTypeT) for i in x.owner.inputs)
return False
......
......@@ -11,10 +11,11 @@ from pytensor.tensor.basic import (
MakeVector,
alloc,
as_tensor,
expand_dims,
get_underlying_scalar_constant_value,
register_infer_shape,
)
from pytensor.tensor.elemwise import Elemwise
from pytensor.tensor.elemwise import DimShuffle, Elemwise
from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.math import Dot, ceil_intdiv, dot
from pytensor.tensor.rewriting.basic import (
......@@ -22,7 +23,7 @@ from pytensor.tensor.rewriting.basic import (
register_specialize,
register_stabilize,
)
from pytensor.tensor.rewriting.subtensor import register_useless
from pytensor.tensor.rewriting.subtensor import is_full_slice, register_useless
from pytensor.tensor.shape import (
Shape,
SpecifyShape,
......@@ -35,6 +36,7 @@ from pytensor.tensor.subtensor import (
get_canonical_form_slice,
get_constant_idx,
get_idx_list,
indices_from_subtensor,
)
from pytensor.tensor.type import TensorType
from pytensor.tensor.type_other import SliceType
......@@ -167,6 +169,80 @@ def local_subtensor_lift(fgraph, node):
return [ret]
@register_canonicalize("shape_unsafe")
@register_specialize("shape_unsafe")
@node_rewriter([Subtensor])
def local_subtensor_of_expand_dims(fgraph, node):
"""Lift a Subtensor through a DimShuffle that only expands dims.
expand_dims(x, axis=0)[0] -> x
expand_dims(x, axis=0)[:, 0] -> expand_dims(x[0], axis=0)
expand_dims(x, axis=2)[0] -> expand_dims(x[0], axis=1)
This goes beyond `local_subtensor_remove_broadcastable_index` which
simply removes useless subtensors on broadcastable dimensions.
"""
ds, *idx = node.inputs
if not (ds.owner and isinstance(ds.owner.op, DimShuffle)):
return None
ds_op = ds.owner.op
if not ds_op.is_expand_dims:
return None
expanded_axes = ds_op.augment
[x] = ds.owner.inputs
idx_tuple = indices_from_subtensor(idx, node.op.idx_list)
# Keep indexes for the original dimensions, and drop indexes for the expanded dimensions when safe
new_idxs = []
for i, idx_item in enumerate(idx_tuple):
if i in expanded_axes:
if isinstance(idx_item, slice):
# Slice could be keeping or dropping this dimension
if is_full_slice(idx_item):
# A None slice, always keeps the dimension.
# We skip the index, and later introduce the needed expand_dim
continue
else:
# Other slices could keep or drop the dimension.
# Get out instead o trying to figure out which case it is
return None
else:
# Integer indexing can only drop the dimension (if it's a valid graph)
# We can just drop the index and avoid expanding the dimension
# This is why this rewrite is tagged with "shape_unsafe"
continue
else:
# Keep indexes for non-expanded dimensions
new_idxs.append(idx_item)
[old_out] = node.outputs
out = x[tuple(new_idxs)]
copy_stack_trace(old_out, out)
if out.type.broadcastable != old_out.type.broadcastable:
# Re-introduce needed new dimensions (corresponding to full slices on the original expanded dimensions)
# If out.type.broadcastable == (False) and old_out.type.broadcastable == (True, False, True)
# then axis = (0, 2)
old_bcast = list(old_out.type.broadcastable)
expanded_bcast = list(out.type.broadcastable)
axis = []
i = 0
while i < len(old_bcast):
if i == len(expanded_bcast) or expanded_bcast[i] != old_bcast[i]:
expanded_bcast.insert(i, True)
axis.append(i)
i += 1
out = expand_dims(out, axis=axis)
copy_stack_trace(old_out, out)
return [out]
@register_infer_shape
@register_useless
@register_canonicalize
......
......@@ -19,7 +19,9 @@ from pytensor.graph import (
Type,
rewrite_graph,
)
from pytensor.graph.basic import equal_computations
from pytensor.graph.rewriting.basic import check_stack_trace
from pytensor.printing import debugprint
from pytensor.tensor import (
add,
exp,
......@@ -37,7 +39,7 @@ from pytensor.tensor import (
tensor3,
vector,
)
from pytensor.tensor.basic import MakeVector, make_vector
from pytensor.tensor.basic import MakeVector, expand_dims, make_vector
from pytensor.tensor.elemwise import DimShuffle, Elemwise
from pytensor.tensor.rewriting.subtensor_lift import (
local_subtensor_make_vector,
......@@ -53,6 +55,9 @@ if mode_opt == "FAST_COMPILE":
mode_opt = get_mode(mode_opt)
NO_OPTIMIZATION_MODE = Mode(linker="py", optimizer=None)
class TestLocalSubtensorLift:
def test_basic(self):
# basic test that the Op works
......@@ -134,8 +139,8 @@ class TestLocalSubtensorLift:
assert check_stack_trace(f, ops_to_check="all")
prog = f.maker.fgraph.toposort()
assert isinstance(prog[0].op, DimShuffle)
assert isinstance(prog[1].op, Subtensor)
assert isinstance(prog[0].op, Subtensor)
assert isinstance(prog[1].op, DimShuffle)
assert prog[2].op == exp
assert len(prog) == 3
f([4, 5]) # let debugmode test something
......@@ -198,6 +203,65 @@ class TestLocalSubtensorLift:
f([1, 2, 3], 4) # let debugmode test something
@pytest.mark.parametrize(
"original_fn, expected_fn",
[
# Integer indexing
(lambda x: expand_dims(x, axis=0)[0], lambda x: x),
(
lambda x: expand_dims(x, axis=1)[0],
lambda x: expand_dims(x[0], axis=0),
),
(
lambda x: expand_dims(x, axis=(1, 3))[0],
lambda x: expand_dims(x[0], axis=(0, 2)),
),
# Slice indexing
(
lambda x: expand_dims(x, axis=1)[1:],
lambda x: expand_dims(x[1:], axis=1),
),
(
lambda x: expand_dims(x, axis=(1, 3))[1:],
lambda x: expand_dims(x[1:], axis=(1, 3)),
),
# Not supported, slice indexing on expanded dimension
(
lambda x: expand_dims(x, axis=0)[1:],
lambda x: expand_dims(x, axis=0)[1:],
),
# Mixed indexing
(
lambda x: expand_dims(x, axis=1)[0, :, 1:],
lambda x: expand_dims(x[0, 1:], axis=0),
),
(
lambda x: expand_dims(x, axis=1)[1:, :, 0],
lambda x: expand_dims(x[1:, 0], axis=1),
),
(
lambda x: expand_dims(x, axis=(1, 2))[1:, :, 0],
lambda x: expand_dims(x[1:], axis=1),
),
],
)
def test_local_subtensor_of_expand_dims(original_fn, expected_fn):
rng = np.random.default_rng(232)
x = tensor("x", shape=(5, 3))
x_test = rng.normal(size=x.type.shape).astype(x.dtype)
out = original_fn(x)
expected_opt_out = expected_fn(x)
opt_out = rewrite_graph(out, exclude=["local_uint_constant_indices"])
assert equal_computations([opt_out], [expected_opt_out]), debugprint(
[opt_out, expected_opt_out], print_type=True
)
np.testing.assert_allclose(
opt_out.eval({x: x_test}, mode=NO_OPTIMIZATION_MODE),
out.eval({x: x_test}, mode=NO_OPTIMIZATION_MODE),
)
def test_local_subtensor_of_alloc():
# DebugMode should detect if something goes wrong.
# test shape combination of odd and event shape.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论