提交 c822a8e6 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Cleanup Scan symbolic buffer size graph

Graph was being broken by Scalar/Tensor conversions that prevented fusion
上级 03b62a33
......@@ -33,7 +33,9 @@ from pytensor.tensor.basic import (
alloc,
get_scalar_constant_value,
nonzero,
scalar_from_tensor,
)
from pytensor.tensor.basic import (
constant as tensor_constant,
)
from pytensor.tensor.blockwise import vectorize_node_fallback
from pytensor.tensor.elemwise import DimShuffle
......@@ -256,20 +258,20 @@ def get_idx_list(inputs, idx_list):
def get_canonical_form_slice(
theslice: slice,
length: int | np.integer | ScalarVariable | TensorVariable,
) -> tuple[slice, int | ScalarConstant]: ...
) -> tuple[slice, int | TensorVariable]: ...
@overload
def get_canonical_form_slice(
theslice: int | np.integer | ScalarVariable | TensorVariable,
length: int | np.integer | ScalarVariable | TensorVariable,
) -> tuple[ScalarVariable, int]: ...
) -> tuple[TensorVariable, int]: ...
def get_canonical_form_slice(
theslice: slice | int | np.integer | ScalarVariable | TensorVariable,
length: int | np.integer | ScalarVariable | TensorVariable,
) -> tuple[slice | ScalarVariable, int | ScalarConstant]:
) -> tuple[slice | TensorVariable, int | TensorVariable]:
"""Convert indices or slices to canonical form.
Scalar integer indices or python Slices with Scalar/None attributes
......@@ -296,30 +298,56 @@ def get_canonical_form_slice(
"""
from pytensor.tensor import ge, lt, sign, switch
# Other non-slice types are the scalar indexing case
if not isinstance(theslice, slice):
if isinstance(theslice, int | np.integer | ScalarVariable) or (
isinstance(theslice, TensorVariable) and theslice.ndim == 0
):
cano = switch(lt(theslice, 0), (theslice + length), theslice)
return scalar_from_tensor(cano), 1
raise ValueError(f"Slice {theslice} is not a supported slice type.")
def undo_scalarization(x):
"""Undo scalarization of a variable.
# At this point we have a slice object. Possibly with symbolic inputs.
PyTensor Basic index operations use ScalarVariables for the indices/slice arguments.
But reasoning symbolically about the result of multiple indexing operations, we usually
want to work on TensorVariables, since rewrites work on those and not ScalarVariables.
This function undoes ScalarFromTensor operation or converts ScalarConstants to TensorConstants.
"""
if isinstance(x, ScalarVariable):
if isinstance(x, ScalarConstant):
return tensor_constant(x.data, dtype=x.dtype)
elif x.owner is not None and isinstance(x.owner.op, ScalarFromTensor):
return x.owner.inputs[0]
else:
return as_tensor_variable(x)
return x
def analyze(x):
try:
x_constant = as_index_literal(x)
is_constant = True
except NotScalarConstantError:
x_constant = x
x_constant = undo_scalarization(x)
is_constant = False
return x_constant, is_constant
length, is_length_constant = analyze(length)
# Other non-slice types are the scalar indexing case
if not isinstance(theslice, slice):
if not (
isinstance(theslice, int | np.integer | ScalarVariable)
or (isinstance(theslice, TensorVariable) and theslice.ndim == 0)
):
raise ValueError(f"Slice {theslice} is not a supported slice type.")
idx, is_index_constant = analyze(theslice)
if is_index_constant:
if idx >= 0:
return idx, 1
else:
return idx + length, 1
else:
return switch(lt(idx, 0), idx + length, idx), 1
# At this point we have a slice object. Possibly with symbolic inputs.
start, is_start_constant = analyze(theslice.start)
stop, is_stop_constant = analyze(theslice.stop)
step, is_step_constant = analyze(theslice.step)
length, is_length_constant = analyze(length)
if (
is_start_constant
......
......@@ -16,6 +16,7 @@ from pytensor.compile.mode import Mode
from pytensor.configdefaults import config
from pytensor.gradient import grad
from pytensor.graph import Constant
from pytensor.graph.basic import equal_computations
from pytensor.graph.op import get_test_value
from pytensor.graph.rewriting.utils import is_same_graph
from pytensor.printing import pprint
......@@ -23,7 +24,7 @@ from pytensor.scalar.basic import as_scalar, int16
from pytensor.tensor import as_tensor, get_vector_length, vectorize
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.math import exp, isinf
from pytensor.tensor.math import exp, isinf, lt, switch
from pytensor.tensor.math import sum as pt_sum
from pytensor.tensor.shape import specify_shape
from pytensor.tensor.subtensor import (
......@@ -136,30 +137,41 @@ class TestGetCanonicalFormSlice:
def test_scalar_constant(self):
a = as_scalar(0)
length = lscalar()
res = get_canonical_form_slice(a, length)
assert isinstance(res[0].owner.op, ptb.ScalarFromTensor)
assert res[1] == 1
res, direction = get_canonical_form_slice(a, length)
assert res == 0
assert direction == 1
b = as_scalar(-1)
res, direction = get_canonical_form_slice(b, length)
assert equal_computations([res], [as_tensor(-1) + length])
assert direction == 1
def test_tensor_constant(self):
a = as_tensor(0)
length = lscalar()
res = get_canonical_form_slice(a, length)
assert isinstance(res[0].owner.op, ptb.ScalarFromTensor)
assert res[1] == 1
res, direction = get_canonical_form_slice(a, length)
assert equal_computations([res], [a])
assert direction == 1
b = as_tensor(-1)
res, direction = get_canonical_form_slice(b, length)
assert equal_computations([res], [b + length])
assert direction == 1
def test_symbolic_scalar(self):
a = int16()
length = lscalar()
res = get_canonical_form_slice(a, length)
assert res[0].owner.op, ptb.switch
assert res[1] == 1
res, direction = get_canonical_form_slice(a, length)
a_t = as_tensor(a)
assert equal_computations([res], [switch(lt(a_t, 0), a_t + length, a_t)])
assert direction == 1
def test_symbolic_tensor(self):
a = lscalar()
length = lscalar()
res = get_canonical_form_slice(a, length)
assert isinstance(res[0].owner.op, ptb.ScalarFromTensor)
assert res[1] == 1
res, direction = get_canonical_form_slice(a, length)
assert equal_computations([res], [switch(lt(a, 0), a + length, a)])
assert direction == 1
@pytest.mark.parametrize("int_fn", [int, np.int64, as_tensor, as_scalar])
def test_all_integer(self, int_fn):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论