提交 938bd8ef authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Lift Subtensor over Softmax

上级 43cad30f
......@@ -5,7 +5,7 @@ import numpy as np
from pytensor import Variable
from pytensor.graph import Constant, node_rewriter
from pytensor.graph.rewriting.basic import copy_stack_trace
from pytensor.npy_2_compat import normalize_axis_tuple
from pytensor.npy_2_compat import normalize_axis_index, normalize_axis_tuple
from pytensor.scalar import basic as ps
from pytensor.tensor.basic import (
Alloc,
......@@ -32,6 +32,7 @@ from pytensor.tensor.shape import (
SpecifyShape,
specify_shape,
)
from pytensor.tensor.special import Softmax, softmax
from pytensor.tensor.subtensor import (
AdvancedSubtensor1,
Subtensor,
......@@ -51,6 +52,20 @@ def _dims_dropped_by_basic_index(idxs: Sequence[slice | int]) -> tuple[int, ...]
return tuple(i for i, idx in enumerate(idxs) if not isinstance(idx, slice))
def _ndim_dropped_left_of_axis_by_basic_index(
idxs: Sequence[slice | int], axis: int
) -> int:
return len(_dims_dropped_by_basic_index(idxs[:axis]))
def _axis_is_indexed_by_basic_index(
idxs: Sequence[slice | int], axis: int | Sequence[int]
) -> bool:
if isinstance(axis, int):
axis = (axis,)
return any(ax < len(idxs) and not is_full_slice(idxs[ax]) for ax in axis)
@register_canonicalize
@register_stabilize
@register_specialize
......@@ -241,6 +256,84 @@ def local_subtensor_of_reduce(fgraph, node):
return [out]
@register_canonicalize
@register_specialize
@node_rewriter([Subtensor])
def local_subtensor_of_softmax(fgraph, node):
"""Lift a Subtensor through a Softmax.
softmax(x, axis=1)[0] -> softmax(x[0], axis=0)
softmax(x, axis=1)[:, :, 0] -> softmax(x[:, :, 0], axis=1)
If part of the indexing acts on the axis of reduction, we split it
softmax(x, axis=1)[:, 0, 1:] -> softmax(x[:, :, 1:], axis=1)[0]
"""
sm, *idx = node.inputs
if not (sm.owner and isinstance(sm.owner.op, Softmax)):
return None
if len(fgraph.clients[sm]) > 1:
return None
[x] = sm.owner.inputs
axis = sm.owner.op.axis
if axis is None:
if x.type.ndim == 1:
axis = 0
else:
# All dimensions are mixed, we can't lift the subtensor
return None
else:
# Softmax currently only allows None or a single integer axis
# Unlike CAReduce it does not normalize negative indices
axis = normalize_axis_index(axis, sm.ndim)
[old_out] = node.outputs
idx_tuple = indices_from_subtensor(idx, node.op.idx_list)
if _axis_is_indexed_by_basic_index(idx_tuple, axis):
# If there are more dimensions being indexed, we can split them
# And lift the non-axis indexes while keeping the axis index
real_indices = [idx for idx in idx_tuple if not is_full_slice(idx)]
if len(real_indices) > 1 and sm.type.ndim > 1:
# Split the subtensor
idx_to_keep = idx_tuple[axis]
idxs_to_lift = (*idx_tuple[:axis], slice(None), *idx_tuple[axis + 1 :])
# Lift the non-axis indexes by calling the rewrite itself
opt_sm = sm[idxs_to_lift]
[opt_sm] = local_subtensor_of_softmax.transform(fgraph, opt_sm.owner)
copy_stack_trace([old_out, sm], opt_sm)
# Then reintroduce the axis index
ndim_reduced_left = _ndim_dropped_left_of_axis_by_basic_index(
idx_tuple, axis
)
new_axis = axis - ndim_reduced_left
idxs_to_keep = (*(slice(None),) * new_axis, idx_to_keep)
new_out = opt_sm[idxs_to_keep]
copy_stack_trace(old_out, new_out)
return [new_out]
else:
return None
# Index input to softmax
x_sub = x[idx_tuple]
# Adjust axis of reduction when indexing drops dimensions (integer indexing as apposed to slice indexing)
axis -= len(
[idx_item for idx_item in idx_tuple[:axis] if not isinstance(idx_item, slice)]
)
out = softmax(x_sub, axis=axis)
copy_stack_trace(old_out, out)
return [out]
@register_canonicalize("shape_unsafe")
@register_specialize("shape_unsafe")
@node_rewriter([Subtensor])
......
......@@ -45,6 +45,7 @@ from pytensor.tensor.rewriting.subtensor_lift import (
local_subtensor_shape_constant,
)
from pytensor.tensor.shape import SpecifyShape, _shape
from pytensor.tensor.special import softmax
from pytensor.tensor.subtensor import Subtensor
......@@ -211,6 +212,44 @@ def test_local_subtensor_of_reduce(original_fn, expected_fn):
)
@pytest.mark.parametrize(
"original_fn, expected_fn",
[
# Lift single index that does not ovelap with axis of softmax
(lambda x: softmax(x, axis=1)[0], lambda x: softmax(x[0], axis=0)),
(lambda x: softmax(x, axis=1)[1:], lambda x: softmax(x[1:], axis=1)),
(lambda x: softmax(x, axis=0)[:, 0], lambda x: softmax(x[:, 0], axis=0)),
(lambda x: softmax(x, axis=0)[:, 1:], lambda x: softmax(x[:, 1:], axis=0)),
# Do nothing to single index over axis of softmax
(lambda x: softmax(x, axis=0)[0], lambda x: softmax(x, axis=0)[0]),
(lambda x: softmax(x, axis=1)[:, 1:], lambda x: softmax(x, axis=1)[:, 1:]),
# Split indexing on axis of softmax
(lambda x: softmax(x, axis=0)[1:, 0], lambda x: softmax(x[:, 0], axis=0)[1:]),
(lambda x: softmax(x, axis=1)[1:, 0], lambda x: softmax(x[1:], axis=1)[:, 0]),
(
lambda x: softmax(x, axis=0)[0, :5:2],
lambda x: softmax(x[:, :5:2], axis=0)[0],
),
(lambda x: softmax(x, axis=1)[0, :5:2], lambda x: softmax(x[0], axis=0)[:5:2]),
],
)
def test_local_subtensor_of_softmax(original_fn, expected_fn):
rng = np.random.default_rng(230)
x = pt.matrix("x", shape=(5, 3))
x_test = rng.normal(size=x.type.shape).astype(x.dtype)
out = original_fn(x)
expected_opt_out = expected_fn(x)
opt_out = rewrite_graph(out)
assert equal_computations([opt_out], [expected_opt_out]), debugprint(
[expected_opt_out, opt_out], print_type=True
)
np.testing.assert_allclose(
opt_out.eval({x: x_test}, mode=NO_OPTIMIZATION_MODE),
out.eval({x: x_test}, mode=NO_OPTIMIZATION_MODE),
)
@pytest.mark.parametrize(
"original_fn, expected_fn",
[
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论