提交 c822a8e6 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Cleanup Scan symbolic buffer size graph

Graph was being broken by Scalar/Tensor conversions that prevented fusion
上级 03b62a33
...@@ -33,7 +33,9 @@ from pytensor.tensor.basic import ( ...@@ -33,7 +33,9 @@ from pytensor.tensor.basic import (
alloc, alloc,
get_scalar_constant_value, get_scalar_constant_value,
nonzero, nonzero,
scalar_from_tensor, )
from pytensor.tensor.basic import (
constant as tensor_constant,
) )
from pytensor.tensor.blockwise import vectorize_node_fallback from pytensor.tensor.blockwise import vectorize_node_fallback
from pytensor.tensor.elemwise import DimShuffle from pytensor.tensor.elemwise import DimShuffle
...@@ -256,20 +258,20 @@ def get_idx_list(inputs, idx_list): ...@@ -256,20 +258,20 @@ def get_idx_list(inputs, idx_list):
def get_canonical_form_slice( def get_canonical_form_slice(
theslice: slice, theslice: slice,
length: int | np.integer | ScalarVariable | TensorVariable, length: int | np.integer | ScalarVariable | TensorVariable,
) -> tuple[slice, int | ScalarConstant]: ... ) -> tuple[slice, int | TensorVariable]: ...
@overload @overload
def get_canonical_form_slice( def get_canonical_form_slice(
theslice: int | np.integer | ScalarVariable | TensorVariable, theslice: int | np.integer | ScalarVariable | TensorVariable,
length: int | np.integer | ScalarVariable | TensorVariable, length: int | np.integer | ScalarVariable | TensorVariable,
) -> tuple[ScalarVariable, int]: ... ) -> tuple[TensorVariable, int]: ...
def get_canonical_form_slice( def get_canonical_form_slice(
theslice: slice | int | np.integer | ScalarVariable | TensorVariable, theslice: slice | int | np.integer | ScalarVariable | TensorVariable,
length: int | np.integer | ScalarVariable | TensorVariable, length: int | np.integer | ScalarVariable | TensorVariable,
) -> tuple[slice | ScalarVariable, int | ScalarConstant]: ) -> tuple[slice | TensorVariable, int | TensorVariable]:
"""Convert indices or slices to canonical form. """Convert indices or slices to canonical form.
Scalar integer indices or python Slices with Scalar/None attributes Scalar integer indices or python Slices with Scalar/None attributes
...@@ -296,30 +298,56 @@ def get_canonical_form_slice( ...@@ -296,30 +298,56 @@ def get_canonical_form_slice(
""" """
from pytensor.tensor import ge, lt, sign, switch from pytensor.tensor import ge, lt, sign, switch
# Other non-slice types are the scalar indexing case def undo_scalarization(x):
if not isinstance(theslice, slice): """Undo scalarization of a variable.
if isinstance(theslice, int | np.integer | ScalarVariable) or (
isinstance(theslice, TensorVariable) and theslice.ndim == 0
):
cano = switch(lt(theslice, 0), (theslice + length), theslice)
return scalar_from_tensor(cano), 1
raise ValueError(f"Slice {theslice} is not a supported slice type.")
# At this point we have a slice object. Possibly with symbolic inputs. PyTensor Basic index operations use ScalarVariables for the indices/slice arguments.
But reasoning symbolically about the result of multiple indexing operations, we usually
want to work on TensorVariables, since rewrites work on those and not ScalarVariables.
This function undoes ScalarFromTensor operation or converts ScalarConstants to TensorConstants.
"""
if isinstance(x, ScalarVariable):
if isinstance(x, ScalarConstant):
return tensor_constant(x.data, dtype=x.dtype)
elif x.owner is not None and isinstance(x.owner.op, ScalarFromTensor):
return x.owner.inputs[0]
else:
return as_tensor_variable(x)
return x
def analyze(x): def analyze(x):
try: try:
x_constant = as_index_literal(x) x_constant = as_index_literal(x)
is_constant = True is_constant = True
except NotScalarConstantError: except NotScalarConstantError:
x_constant = x x_constant = undo_scalarization(x)
is_constant = False is_constant = False
return x_constant, is_constant return x_constant, is_constant
length, is_length_constant = analyze(length)
# Other non-slice types are the scalar indexing case
if not isinstance(theslice, slice):
if not (
isinstance(theslice, int | np.integer | ScalarVariable)
or (isinstance(theslice, TensorVariable) and theslice.ndim == 0)
):
raise ValueError(f"Slice {theslice} is not a supported slice type.")
idx, is_index_constant = analyze(theslice)
if is_index_constant:
if idx >= 0:
return idx, 1
else:
return idx + length, 1
else:
return switch(lt(idx, 0), idx + length, idx), 1
# At this point we have a slice object. Possibly with symbolic inputs.
start, is_start_constant = analyze(theslice.start) start, is_start_constant = analyze(theslice.start)
stop, is_stop_constant = analyze(theslice.stop) stop, is_stop_constant = analyze(theslice.stop)
step, is_step_constant = analyze(theslice.step) step, is_step_constant = analyze(theslice.step)
length, is_length_constant = analyze(length)
if ( if (
is_start_constant is_start_constant
......
...@@ -16,6 +16,7 @@ from pytensor.compile.mode import Mode ...@@ -16,6 +16,7 @@ from pytensor.compile.mode import Mode
from pytensor.configdefaults import config from pytensor.configdefaults import config
from pytensor.gradient import grad from pytensor.gradient import grad
from pytensor.graph import Constant from pytensor.graph import Constant
from pytensor.graph.basic import equal_computations
from pytensor.graph.op import get_test_value from pytensor.graph.op import get_test_value
from pytensor.graph.rewriting.utils import is_same_graph from pytensor.graph.rewriting.utils import is_same_graph
from pytensor.printing import pprint from pytensor.printing import pprint
...@@ -23,7 +24,7 @@ from pytensor.scalar.basic import as_scalar, int16 ...@@ -23,7 +24,7 @@ from pytensor.scalar.basic import as_scalar, int16
from pytensor.tensor import as_tensor, get_vector_length, vectorize from pytensor.tensor import as_tensor, get_vector_length, vectorize
from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.elemwise import DimShuffle from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.math import exp, isinf from pytensor.tensor.math import exp, isinf, lt, switch
from pytensor.tensor.math import sum as pt_sum from pytensor.tensor.math import sum as pt_sum
from pytensor.tensor.shape import specify_shape from pytensor.tensor.shape import specify_shape
from pytensor.tensor.subtensor import ( from pytensor.tensor.subtensor import (
...@@ -136,30 +137,41 @@ class TestGetCanonicalFormSlice: ...@@ -136,30 +137,41 @@ class TestGetCanonicalFormSlice:
def test_scalar_constant(self): def test_scalar_constant(self):
a = as_scalar(0) a = as_scalar(0)
length = lscalar() length = lscalar()
res = get_canonical_form_slice(a, length) res, direction = get_canonical_form_slice(a, length)
assert isinstance(res[0].owner.op, ptb.ScalarFromTensor) assert res == 0
assert res[1] == 1 assert direction == 1
b = as_scalar(-1)
res, direction = get_canonical_form_slice(b, length)
assert equal_computations([res], [as_tensor(-1) + length])
assert direction == 1
def test_tensor_constant(self): def test_tensor_constant(self):
a = as_tensor(0) a = as_tensor(0)
length = lscalar() length = lscalar()
res = get_canonical_form_slice(a, length) res, direction = get_canonical_form_slice(a, length)
assert isinstance(res[0].owner.op, ptb.ScalarFromTensor) assert equal_computations([res], [a])
assert res[1] == 1 assert direction == 1
b = as_tensor(-1)
res, direction = get_canonical_form_slice(b, length)
assert equal_computations([res], [b + length])
assert direction == 1
def test_symbolic_scalar(self): def test_symbolic_scalar(self):
a = int16() a = int16()
length = lscalar() length = lscalar()
res = get_canonical_form_slice(a, length) res, direction = get_canonical_form_slice(a, length)
assert res[0].owner.op, ptb.switch a_t = as_tensor(a)
assert res[1] == 1 assert equal_computations([res], [switch(lt(a_t, 0), a_t + length, a_t)])
assert direction == 1
def test_symbolic_tensor(self): def test_symbolic_tensor(self):
a = lscalar() a = lscalar()
length = lscalar() length = lscalar()
res = get_canonical_form_slice(a, length) res, direction = get_canonical_form_slice(a, length)
assert isinstance(res[0].owner.op, ptb.ScalarFromTensor) assert equal_computations([res], [switch(lt(a, 0), a + length, a)])
assert res[1] == 1 assert direction == 1
@pytest.mark.parametrize("int_fn", [int, np.int64, as_tensor, as_scalar]) @pytest.mark.parametrize("int_fn", [int, np.int64, as_tensor, as_scalar])
def test_all_integer(self, int_fn): def test_all_integer(self, int_fn):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论