提交 43cad30f authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Lift Subtensor over CAReduce

上级 d5a054d1
...@@ -5,6 +5,7 @@ import numpy as np ...@@ -5,6 +5,7 @@ import numpy as np
from pytensor import Variable from pytensor import Variable
from pytensor.graph import Constant, node_rewriter from pytensor.graph import Constant, node_rewriter
from pytensor.graph.rewriting.basic import copy_stack_trace from pytensor.graph.rewriting.basic import copy_stack_trace
from pytensor.npy_2_compat import normalize_axis_tuple
from pytensor.scalar import basic as ps from pytensor.scalar import basic as ps
from pytensor.tensor.basic import ( from pytensor.tensor.basic import (
Alloc, Alloc,
...@@ -15,7 +16,7 @@ from pytensor.tensor.basic import ( ...@@ -15,7 +16,7 @@ from pytensor.tensor.basic import (
get_underlying_scalar_constant_value, get_underlying_scalar_constant_value,
register_infer_shape, register_infer_shape,
) )
from pytensor.tensor.elemwise import DimShuffle, Elemwise from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
from pytensor.tensor.exceptions import NotScalarConstantError from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.extra_ops import squeeze from pytensor.tensor.extra_ops import squeeze
from pytensor.tensor.math import Dot, ceil_intdiv, dot from pytensor.tensor.math import Dot, ceil_intdiv, dot
...@@ -183,6 +184,63 @@ def local_subtensor_of_elemwise(fgraph, node): ...@@ -183,6 +184,63 @@ def local_subtensor_of_elemwise(fgraph, node):
return [new_out] return [new_out]
@register_canonicalize
@register_specialize
@node_rewriter([Subtensor])
def local_subtensor_of_reduce(fgraph, node):
"""Lift a Subtensor through a CAReduce Op.
For now rewrite is restricted to single axis of reduction, for simplicity.
sum(x, axis=1)[0] -> sum(x[0], axis=0)
sum(x, axis=1)[1:] -> sum(x[1:], axis=1)
sum(x, axis=0)[0] -> sum(x[:, 0], axis=0)
sum(x, axis=0)[1:] -> sum(x[:, 1:], axis=0)
"""
red, *idx = node.inputs
if not (red.owner and isinstance(red.owner.op, CAReduce)):
return None
if len(fgraph.clients[red]) > 1:
# Don't apply rewrite if another node requires the full reduction
return None
[x] = red.owner.inputs
axis = red.owner.op.axis
if axis is None:
axis = tuple(range(x.type.ndim))
# TODO: Allow reduction across multiple axis
if len(axis) != 1:
return None
[axis] = normalize_axis_tuple(axis, x.ndim)
idx_tuple = indices_from_subtensor(idx, node.op.idx_list)
# Index input of reduction.
new_idxs = list(idx_tuple)
if axis < len(idx_tuple):
# When there are indexes beyond the axis of reduction, we need to shift them with None slices.
new_idxs.insert(axis, slice(None))
x_sub = x[tuple(new_idxs)]
[old_out] = node.outputs
copy_stack_trace(old_out, x_sub)
# Adjust axis of reduction when indexing drops dimensions (integer indexing as apposed to slice indexing)
axis -= len(
[idx_item for idx_item in idx_tuple[:axis] if not isinstance(idx_item, slice)]
)
# Apply reduction to indexed input
out = type(red.owner.op)(axis=axis)(x_sub)
copy_stack_trace(old_out, out)
return [out]
@register_canonicalize("shape_unsafe") @register_canonicalize("shape_unsafe")
@register_specialize("shape_unsafe") @register_specialize("shape_unsafe")
@node_rewriter([Subtensor]) @node_rewriter([Subtensor])
......
...@@ -38,6 +38,7 @@ from pytensor.tensor import ( ...@@ -38,6 +38,7 @@ from pytensor.tensor import (
) )
from pytensor.tensor.basic import MakeVector, expand_dims, make_vector from pytensor.tensor.basic import MakeVector, expand_dims, make_vector
from pytensor.tensor.elemwise import DimShuffle, Elemwise from pytensor.tensor.elemwise import DimShuffle, Elemwise
from pytensor.tensor.math import sum as pt_sum
from pytensor.tensor.rewriting.subtensor_lift import ( from pytensor.tensor.rewriting.subtensor_lift import (
local_subtensor_make_vector, local_subtensor_make_vector,
local_subtensor_of_elemwise, local_subtensor_of_elemwise,
...@@ -176,6 +177,40 @@ class TestLocalSubtensorOfElemwise: ...@@ -176,6 +177,40 @@ class TestLocalSubtensorOfElemwise:
assert local_subtensor_of_elemwise.transform(fgraph, out2.owner) is not None assert local_subtensor_of_elemwise.transform(fgraph, out2.owner) is not None
@pytest.mark.parametrize(
"original_fn, expected_fn",
[
# Indexing before axis of reduction
(lambda x: pt_sum(x, axis=2)[0], lambda x: pt_sum(x[0], axis=1)),
(lambda x: pt_sum(x, axis=2)[0, 1], lambda x: pt_sum(x[0, 1], axis=None)),
(lambda x: pt_sum(x, axis=2)[1:], lambda x: pt_sum(x[1:], axis=2)),
# Indexing "at" axis of reduction
(lambda x: pt_sum(x, axis=0)[2], lambda x: pt_sum(x[:, 2], axis=0)),
(lambda x: pt_sum(x, axis=0)[:-2], lambda x: pt_sum(x[:, :-2], axis=0)),
# Index after axis of reduction
(lambda x: pt_sum(x, axis=0)[:, 1:], lambda x: pt_sum(x[:, :, 1:], axis=0)),
# Index before and after axis reduction
(lambda x: pt_sum(x, axis=1)[-2, 1:], lambda x: pt_sum(x[-2, :, 1:], axis=0)),
(lambda x: pt_sum(x, axis=1)[1:, -2], lambda x: pt_sum(x[1:, :, -2], axis=1)),
],
)
def test_local_subtensor_of_reduce(original_fn, expected_fn):
rng = np.random.default_rng(245)
x = pt.tensor("x", shape=(5, 3, 2))
x_test = rng.normal(size=x.type.shape).astype(x.dtype)
out = original_fn(x)
expected_opt_out = expected_fn(x)
opt_out = rewrite_graph(out)
assert equal_computations([opt_out], [expected_opt_out]), debugprint(
[expected_opt_out, opt_out], print_type=True
)
np.testing.assert_allclose(
opt_out.eval({x: x_test}, mode=NO_OPTIMIZATION_MODE),
out.eval({x: x_test}, mode=NO_OPTIMIZATION_MODE),
)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"original_fn, expected_fn", "original_fn, expected_fn",
[ [
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论