提交 58f1fd2b authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Lift Subtensor over Join

上级 938bd8ef
from collections.abc import Iterable, Sequence from collections.abc import Iterable, Sequence
from typing import cast
import numpy as np import numpy as np
from pytensor import Variable from pytensor import Variable
from pytensor.graph import Constant, node_rewriter from pytensor.graph import Constant, FunctionGraph, node_rewriter
from pytensor.graph.rewriting.basic import copy_stack_trace from pytensor.graph.rewriting.basic import NodeRewriter, copy_stack_trace
from pytensor.npy_2_compat import normalize_axis_index, normalize_axis_tuple from pytensor.npy_2_compat import normalize_axis_index, normalize_axis_tuple
from pytensor.scalar import basic as ps from pytensor.scalar import basic as ps
from pytensor.tensor.basic import ( from pytensor.tensor.basic import (
Alloc, Alloc,
Join,
MakeVector, MakeVector,
alloc, alloc,
as_tensor, as_tensor,
expand_dims, expand_dims,
get_underlying_scalar_constant_value, get_underlying_scalar_constant_value,
join,
register_infer_shape, register_infer_shape,
) )
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
...@@ -44,6 +47,7 @@ from pytensor.tensor.subtensor import ( ...@@ -44,6 +47,7 @@ from pytensor.tensor.subtensor import (
) )
from pytensor.tensor.type import TensorType from pytensor.tensor.type import TensorType
from pytensor.tensor.type_other import SliceType from pytensor.tensor.type_other import SliceType
from pytensor.tensor.variable import TensorVariable
def _dims_dropped_by_basic_index(idxs: Sequence[slice | int]) -> tuple[int, ...]: def _dims_dropped_by_basic_index(idxs: Sequence[slice | int]) -> tuple[int, ...]:
...@@ -66,6 +70,41 @@ def _axis_is_indexed_by_basic_index( ...@@ -66,6 +70,41 @@ def _axis_is_indexed_by_basic_index(
return any(ax < len(idxs) and not is_full_slice(idxs[ax]) for ax in axis) return any(ax < len(idxs) and not is_full_slice(idxs[ax]) for ax in axis)
def _lift_subtensor_non_axis(
local_subtensor_lift_rewrite: NodeRewriter,
fgraph: FunctionGraph,
variable: TensorVariable,
idx_tuple: tuple[int | slice],
axis: int,
old_subtensor_variable: TensorVariable,
) -> None | list[TensorVariable]:
# Apply generic subtensor lift rewrite along "non-axis" dimensions
real_indices = [idx for idx in idx_tuple if not is_full_slice(idx)]
if len(real_indices) > 1 and variable.type.ndim > 1:
# Split the subtensor
idx_to_keep = idx_tuple[axis]
idxs_to_lift = (*idx_tuple[:axis], slice(None), *idx_tuple[axis + 1 :])
# Lift the non-axis indexes by calling the rewrite itself
indexed_variable = variable[idxs_to_lift]
[indexed_variable] = cast(
list[TensorVariable],
local_subtensor_lift_rewrite.transform(fgraph, indexed_variable.owner),
)
copy_stack_trace([old_subtensor_variable, indexed_variable], indexed_variable)
# Then reintroduce the axis index
ndim_reduced_left = _ndim_dropped_left_of_axis_by_basic_index(idx_tuple, axis)
new_axis = axis - ndim_reduced_left
idxs_to_keep = (*(slice(None),) * new_axis, idx_to_keep)
new_out = indexed_variable[idxs_to_keep]
copy_stack_trace(old_subtensor_variable, new_out)
return [new_out]
else:
return None
@register_canonicalize @register_canonicalize
@register_stabilize @register_stabilize
@register_specialize @register_specialize
...@@ -297,29 +336,14 @@ def local_subtensor_of_softmax(fgraph, node): ...@@ -297,29 +336,14 @@ def local_subtensor_of_softmax(fgraph, node):
if _axis_is_indexed_by_basic_index(idx_tuple, axis): if _axis_is_indexed_by_basic_index(idx_tuple, axis):
# If there are more dimensions being indexed, we can split them # If there are more dimensions being indexed, we can split them
# And lift the non-axis indexes while keeping the axis index # And lift the non-axis indexes while keeping the axis index
real_indices = [idx for idx in idx_tuple if not is_full_slice(idx)] return _lift_subtensor_non_axis(
if len(real_indices) > 1 and sm.type.ndim > 1: local_subtensor_lift_rewrite=local_subtensor_of_softmax,
# Split the subtensor fgraph=fgraph,
idx_to_keep = idx_tuple[axis] variable=sm,
idxs_to_lift = (*idx_tuple[:axis], slice(None), *idx_tuple[axis + 1 :]) idx_tuple=idx_tuple,
axis=axis,
# Lift the non-axis indexes by calling the rewrite itself old_subtensor_variable=old_out,
opt_sm = sm[idxs_to_lift] )
[opt_sm] = local_subtensor_of_softmax.transform(fgraph, opt_sm.owner)
copy_stack_trace([old_out, sm], opt_sm)
# Then reintroduce the axis index
ndim_reduced_left = _ndim_dropped_left_of_axis_by_basic_index(
idx_tuple, axis
)
new_axis = axis - ndim_reduced_left
idxs_to_keep = (*(slice(None),) * new_axis, idx_to_keep)
new_out = opt_sm[idxs_to_keep]
copy_stack_trace(old_out, new_out)
return [new_out]
else:
return None
# Index input to softmax # Index input to softmax
x_sub = x[idx_tuple] x_sub = x[idx_tuple]
...@@ -646,6 +670,52 @@ def local_subtensor_make_vector(fgraph, node): ...@@ -646,6 +670,52 @@ def local_subtensor_make_vector(fgraph, node):
pass pass
@register_canonicalize
@register_specialize
@node_rewriter([Subtensor])
def local_subtensor_of_join(fgraph, node):
"""Lift a Subtensor through a Join.
join(axis=1, x, y)[0] -> join(axis=0, x[0], y[0])
join(axis=1, x, y)[:, 0, -1] -> join(axis=1, x[:, :, -1], y[:, :, -1])[:, 0]
"""
join_var, *idx = node.inputs
if not (join_var.owner and isinstance(join_var.owner.op, Join)):
return None
if len(fgraph.clients[join_var]) > 1:
# Join involves a full_copy, so we don't want to do it twice
return None
join_axis, *join_components = join_var.owner.inputs
# Rewrite only works when the join axis is a constant along a non-indexed dimension
if not isinstance(join_axis, Constant):
return None
[old_out] = node.outputs
axis = normalize_axis_index(join_axis.data, join_components[0].type.ndim)
idx_tuple = indices_from_subtensor(idx, node.op.idx_list)
if _axis_is_indexed_by_basic_index(idx_tuple, axis):
return _lift_subtensor_non_axis(
local_subtensor_lift_rewrite=local_subtensor_of_join,
fgraph=fgraph,
variable=join_var,
idx_tuple=idx_tuple,
axis=axis,
old_subtensor_variable=old_out,
)
# Lift index to the Join components
indexed_components = [component[idx_tuple] for component in join_components]
new_axis = axis - _ndim_dropped_left_of_axis_by_basic_index(idx_tuple, axis)
out = join(new_axis, *indexed_components)
return [out]
@register_specialize @register_specialize
@register_canonicalize @register_canonicalize
@node_rewriter([Subtensor]) @node_rewriter([Subtensor])
......
...@@ -36,7 +36,7 @@ from pytensor.tensor import ( ...@@ -36,7 +36,7 @@ from pytensor.tensor import (
tensor3, tensor3,
vector, vector,
) )
from pytensor.tensor.basic import MakeVector, expand_dims, make_vector from pytensor.tensor.basic import MakeVector, concatenate, expand_dims, make_vector
from pytensor.tensor.elemwise import DimShuffle, Elemwise from pytensor.tensor.elemwise import DimShuffle, Elemwise
from pytensor.tensor.math import sum as pt_sum from pytensor.tensor.math import sum as pt_sum
from pytensor.tensor.rewriting.subtensor_lift import ( from pytensor.tensor.rewriting.subtensor_lift import (
...@@ -600,6 +600,60 @@ class TestLocalSubtensorMakeVector: ...@@ -600,6 +600,60 @@ class TestLocalSubtensorMakeVector:
assert local_subtensor_make_vector.transform(fgraph, node) == [v] assert local_subtensor_make_vector.transform(fgraph, node) == [v]
shared_axis = shared(1, "axis")
@pytest.mark.parametrize(
"original_fn, expected_fn",
[
(
lambda x, y: concatenate([x, y], axis=1)[1],
lambda x, y: concatenate([x[1], y[1]], axis=0),
),
(
lambda x, y: concatenate([x, y], axis=-1)[1:],
lambda x, y: concatenate([x[1:], y[1:]], axis=1),
),
# Indexing on both axis of concatenation and somewhere else:
(
lambda x, y: concatenate([x, y], axis=1)[0, 1:],
lambda x, y: concatenate([x[0], y[0]], axis=0)[1:],
),
# Not supported, indexing on axis of concatenation
(
lambda x, y: concatenate([x, y], axis=0)[0],
lambda x, y: concatenate([x, y], axis=0)[0],
),
(
lambda x, y: concatenate([x, y], axis=1)[:, 1:],
lambda x, y: concatenate([x, y], axis=1)[:, 1:],
),
# Not supported, axis of concatenation is dynamically determined
(
lambda x, y: concatenate([x, y], axis=shared_axis)[1],
lambda x, y: concatenate([x, y], axis=shared_axis)[1],
),
],
)
def test_local_subtensor_of_join(original_fn, expected_fn):
rng = np.random.default_rng(257)
x = pt.matrix("x", shape=(5, 3))
y = pt.matrix("y", shape=(5, 3))
x_test = rng.normal(size=x.type.shape).astype(x.dtype)
y_test = rng.normal(size=y.type.shape).astype(y.dtype)
out = original_fn(x, y)
expected_opt_out = expected_fn(x, y)
opt_out = rewrite_graph(out)
assert equal_computations([opt_out], [expected_opt_out]), debugprint(
[expected_opt_out, opt_out], print_type=True
)
np.testing.assert_allclose(
opt_out.eval({x: x_test, y: y_test}, mode=NO_OPTIMIZATION_MODE),
out.eval({x: x_test, y: y_test}, mode=NO_OPTIMIZATION_MODE),
)
def test_local_subtensor_shape_constant(): def test_local_subtensor_shape_constant():
x = tensor(dtype=np.float64, shape=(1, None)).shape[0] x = tensor(dtype=np.float64, shape=(1, None)).shape[0]
(res,) = local_subtensor_shape_constant.transform(None, x.owner) (res,) = local_subtensor_shape_constant.transform(None, x.owner)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论