提交 db7fa079 authored 作者: Jaan Erik Pihel's avatar Jaan Erik Pihel 提交者: Ricardo Vieira

Refactor AdvancedSubtensor

- newaxis is handled as explicit DimShuffel on the inputs - slices are encoded internally, so the Ops only take numerical inputs Co-authored-by: 's avatarRicardo Vieira <28983449+ricardov94@users.noreply.github.com>
上级 87470065
......@@ -771,9 +771,9 @@ class DestroyHandler(Bookkeeper):
}
tolerated.add(destroyed_idx)
tolerate_aliased = getattr(
app.op, "destroyhandler_tolerate_aliased", []
app.op, "destroyhandler_tolerate_aliased", ()
)
assert isinstance(tolerate_aliased, list)
assert isinstance(tolerate_aliased, tuple | list)
ignored = {
idx1 for idx0, idx1 in tolerate_aliased if idx0 == destroyed_idx
}
......
......@@ -8,7 +8,6 @@ from pytensor.tensor.subtensor import (
Subtensor,
indices_from_subtensor,
)
from pytensor.tensor.type_other import MakeSlice
BOOLEAN_MASK_ERROR = """JAX does not support resizing arrays with boolean
......@@ -35,10 +34,8 @@ slice length.
@jax_funcify.register(AdvancedSubtensor)
@jax_funcify.register(AdvancedSubtensor1)
def jax_funcify_Subtensor(op, node, **kwargs):
idx_list = getattr(op, "idx_list", None)
def subtensor(x, *ilists):
indices = indices_from_subtensor(ilists, idx_list)
indices = indices_from_subtensor(ilists, op.idx_list)
if len(indices) == 1:
indices = indices[0]
......@@ -48,10 +45,9 @@ def jax_funcify_Subtensor(op, node, **kwargs):
@jax_funcify.register(IncSubtensor)
@jax_funcify.register(AdvancedIncSubtensor)
@jax_funcify.register(AdvancedIncSubtensor1)
def jax_funcify_IncSubtensor(op, node, **kwargs):
idx_list = getattr(op, "idx_list", None)
if getattr(op, "set_instead_of_inc", False):
def jax_fn(x, indices, y):
......@@ -62,7 +58,7 @@ def jax_funcify_IncSubtensor(op, node, **kwargs):
def jax_fn(x, indices, y):
return x.at[indices].add(y)
def incsubtensor(x, y, *ilist, jax_fn=jax_fn, idx_list=idx_list):
def incsubtensor(x, y, *ilist, jax_fn=jax_fn, idx_list=op.idx_list):
indices = indices_from_subtensor(ilist, idx_list)
if len(indices) == 1:
indices = indices[0]
......@@ -73,29 +69,3 @@ def jax_funcify_IncSubtensor(op, node, **kwargs):
return jax_fn(x, indices, y)
return incsubtensor
@jax_funcify.register(AdvancedIncSubtensor)
def jax_funcify_AdvancedIncSubtensor(op, node, **kwargs):
if getattr(op, "set_instead_of_inc", False):
def jax_fn(x, indices, y):
return x.at[indices].set(y)
else:
def jax_fn(x, indices, y):
return x.at[indices].add(y)
def advancedincsubtensor(x, y, *ilist, jax_fn=jax_fn):
return jax_fn(x, ilist, y)
return advancedincsubtensor
@jax_funcify.register(MakeSlice)
def jax_funcify_MakeSlice(op, **kwargs):
def makeslice(*x):
return slice(*x)
return makeslice
......@@ -10,15 +10,14 @@ from pytensor.tensor.subtensor import (
Subtensor,
indices_from_subtensor,
)
from pytensor.tensor.type_other import MakeSlice
@mlx_funcify.register(Subtensor)
def mlx_funcify_Subtensor(op, node, **kwargs):
idx_list = getattr(op, "idx_list", None)
def subtensor(x, *ilists):
indices = indices_from_subtensor([int(element) for element in ilists], idx_list)
indices = indices_from_subtensor(
[int(element) for element in ilists], op.idx_list
)
if len(indices) == 1:
indices = indices[0]
......@@ -30,10 +29,8 @@ def mlx_funcify_Subtensor(op, node, **kwargs):
@mlx_funcify.register(AdvancedSubtensor)
@mlx_funcify.register(AdvancedSubtensor1)
def mlx_funcify_AdvancedSubtensor(op, node, **kwargs):
idx_list = getattr(op, "idx_list", None)
def advanced_subtensor(x, *ilists):
indices = indices_from_subtensor(ilists, idx_list)
indices = indices_from_subtensor(ilists, op.idx_list)
if len(indices) == 1:
indices = indices[0]
......@@ -45,8 +42,6 @@ def mlx_funcify_AdvancedSubtensor(op, node, **kwargs):
@mlx_funcify.register(IncSubtensor)
@mlx_funcify.register(AdvancedIncSubtensor1)
def mlx_funcify_IncSubtensor(op, node, **kwargs):
idx_list = getattr(op, "idx_list", None)
if getattr(op, "set_instead_of_inc", False):
def mlx_fn(x, indices, y):
......@@ -63,7 +58,7 @@ def mlx_funcify_IncSubtensor(op, node, **kwargs):
x[indices] += y
return x
def incsubtensor(x, y, *ilist, mlx_fn=mlx_fn, idx_list=idx_list):
def incsubtensor(x, y, *ilist, mlx_fn=mlx_fn, idx_list=op.idx_list):
indices = indices_from_subtensor(ilist, idx_list)
if len(indices) == 1:
indices = indices[0]
......@@ -95,11 +90,3 @@ def mlx_funcify_AdvancedIncSubtensor(op, node, **kwargs):
return mlx_fn(x, ilist, y)
return advancedincsubtensor
@mlx_funcify.register(MakeSlice)
def mlx_funcify_MakeSlice(op, **kwargs):
def makeslice(*x):
return slice(*x)
return makeslice
......@@ -10,18 +10,17 @@ from numba import types
from numba.core.pythonapi import box
import pytensor.link.numba.dispatch.basic as numba_basic
from pytensor.graph import Type
from pytensor.graph import Variable
from pytensor.link.numba.cache import (
compile_numba_function_src,
)
from pytensor.link.numba.dispatch.basic import (
generate_fallback_impl,
register_funcify_and_cache_key,
register_funcify_default_op_cache_key,
)
from pytensor.link.numba.dispatch.compile_ops import numba_deepcopy
from pytensor.link.numba.dispatch.string_codegen import create_tuple_string
from pytensor.tensor import TensorType
from pytensor.tensor import TensorType, TensorVariable
from pytensor.tensor.subtensor import (
AdvancedIncSubtensor,
AdvancedIncSubtensor1,
......@@ -29,8 +28,8 @@ from pytensor.tensor.subtensor import (
AdvancedSubtensor1,
IncSubtensor,
Subtensor,
indices_from_subtensor,
)
from pytensor.tensor.type_other import MakeSlice, NoneTypeT
def slice_new(self, start, stop, step):
......@@ -118,15 +117,6 @@ def numba_deepcopy_slice(x):
return deepcopy_slice
@register_funcify_default_op_cache_key(MakeSlice)
def numba_funcify_MakeSlice(op, **kwargs):
@numba_basic.numba_njit
def makeslice(*x):
return slice(*x)
return makeslice
def subtensor_op_cache_key(op, **extra_fields):
key_parts = [type(op), tuple(extra_fields.items())]
if hasattr(op, "idx_list"):
......@@ -156,35 +146,36 @@ def subtensor_op_cache_key(op, **extra_fields):
def numba_funcify_default_subtensor(op, node, **kwargs):
"""Create a Python function that assembles and uses an index on an array."""
def convert_indices(indice_names, entry):
if indice_names and isinstance(entry, Type):
return next(indice_names)
def convert_indices(indices_iterator, entry):
if isinstance(entry, int):
name, var = next(indices_iterator)
if var.ndim == 0 and isinstance(var.type, TensorType):
return f"{name}.item()"
return name
elif isinstance(entry, slice):
return (
f"slice({convert_indices(indice_names, entry.start)}, "
f"{convert_indices(indice_names, entry.stop)}, "
f"{convert_indices(indice_names, entry.step)})"
f"slice({convert_indices(indices_iterator, entry.start)}, "
f"{convert_indices(indices_iterator, entry.stop)}, "
f"{convert_indices(indices_iterator, entry.step)})"
)
elif isinstance(entry, type(None)):
return "None"
else:
raise ValueError()
raise ValueError(f"Unknown index type: {entry}")
set_or_inc = isinstance(
op, IncSubtensor | AdvancedIncSubtensor1 | AdvancedIncSubtensor
)
index_start_idx = 1 + int(set_or_inc)
op_indices = list(node.inputs[index_start_idx:])
idx_list = getattr(op, "idx_list", None)
idx_list = op.idx_list
idx_names = [f"idx_{i}" for i in range(len(op_indices))]
input_names = ["x", "y", *idx_names] if set_or_inc else ["x", *idx_names]
idx_names_iterator = iter(idx_names)
indices_creation_src = (
tuple(convert_indices(idx_names_iterator, idx) for idx in idx_list)
if idx_list
else tuple(input_names[index_start_idx:])
indices_iterator = iter(zip(idx_names, op_indices))
indices_creation_src = tuple(
convert_indices(indices_iterator, idx) for idx in idx_list
)
if len(indices_creation_src) == 1:
......@@ -240,20 +231,24 @@ def {function_name}({", ".join(input_names)}):
@register_funcify_and_cache_key(AdvancedIncSubtensor)
def numba_funcify_AdvancedSubtensor(op, node, **kwargs):
if isinstance(op, AdvancedSubtensor):
_x, _y, idxs = node.inputs[0], None, node.inputs[1:]
_x, *index_variables = node.inputs
else:
_x, _y, *idxs = node.inputs
adv_idxs = [
{
"axis": i,
"dtype": idx.type.dtype,
"bcast": idx.type.broadcastable,
"ndim": idx.type.ndim,
}
for i, idx in enumerate(idxs)
if isinstance(idx.type, TensorType)
]
_x, _y, *index_variables = node.inputs
reconstructed_indices = indices_from_subtensor(index_variables, op.idx_list)
adv_idxs = []
for i, idx in enumerate(reconstructed_indices):
if isinstance(idx, TensorVariable):
# This is an advanced tensor index
adv_idxs.append(
{
"axis": i,
"dtype": idx.type.dtype,
"bcast": idx.type.broadcastable,
"ndim": idx.type.ndim,
}
)
must_ignore_duplicates = (
isinstance(op, AdvancedIncSubtensor)
......@@ -265,13 +260,10 @@ def numba_funcify_AdvancedSubtensor(op, node, **kwargs):
)
)
# Special implementation for integer indices that respects duplicates
if (
not must_ignore_duplicates
and len(adv_idxs) >= 1
and all(adv_idx["dtype"] != "bool" for adv_idx in adv_idxs)
# Implementation does not support newaxis
and not any(isinstance(idx.type, NoneTypeT) for idx in idxs)
):
return vector_integer_advanced_indexing(op, node, **kwargs)
......@@ -399,7 +391,6 @@ def vector_integer_advanced_indexing(
y_bcast = np.broadcast_to(y_adv_dims_front, (*adv_idx_shape, *basic_idx_shape))
# Ravel the advanced dims (if needed)
# Note that numba reshape only supports C-arrays, so we ravel before reshape
y_bcast = y_bcast
# Index over tuples of raveled advanced indices and update buffer
......@@ -460,45 +451,90 @@ def vector_integer_advanced_indexing(
return x
"""
if isinstance(op, AdvancedSubtensor1 | AdvancedSubtensor):
x, *idxs = node.inputs
x, *index_variables = node.inputs
else:
x, y, *idxs = node.inputs
x, y, *index_variables = node.inputs
[out] = node.outputs
reconstructed_indices = indices_from_subtensor(index_variables, op.idx_list)
idx_args = [f"idx{i}" for i in range(len(index_variables))]
var_to_arg = dict(zip(index_variables, idx_args))
idxs = []
def get_idx_str(val, is_slice_component=False):
if val is None:
return "None"
if isinstance(val, Variable) and val in var_to_arg:
arg = var_to_arg[val]
if val.ndim == 0 and is_slice_component:
return f"{arg}.item()"
return arg
raise ValueError(f"Unexpected index value: {val}")
for idx in reconstructed_indices:
if isinstance(idx, slice):
start = get_idx_str(idx.start, is_slice_component=True)
stop = get_idx_str(idx.stop, is_slice_component=True)
step = get_idx_str(idx.step, is_slice_component=True)
idxs.append(f"slice({start}, {stop}, {step})")
else:
# It's a direct index variable
idxs.append(get_idx_str(idx, is_slice_component=False))
adv_indices_pos = tuple(
i for i, idx in enumerate(idxs) if isinstance(idx.type, TensorType)
i for i, idx in enumerate(reconstructed_indices) if not isinstance(idx, slice)
)
assert adv_indices_pos # Otherwise it's just basic indexing
basic_indices_pos = tuple(
i for i, idx in enumerate(idxs) if not isinstance(idx.type, TensorType)
i for i, idx in enumerate(reconstructed_indices) if isinstance(idx, slice)
)
explicit_basic_indices_pos = (*basic_indices_pos, *range(len(idxs), x.type.ndim))
# Create index signature and split them among basic and advanced
idx_signature = ", ".join(f"idx{i}" for i in range(len(idxs)))
adv_indices = [f"idx{i}" for i in adv_indices_pos]
basic_indices = [f"idx{i}" for i in basic_indices_pos]
# Create index signature for generated function: "idx0, idx1, idx2, ..."
idx_signature = ", ".join(idx_args)
# Define transpose axis so that advanced indexing dims are on the front
adv_axis_front_order = (*adv_indices_pos, *explicit_basic_indices_pos)
adv_axis_front_transpose_needed = adv_axis_front_order != tuple(range(x.ndim))
adv_idx_ndim = max(idxs[i].ndim for i in adv_indices_pos)
# String representations of advanced and basic indices for codegen
adv_indices = [idxs[i] for i in adv_indices_pos]
basic_indices = [idxs[i] for i in basic_indices_pos]
# Helper needed for basic indexing after moving advanced indices to the front
basic_indices_with_none_slices = ", ".join(
(*((":",) * len(adv_indices)), *basic_indices)
)
to_tuple = create_tuple_string # alias to make code more readable below
# Position of the first advanced index dimension after indexing the array
if (np.diff(adv_indices_pos) > 1).any():
# If not consecutive, it's always at the front
out_adv_axis_pos = 0
# Compute number of dimensions in advanced indices (after broadcasting)
if len(adv_indices_pos) == 1:
adv_idx = reconstructed_indices[adv_indices_pos[0]]
adv_idx_ndim = adv_idx.ndim # type: ignore[union-attr]
else:
# Otherwise wherever the first advanced index is located
# Multiple advanced indices - use max ndim (broadcast result ndim)
adv_idx_ndim = max(reconstructed_indices[i].ndim for i in adv_indices_pos) # type: ignore[union-attr]
# Determine output position of advanced indexed dimensions
# If advanced indices are consecutive, they go in the first advanced index position
# Otherwise they go at the beginning
if adv_indices_pos == tuple(range(adv_indices_pos[0], adv_indices_pos[-1] + 1)):
# Consecutive - advanced dims will be at position of first advanced index
out_adv_axis_pos = adv_indices_pos[0]
else:
# Non-consecutive - advanced dims go at the front
out_adv_axis_pos = 0
to_tuple = create_tuple_string # alias to make code more readable below
# Include trailing dimensions not covered by explicit indices
explicit_basic_indices_pos = (
*basic_indices_pos,
*range(len(reconstructed_indices), x.type.ndim),
)
# Compute transpose to move advanced indexed dims to the front
adv_axis_front_order = (*adv_indices_pos, *explicit_basic_indices_pos)
adv_axis_front_transpose_needed = adv_axis_front_order != tuple(range(x.type.ndim))
# Compute basic indices with "None" slices for dimensions that will be indexed by advanced indices
basic_indices_with_none_slices = ", ".join(
":" for _ in range(len(adv_indices_pos))
) + (", " + ", ".join(basic_indices) if basic_indices else "")
if isinstance(op, AdvancedSubtensor1 | AdvancedSubtensor):
# Define transpose axis on the output to restore original meaning
......@@ -557,7 +593,8 @@ def vector_integer_advanced_indexing(
else:
# Make implicit dims of y explicit to simplify code
# Numba doesn't support `np.expand_dims` with multiple axis, so we use indexing with newaxis
indexed_ndim = x[tuple(idxs)].type.ndim
indexed_ndim = x[tuple(reconstructed_indices)].type.ndim
y_expand_dims = [":"] * y.type.ndim
y_implicit_dims = range(indexed_ndim - y.type.ndim)
for axis in y_implicit_dims:
......
......@@ -9,7 +9,6 @@ from pytensor.tensor.subtensor import (
Subtensor,
indices_from_subtensor,
)
from pytensor.tensor.type_other import MakeSlice, SliceType
def check_negative_steps(indices):
......@@ -47,23 +46,11 @@ def pytorch_funcify_Subtensor(op, node, **kwargs):
return subtensor
@pytorch_funcify.register(MakeSlice)
def pytorch_funcify_makeslice(op, **kwargs):
def makeslice(start, stop, step):
# Torch does not like numpy integers in indexing slices
return slice(
None if start is None else int(start),
None if stop is None else int(stop),
None if step is None else int(step),
)
return makeslice
@pytorch_funcify.register(AdvancedSubtensor1)
@pytorch_funcify.register(AdvancedSubtensor)
def pytorch_funcify_AdvSubtensor(op, node, **kwargs):
def advsubtensor(x, *indices):
indices = indices_from_subtensor(indices, op.idx_list)
check_negative_steps(indices)
return x[indices]
......@@ -102,12 +89,14 @@ def pytorch_funcify_IncSubtensor(op, node, **kwargs):
@pytorch_funcify.register(AdvancedIncSubtensor)
@pytorch_funcify.register(AdvancedIncSubtensor1)
def pytorch_funcify_AdvancedIncSubtensor(op, node, **kwargs):
idx_list = op.idx_list
inplace = op.inplace
ignore_duplicates = getattr(op, "ignore_duplicates", False)
if op.set_instead_of_inc:
def adv_set_subtensor(x, y, *indices):
def adv_set_subtensor(x, y, *flattened_indices):
indices = indices_from_subtensor(flattened_indices, idx_list)
check_negative_steps(indices)
if isinstance(op, AdvancedIncSubtensor1):
op._check_runtime_broadcasting(node, x, y, indices)
......@@ -120,7 +109,8 @@ def pytorch_funcify_AdvancedIncSubtensor(op, node, **kwargs):
elif ignore_duplicates:
def adv_inc_subtensor_no_duplicates(x, y, *indices):
def adv_inc_subtensor_no_duplicates(x, y, *flattened_indices):
indices = indices_from_subtensor(flattened_indices, idx_list)
check_negative_steps(indices)
if isinstance(op, AdvancedIncSubtensor1):
op._check_runtime_broadcasting(node, x, y, indices)
......@@ -132,13 +122,14 @@ def pytorch_funcify_AdvancedIncSubtensor(op, node, **kwargs):
return adv_inc_subtensor_no_duplicates
else:
if any(isinstance(idx.type, SliceType) for idx in node.inputs[2:]):
if any(isinstance(entry, slice) for entry in idx_list):
raise NotImplementedError(
"IncSubtensor with potential duplicates indexes and slice indexing not implemented in PyTorch"
)
def adv_inc_subtensor(x, y, *indices):
# Not needed because slices aren't supported
def adv_inc_subtensor(x, y, *flattened_indices):
indices = indices_from_subtensor(flattened_indices, idx_list)
# Not needed because slices aren't supported in this path
# check_negative_steps(indices)
if not inplace:
x = x.clone()
......
......@@ -72,9 +72,9 @@ from pytensor.tensor.shape import shape
from pytensor.tensor.subtensor import (
IncSubtensor,
Subtensor,
basic_subtensor,
get_canonical_form_slice,
get_idx_list,
get_slice_elements,
set_subtensor,
)
from pytensor.tensor.variable import TensorConstant, TensorVariable
......@@ -1211,7 +1211,7 @@ def _is_default_scan_buffer(final_buffer: TensorVariable, taps: int) -> bool:
if not (
isinstance(op, IncSubtensor)
and op.set_instead_of_inc
and op.idx_list == [slice(None, ps.int64)]
and op.idx_list == (slice(None, 0),)
):
return False
......@@ -1389,12 +1389,6 @@ def scan_save_mem_rewrite(fgraph, node, backend_supports_output_pre_allocation:
else:
# 2.3.1 extract idx list of subtensor
this_slice = get_idx_list(cl.inputs, cl.op.idx_list)
if this_slice is None:
# if unable to extract idx_list
# => outputs needs all its intermediate values
global_nsteps = None
slices[i] = None
break
# 2.3.2 extract the begin/end of the first dimension
if i >= op_info.n_mit_mot:
......@@ -1487,9 +1481,6 @@ def scan_save_mem_rewrite(fgraph, node, backend_supports_output_pre_allocation:
break
else:
this_slice = get_idx_list(cl.inputs, cl.op.idx_list)
if this_slice is None:
store_steps[i] = 0
break
if isinstance(this_slice[0], slice):
start = this_slice[0].start
......@@ -1711,16 +1702,9 @@ def scan_save_mem_rewrite(fgraph, node, backend_supports_output_pre_allocation:
)
else:
fslice = sanitize(cnf_slice[0])
nw_slice = (fslice, *old_slices[1:])
nw_pos = inv_compress_map[idx]
subtens = Subtensor(nw_slice)
# slice inputs
sl_ins = get_slice_elements(
nw_slice, lambda entry: isinstance(entry, Variable)
)
new_o = cast(TensorVariable, subtens(new_outs[nw_pos], *sl_ins))
new_o = basic_subtensor(new_outs[nw_pos], fslice, *old_slices[1:])
if new_o.ndim > 0:
new_o = new_o[:: cnf_slice[1]]
replaced_outs.append(idx)
......@@ -1771,11 +1755,7 @@ def scan_save_mem_rewrite(fgraph, node, backend_supports_output_pre_allocation:
)
nw_slice = (sanitize(position), *old_slices[1:])
subtens = Subtensor(nw_slice)
sl_ins = get_slice_elements(
nw_slice, lambda entry: isinstance(entry, Variable)
)
new_o = cast(TensorVariable, subtens(new_outs[nw_pos], *sl_ins))
new_o = basic_subtensor(new_outs[nw_pos], *nw_slice)
if new_o.ndim > 0:
new_o = new_o[:: cnf_slice[1]]
old_new += [(old, new_o)]
......
......@@ -29,7 +29,7 @@ from pytensor.graph.fg import FunctionGraph, Output
from pytensor.graph.op import Op
from pytensor.graph.replace import _vectorize_node
from pytensor.graph.rewriting.db import EquilibriumDB
from pytensor.graph.type import HasShape, Type
from pytensor.graph.type import HasShape
from pytensor.link.c.op import COp
from pytensor.link.c.params_type import ParamsType
from pytensor.printing import Printer, min_informative_str, pprint, set_precedence
......@@ -433,7 +433,7 @@ def _get_underlying_scalar_constant_value(
var.ndim == 1 for var in v.owner.inputs[0].owner.inputs[1:]
):
idx = v.owner.op.idx_list[0]
if isinstance(idx, Type):
if isinstance(idx, int):
idx = _get_underlying_scalar_constant_value(
v.owner.inputs[1], max_recur=max_recur
)
......@@ -467,7 +467,7 @@ def _get_underlying_scalar_constant_value(
and len(v.owner.op.idx_list) == 1
):
idx = v.owner.op.idx_list[0]
if isinstance(idx, Type):
if isinstance(idx, int):
idx = _get_underlying_scalar_constant_value(
v.owner.inputs[1], max_recur=max_recur
)
......@@ -488,7 +488,7 @@ def _get_underlying_scalar_constant_value(
op = owner.op
idx_list = op.idx_list
idx = idx_list[0]
if isinstance(idx, Type):
if isinstance(idx, int):
idx = _get_underlying_scalar_constant_value(
owner.inputs[1], max_recur=max_recur
)
......
......@@ -23,7 +23,7 @@ from pytensor.tensor.subtensor import (
indices_from_subtensor,
)
from pytensor.tensor.type import integer_dtypes
from pytensor.tensor.type_other import NoneTypeT, SliceType
from pytensor.tensor.type_other import NoneTypeT
def is_rv_used_in_graph(base_rv, node, fgraph):
......@@ -237,20 +237,15 @@ def local_subtensor_rv_lift(fgraph, node):
return False
# Parse indices
if isinstance(subtensor_op, Subtensor):
if isinstance(subtensor_op, Subtensor | AdvancedSubtensor):
indices = indices_from_subtensor(node.inputs[1:], subtensor_op.idx_list)
else:
indices = node.inputs[1:]
# The rewrite doesn't apply if advanced indexing could broadcast the samples (leading to duplicates)
# Note: For simplicity this also excludes subtensor-related expand_dims (np.newaxis).
# If we wanted to support that we could rewrite it as subtensor + dimshuffle
# and make use of the dimshuffle lift rewrite
# TODO: This rewrite is aborting with dummy indexing dimensions which aren't a problem
if any(
is_nd_advanced_idx(idx, integer_dtypes) or isinstance(idx.type, NoneTypeT)
for idx in indices
):
return False
# The rewrite doesn't apply if advanced indexing could broadcast the samples (leading to duplicates)
# TODO: This rewrite is aborting with dummy indexing dimensions which aren't a problem
# (e.g., x[[0],] is equivalent to x[0] - can only index one entry, won't lead to duplicates)
if any(is_nd_advanced_idx(idx, integer_dtypes) for idx in indices):
return False
# Check that indexing does not act on support dims
batch_ndims = rv_op.batch_ndim(rv_node)
......@@ -268,10 +263,7 @@ def local_subtensor_rv_lift(fgraph, node):
non_bool_indices[batch_ndims:],
)
for idx in supp_indices:
if not (
isinstance(idx.type, SliceType)
and all(isinstance(i.type, NoneTypeT) for i in idx.owner.inputs)
):
if idx != slice(None):
return False
n_discarded_idxs = len(supp_indices)
indices = indices[:-n_discarded_idxs]
......@@ -331,7 +323,7 @@ def local_subtensor_rv_lift(fgraph, node):
# Broadcasted dim
if curr_dim in bcast_param_dims:
# Slice indexing, keep degenerate dim by none-slicing
if isinstance(idx, slice) or isinstance(idx.type, SliceType):
if isinstance(idx, slice):
batch_indices.append(slice(None))
# Integer indexing, drop degenerate dim by 0-indexing
else:
......
......@@ -17,7 +17,6 @@ from pytensor.graph.rewriting.basic import (
)
from pytensor.graph.traversal import ancestors
from pytensor.graph.utils import InconsistencyError, get_variable_trace_string
from pytensor.scalar import ScalarType
from pytensor.tensor.basic import (
MakeVector,
as_tensor_variable,
......@@ -842,13 +841,16 @@ def _is_shape_i_of_x(
if isinstance(var.owner.op, Shape_i):
return (var.owner.op.i == i) and (var.owner.inputs[0] == x) # type: ignore
# Match Subtensor((ScalarType,))(Shape(input), i)
# Match Subtensor((int,))(Shape(input), i) - single integer index into shape
if isinstance(var.owner.op, Subtensor):
idx_entry = (
var.owner.op.idx_list[0] if len(var.owner.op.idx_list) == 1 else None
)
return (
# Check we have integer indexing operation
# (and not slice or multiple indexing)
len(var.owner.op.idx_list) == 1
and isinstance(var.owner.op.idx_list[0], ScalarType)
and isinstance(idx_entry, int)
# Check we are indexing on the shape of x
and var.owner.inputs[0].owner is not None
and isinstance(var.owner.inputs[0].owner.op, Shape)
......
import itertools
import sys
import warnings
import numpy as np
......@@ -15,7 +16,7 @@ from pytensor.graph.rewriting.basic import (
node_rewriter,
)
from pytensor.raise_op import Assert
from pytensor.scalar import Add, ScalarConstant, ScalarType
from pytensor.scalar import Add, ScalarConstant
from pytensor.scalar import constant as scalar_constant
from pytensor.tensor.basic import (
Alloc,
......@@ -31,6 +32,7 @@ from pytensor.tensor.basic import (
full,
get_scalar_constant_value,
get_underlying_scalar_constant_value,
moveaxis,
register_infer_shape,
switch,
)
......@@ -72,10 +74,11 @@ from pytensor.tensor.subtensor import (
AdvancedSubtensor1,
IncSubtensor,
Subtensor,
_non_consecutive_adv_indexing,
advanced_inc_subtensor1,
advanced_subtensor,
advanced_subtensor1,
as_index_constant,
basic_subtensor,
get_canonical_form_slice,
get_constant_idx,
get_idx_list,
......@@ -84,7 +87,6 @@ from pytensor.tensor.subtensor import (
indices_from_subtensor,
)
from pytensor.tensor.type import TensorType
from pytensor.tensor.type_other import NoneTypeT, SliceType
from pytensor.tensor.variable import TensorConstant, TensorVariable
......@@ -154,8 +156,10 @@ def transform_take(a, indices, axis):
if len(shape_parts) > 1:
shape = pytensor.tensor.concatenate(shape_parts)
else:
elif len(shape_parts) == 1:
shape = shape_parts[0]
else:
shape = ()
ndim = a.ndim + indices.ndim - 1
......@@ -163,23 +167,11 @@ def transform_take(a, indices, axis):
def is_full_slice(x):
"""Determine if `x` is a ``slice(None)`` or a symbolic equivalent."""
if isinstance(x, slice):
return x == slice(None)
if isinstance(x, Variable) and isinstance(x.type, SliceType):
if x.owner is None:
if isinstance(x, Constant):
return x.data == slice(None)
else:
# Root slice variable
return False
# Symbolic MakeSlice
# Ignores start = 0, step = 1 cases
return all(isinstance(i.type, NoneTypeT) for i in x.owner.inputs)
return False
warnings.warn(
"The function is deprecated, use x==slice(None) instead.",
DeprecationWarning,
)
return x == slice(None)
def get_advsubtensor_axis(indices):
......@@ -194,13 +186,13 @@ def get_advsubtensor_axis(indices):
found_idx = False
axis = 0
for idx in indices:
if not found_idx and is_full_slice(idx):
if not found_idx and idx == slice(None):
# Preceding full slices
axis += 1
elif found_idx and not is_full_slice(idx):
elif found_idx and not idx == slice(None):
# We don't handle multiple indices
return
elif found_idx and is_full_slice(idx):
elif found_idx and idx == slice(None):
# Trailing full slices
continue
else:
......@@ -227,9 +219,8 @@ def local_replace_AdvancedSubtensor(fgraph, node):
if not isinstance(node.op, AdvancedSubtensor):
return
indexed_var = node.inputs[0]
indices = node.inputs[1:]
indexed_var, *index_variables = node.inputs
indices = indices_from_subtensor(index_variables, node.op.idx_list)
axis = get_advsubtensor_axis(indices)
if axis is None or indices[axis].dtype == "bool":
......@@ -253,9 +244,8 @@ def local_AdvancedIncSubtensor_to_AdvancedIncSubtensor1(fgraph, node):
# `AdvancedIncSubtensor1` does not ignore duplicate index values
return
res = node.inputs[0]
val = node.inputs[1]
indices = node.inputs[2:]
res, val, *index_variables = node.inputs
indices = indices_from_subtensor(index_variables, node.op.idx_list)
axis = get_advsubtensor_axis(indices)
......@@ -428,11 +418,7 @@ def local_subtensor_merge(fgraph, node):
merged_slices += slices1[pos_1:]
merged_slices = tuple(as_index_constant(s) for s in merged_slices)
subtens = Subtensor(merged_slices)
sl_ins = get_slice_elements(merged_slices, lambda x: isinstance(x, Variable))
# Do not call make_node for test_value
out = subtens(x, *sl_ins)
out = basic_subtensor(x, *merged_slices)
# Copy over previous output stacktrace
# and stacktrace from previous slicing operation.
......@@ -463,9 +449,8 @@ def local_subtensor_remove_broadcastable_index(fgraph, node):
remove_dim = []
node_inputs_idx = 1
for dim, elem in enumerate(idx):
if isinstance(elem, ScalarType):
# The idx is a ScalarType, ie a Type. This means the actual index
# is contained in node.inputs[1]
if isinstance(elem, int):
# The idx is a integer position.
dim_index = node.inputs[node_inputs_idx]
if isinstance(dim_index, ScalarConstant):
dim_index = dim_index.value
......@@ -477,9 +462,6 @@ def local_subtensor_remove_broadcastable_index(fgraph, node):
elif isinstance(elem, slice):
if elem != slice(None):
return
elif isinstance(elem, int | np.integer):
if elem in (0, -1) and node.inputs[0].broadcastable[dim]:
remove_dim.append(dim)
else:
raise TypeError("case not expected")
......@@ -506,26 +488,29 @@ def local_subtensor_inc_subtensor(fgraph, node):
if not x.owner.op.set_instead_of_inc:
return
if x.owner.inputs[2:] == node.inputs[1:] and tuple(
x.owner.op.idx_list
) == tuple(node.op.idx_list):
x_inc, y_inc, *inc_index_variables = x.owner.inputs
_sub_x, *sub_index_variables = node.inputs
if (
inc_index_variables == sub_index_variables
and x.owner.op.idx_list == node.op.idx_list
):
out = node.outputs[0]
y = x.owner.inputs[1]
# If the dtypes differ, cast y into x.dtype
if x.dtype != y.dtype:
y = y.astype(x.dtype)
if x.dtype != y_inc.dtype:
y_inc = y_inc.astype(x.dtype)
if (
out.type.dtype == y.type.dtype
and out.type.broadcastable == y.type.broadcastable
out.type.dtype == y_inc.type.dtype
and out.type.broadcastable == y_inc.type.broadcastable
):
# if x[idx] and y have the same type, directly return y
return [y]
return [y_inc]
else:
# The difference is related to broadcasting pattern
assert out.broadcastable != y.broadcastable
assert out.broadcastable != y_inc.broadcastable
# We have to alloc y to the shape of x[idx]
x_subtensor = node.op(x.owner.inputs[0], *x.owner.inputs[2:])
return [alloc(y, *x_subtensor.shape)]
x_subtensor = node.op(x_inc, *inc_index_variables)
return [alloc(y_inc, *x_subtensor.shape)]
else:
return
......@@ -829,9 +814,9 @@ def merge_two_slices(fgraph, slice1, len1, slice2, len2):
raise ValueError("slice1 should be of type `slice`")
# Simple case where one of the slices is useless
if is_full_slice(slice1):
if slice1 == slice(None):
return slice2
elif is_full_slice(slice2):
elif slice2 == slice(None):
return slice1
sl1, reverse1 = get_canonical_form_slice(slice1, len1)
......@@ -1090,6 +1075,7 @@ compile.optdb.register(
def local_inplace_AdvancedIncSubtensor(fgraph, node):
if isinstance(node.op, AdvancedIncSubtensor) and not node.op.inplace:
new_op = type(node.op)(
node.op.idx_list,
inplace=True,
set_instead_of_inc=node.op.set_instead_of_inc,
ignore_duplicates=node.op.ignore_duplicates,
......@@ -1276,9 +1262,7 @@ def local_useless_inc_subtensor_alloc(fgraph, node):
"""
if isinstance(node.op, IncSubtensor | AdvancedIncSubtensor | AdvancedIncSubtensor1):
x = node.inputs[0]
y = node.inputs[1]
i = node.inputs[2:]
x, y, *index_variables = node.inputs
if y.owner is not None and isinstance(y.owner.op, Alloc):
# `z` is the input of the Alloc op, i.e. at.alloc(z, <shape>)
......@@ -1297,11 +1281,11 @@ def local_useless_inc_subtensor_alloc(fgraph, node):
# Get the subtensor of `x` indexed by `i` in order to compare
# shapes later.
if isinstance(node.op, IncSubtensor):
xi = Subtensor(node.op.idx_list)(x, *i)
xi = Subtensor(node.op.idx_list)(x, *index_variables)
elif isinstance(node.op, AdvancedIncSubtensor):
xi = advanced_subtensor(x, *i)
xi = AdvancedSubtensor(node.op.idx_list)(x, *index_variables)
elif isinstance(node.op, AdvancedIncSubtensor1):
xi = advanced_subtensor1(x, *i)
xi = advanced_subtensor1(x, *index_variables)
else:
raise Exception("Should never happen!")
......@@ -1361,7 +1345,7 @@ def local_useless_inc_subtensor_alloc(fgraph, node):
msg = "`x[i]` and `y` do not have the same shape."
z = Assert(msg)(z, *cond)
r = node.op(x, z, *i)
r = node.op(x, z, *index_variables)
# Copy over stacktrace from previous output, since
# we don't expect problems when removing the intermediate
# alloc operation and so we still want to point at the line
......@@ -1493,8 +1477,7 @@ def local_uint_constant_indices(fgraph, node):
x, *indices = node.inputs
y = None
idx_list = getattr(node.op, "idx_list", None)
new_indices = list(indices_from_subtensor(indices, idx_list))
new_indices = list(indices_from_subtensor(indices, node.op.idx_list))
has_new_index = False
for i, index in enumerate(new_indices):
......@@ -1544,14 +1527,7 @@ def local_uint_constant_indices(fgraph, node):
if not has_new_index:
return False
if isinstance(op, Subtensor | IncSubtensor):
# Basic index Ops contain information about the dtype of the indices, so wee have to recreate them
props = op._props_dict()
props["idx_list"] = new_indices
op = type(op)(**props)
# Basic index Ops don't expect slices, but the respective start/step/stop
new_indices = get_slice_elements(new_indices)
new_indices = get_slice_elements(new_indices)
new_args = (x, *new_indices) if y is None else (x, y, *new_indices)
new_out = op(*new_args)
copy_stack_trace(node.outputs[0], new_out)
......@@ -1611,27 +1587,18 @@ def local_blockwise_inc_subtensor(fgraph, node):
core_op = node.op.core_op
x, y, *idxs = node.inputs
[out] = node.outputs
if isinstance(core_op, AdvancedIncSubtensor):
if any(
(
# Blockwise requires all inputs to be tensors so it is not possible
# to wrap an AdvancedIncSubtensor with slice / newaxis inputs, but we check again just in case
# If this is ever supported we need to pay attention to special behavior of numpy when advanced indices
# are separated by basic indices
isinstance(idx, SliceType | NoneTypeT)
# Also get out if we have boolean indices as they cross dimension boundaries
# / can't be safely broadcasted depending on their runtime content
or (idx.type.dtype == "bool")
)
for idx in idxs
):
return None
advanced = isinstance(core_op, AdvancedIncSubtensor)
if advanced and any(idx.type.dtype == "bool" for idx in idxs):
# Get out if we have boolean indices as they cross dimension boundaries
# / can't be safely broadcasted depending on their runtime content
return None
batch_ndim = node.op.batch_ndim(node)
idxs_core_ndim = [len(inp_sig) for inp_sig in node.op.inputs_sig[2:]]
max_idx_core_ndim = max(idxs_core_ndim, default=0)
# Step 1. Broadcast buffer to batch_shape
# Broadcast buffer to batch_shape
if x.type.broadcastable != out.type.broadcastable:
batch_shape = [1] * batch_ndim
for inp in node.inputs:
......@@ -1648,58 +1615,61 @@ def local_blockwise_inc_subtensor(fgraph, node):
x = broadcast_to(x, (*batch_shape, *x.shape[batch_ndim:]))
assert x.type.broadcastable == out.type.broadcastable
# Step 2. Massage indices so they respect blockwise semantics
if isinstance(core_op, IncSubtensor):
# For basic IncSubtensor there are two cases:
# 1. Slice entries -> We need to squeeze away dummy dimensions so we can convert back to slice
# 2. Integers -> Can be used as is, but we try to squeeze away dummy batch dimensions
# in case we can end up with a basic IncSubtensor again
core_idxs = []
counter = 0
for idx in core_op.idx_list:
if isinstance(idx, slice):
# Squeeze away dummy dimensions so we can convert to slice
new_entries = [None, None, None]
for i, entry in enumerate((idx.start, idx.stop, idx.step)):
if entry is None:
continue
else:
new_entries[i] = new_entry = idxs[counter].squeeze()
counter += 1
if new_entry.ndim > 0:
# If the slice entry has dimensions after the squeeze we can't convert it to a slice
# We could try to convert to equivalent integer indices, but nothing guarantees
# that the slice is "square".
return None
core_idxs.append(slice(*new_entries))
# Massage indices so they respect blockwise semantics while using regular indexing
core_idxs = []
for idx_entry in core_op.idx_list:
if isinstance(idx_entry, slice):
# Squeeze away dummy dimensions so we can convert to slice
new_entries = [None, None, None]
for i, slice_idx_entry in enumerate(
(idx_entry.start, idx_entry.stop, idx_entry.step)
):
if slice_idx_entry is None:
continue
else:
new_entries[i] = new_entry = idxs[slice_idx_entry].squeeze()
if new_entry.ndim > 0:
# If the slice entry has dimensions after the squeeze we can't convert it to a slice
# We could try to convert to equivalent integer indices, but nothing guarantees
# that the slice is "square".
return None
squeezed_index = slice(*new_entries)
else:
if advanced:
# For AdvancedIncSubtensor we have tensor integer indices,
# We need to expand batch indexes on the right, so they don't interact with core index dimensions
# We still squeeze on the left in case that allows us to use simpler indices
squeezed_index = _squeeze_left(
shape_padright(
idxs[idx_entry], max_idx_core_ndim - idxs_core_ndim[idx_entry]
),
stop_at_dim=batch_ndim,
)
else:
core_idxs.append(_squeeze_left(idxs[counter]))
counter += 1
else:
# For AdvancedIncSubtensor we have tensor integer indices,
# We need to expand batch indexes on the right, so they don't interact with core index dimensions
# We still squeeze on the left in case that allows us to use simpler indices
core_idxs = [
_squeeze_left(
shape_padright(idx, max_idx_core_ndim - idx_core_ndim),
stop_at_dim=batch_ndim,
)
for idx, idx_core_ndim in zip(idxs, idxs_core_ndim)
]
# For basic IncSubtensor integers indices can be used as is, but we try to squeeze away dummy
# batch dimensions in case we can end up with a basic IncSubtensor again
squeezed_index = _squeeze_left(idxs[idx_entry])
core_idxs.append(squeezed_index)
# Step 3. Create new indices for the new batch dimension of x
if not all(
# Create new indices for the batch dimensions
has_batched_indices = not all(
all(idx.type.broadcastable[:batch_ndim])
for idx in idxs
if not isinstance(idx, slice)
):
# If indices have batch dimensions in the indices, they will interact with the new dimensions of x
# We build vectorized indexing with new arange indices that do not interact with core indices or each other
# (i.e., they broadcast)
# Note: due to how numpy handles non-consecutive advanced indexing (transposing it to the front),
# we don't want to create a mix of slice(None), and arange() indices for the new batch dimension,
# even if not all batch dimensions have corresponding batch indices.
)
if has_batched_indices:
# If indices have batch dimensions, we need to align them element-wise with the respective batch dimensions of x
# We achieve this by creating `arange` indices and adding expand_dims for correct broadcasting.
# Example:
# x = pt.zeros(5); idx = [0, 1, 0]; out = x[idx].set(y)
# batch_x = pt.zeros((2, 5)); batch_idx = [[0, 1, 0], [1, 1, 2]]
# batch_out = batch_x[[0, 1][:, None], batch_idx].set(y)
# If instead batch_x = pt.zeros((2, 2, 5))
# batch_out = batch_x[[0, 1][:, None, None], [0, 1][None, 1, None], batch_idx]
# Note: For simplicity we use arange for all batch dimensions of x,
# even if not all may have corresponding batch index dimensions
batch_slices = [
shape_padright(arange(x_batch_shape, dtype="int64"), n)
for (x_batch_shape, n) in zip(
......@@ -1715,29 +1685,49 @@ def local_blockwise_inc_subtensor(fgraph, node):
new_idxs = (*batch_slices, *core_idxs)
x_view = x[new_idxs]
# Step 4. Introduce any implicit expand_dims on core dimension of y
# Introduce any implicit expand_dims on core dimension of y
missing_y_core_ndim = x_view.type.ndim - y.type.ndim
implicit_axes = tuple(range(batch_ndim, batch_ndim + missing_y_core_ndim))
y = _squeeze_left(expand_dims(y, implicit_axes), stop_at_dim=batch_ndim)
if isinstance(core_op, IncSubtensor):
# Check if we can still use a basic IncSubtensor
if isinstance(x_view.owner.op, Subtensor):
new_props = core_op._props_dict()
new_props["idx_list"] = x_view.owner.op.idx_list
new_core_op = type(core_op)(**new_props)
symbolic_idxs = x_view.owner.inputs[1:]
new_out = new_core_op(x, y, *symbolic_idxs)
else:
# We need to use AdvancedSet/IncSubtensor
if core_op.set_instead_of_inc:
new_out = x[new_idxs].set(y)
else:
new_out = x[new_idxs].inc(y)
y = expand_dims(y, implicit_axes)
# Transpose y if needed
if has_batched_indices:
# By introducing arange slices we may caused a transposition of the advanced group to the front
# If this was not already happening in the core graph, we'll need to transpose y to align it correctly
if max_idx_core_ndim and not (
advanced and _non_consecutive_adv_indexing(core_idxs)
):
integer_pos = [
i for i, entry in enumerate(core_op.idx_list) if isinstance(entry, int)
]
slice_pos = [
i
for i, entry in enumerate(core_op.idx_list)
if isinstance(entry, slice)
]
if slice_pos and integer_pos and (slice_pos[0] < integer_pos[-1]):
y = moveaxis(
y,
[batch_ndim + integer_pos[0] + i for i in range(max_idx_core_ndim)],
[batch_ndim + i for i in range(max_idx_core_ndim)],
)
else:
# AdvancedIncSubtensor takes symbolic indices/slices directly, no need to create a new op
symbolic_idxs = x_view.owner.inputs[1:]
new_out = core_op(x, y, *symbolic_idxs)
# Conversely if we tried to use `slice(None)` for the batch dimensions but there was already transposition
# in the core case, we'll need to move the batch slices of y to after the advanced indexing group
if advanced and _non_consecutive_adv_indexing(core_idxs):
y = moveaxis(
y,
[i for i in range(batch_ndim)], # noqa: C416
[max_idx_core_ndim + i for i in range(batch_ndim)],
)
# Remove useless left-batch dimensions of y (if any)
y = _squeeze_left(y, stop_at_dim=batch_ndim)
if core_op.set_instead_of_inc:
new_out = x[new_idxs].set(y)
else:
new_out = x[new_idxs].inc(y)
copy_stack_trace(out, new_out)
return [new_out]
......@@ -1754,10 +1744,12 @@ def bool_idx_to_nonzero(fgraph, node):
else:
x, y, *idxs = node.inputs
idxs = indices_from_subtensor(idxs, node.op.idx_list)
bool_pos = {
i
for i, idx in enumerate(idxs)
if (isinstance(idx.type, TensorType) and idx.dtype == "bool")
if isinstance(idx, TensorVariable) and idx.dtype == "bool"
}
if not bool_pos:
......@@ -1771,9 +1763,13 @@ def bool_idx_to_nonzero(fgraph, node):
new_idxs.append(idx)
if isinstance(node.op, AdvancedSubtensor):
new_out = node.op(x, *new_idxs)
new_out = x[tuple(new_idxs)]
else:
new_out = node.op(x, y, *new_idxs)
new_out = (
x[tuple(new_idxs)].set(y)
if node.op.set_instead_of_inc
else x[tuple(new_idxs)].inc(y)
)
return [copy_stack_trace(node.outputs[0], new_out)]
......@@ -1822,7 +1818,8 @@ def extract_diag_of_diagonal_set_subtensor(fgraph, node):
):
return None
x, y, *idxs = diag_x.owner.inputs
x, y, *idx_variables = diag_x.owner.inputs
idxs = indices_from_subtensor(idx_variables, diag_x.owner.op.idx_list)
if not (
x.type.ndim >= 2
......@@ -1838,7 +1835,7 @@ def extract_diag_of_diagonal_set_subtensor(fgraph, node):
# Check all non-axis indices are full slices
axis = {op.axis1, op.axis2}
if not all(is_full_slice(idx) for i, idx in enumerate(idxs) if i not in axis):
if not all(idx == slice(None) for i, idx in enumerate(idxs) if i not in axis):
return None
# Check axis indices are arange we would expect from setting on the diagonal
......
......@@ -8,7 +8,6 @@ from pytensor import Variable
from pytensor.compile import optdb
from pytensor.graph import Constant, FunctionGraph, node_rewriter, vectorize_graph
from pytensor.graph.rewriting.basic import NodeRewriter, copy_stack_trace
from pytensor.scalar import basic as ps
from pytensor.tensor.basic import (
Alloc,
Join,
......@@ -31,7 +30,7 @@ from pytensor.tensor.rewriting.basic import (
register_stabilize,
)
from pytensor.tensor.rewriting.elemwise import local_dimshuffle_lift
from pytensor.tensor.rewriting.subtensor import is_full_slice, register_useless
from pytensor.tensor.rewriting.subtensor import register_useless
from pytensor.tensor.shape import (
Shape,
SpecifyShape,
......@@ -50,7 +49,6 @@ from pytensor.tensor.subtensor import (
indices_from_subtensor,
)
from pytensor.tensor.type import TensorType
from pytensor.tensor.type_other import NoneTypeT, SliceType
from pytensor.tensor.variable import TensorVariable
......@@ -71,7 +69,7 @@ def _axis_is_indexed_by_basic_index(
) -> bool:
if isinstance(axis, int):
axis = (axis,)
return any(ax < len(idxs) and not is_full_slice(idxs[ax]) for ax in axis)
return any(ax < len(idxs) and not idxs[ax] == slice(None) for ax in axis)
def _lift_subtensor_non_axis(
......@@ -83,7 +81,7 @@ def _lift_subtensor_non_axis(
old_subtensor_variable: TensorVariable,
) -> None | list[TensorVariable]:
# Apply generic subtensor lift rewrite along "non-axis" dimensions
real_indices = [idx for idx in idx_tuple if not is_full_slice(idx)]
real_indices = [idx for idx in idx_tuple if not idx == slice(None)]
if len(real_indices) > 1 and variable.type.ndim > 1:
# Split the subtensor
idx_to_keep = idx_tuple[axis]
......@@ -206,7 +204,7 @@ def local_subtensor_of_batch_dims(fgraph, node):
if len(idx_tuple) > batch_ndim:
# Indexing on core dimensions of Blockwise. We split the indices and lift the batch ones only
batch_indices, core_indices = idx_tuple[:batch_ndim], idx_tuple[batch_ndim:]
if all(is_full_slice(idx) for idx in batch_indices):
if all(idx == slice(None) for idx in batch_indices):
# No batch indices, nothing to do
return None
elem_with_batch_indices = elem[batch_indices]
......@@ -240,7 +238,7 @@ def local_subtensor_of_batch_dims(fgraph, node):
strict=False,
)
):
if is_full_slice(dim_idx):
if dim_idx == slice(None):
# Full slice can be safely applied to all inputs
continue
......@@ -429,7 +427,7 @@ def local_subtensor_of_expand_dims(fgraph, node):
if i in expanded_axes:
if isinstance(idx_item, slice):
# Slice could be keeping or dropping this dimension
if is_full_slice(idx_item):
if idx_item == slice(None):
# A None slice, always keeps the dimension.
# We skip the index, and later introduce the needed expand_dim
continue
......@@ -648,10 +646,7 @@ def local_subtensor_SpecifyShape_lift(fgraph, node):
indices = get_idx_list(node.inputs, node.op.idx_list)
if any(
isinstance(index, slice) or isinstance(getattr(index, "type", None), SliceType)
for index in indices
):
if any(isinstance(index, slice) for index in indices):
return False
new_obj_arg = obj_arg[indices]
......@@ -702,15 +697,12 @@ def local_subtensor_make_vector(fgraph, node):
(idx,) = idxs
if isinstance(idx, ps.ScalarType | TensorType):
old_idx, idx = idx, node.inputs[1]
assert idx.type.is_super(old_idx)
if isinstance(idx, int):
idx = node.inputs[1]
elif isinstance(node.op, AdvancedSubtensor1):
idx = node.inputs[1]
if isinstance(idx, int | np.integer):
return [x.owner.inputs[idx]]
elif isinstance(idx, Variable):
if isinstance(idx, Variable):
if idx.ndim == 0:
try:
v = get_underlying_scalar_constant_value(
......@@ -833,8 +825,6 @@ def local_subtensor_shape_constant(fgraph, node):
except NotScalarConstantError:
return False
assert idx_val != np.newaxis
if not isinstance(shape_arg.type, TensorType):
return False
......@@ -871,22 +861,24 @@ def local_subtensor_of_adv_subtensor(fgraph, node):
# AdvancedSubtensor involves a full_copy, so we don't want to do it twice
return None
x, *adv_idxs = adv_subtensor.owner.inputs
x, *adv_index_vars = adv_subtensor.owner.inputs
adv_idxs = indices_from_subtensor(adv_index_vars, adv_subtensor.owner.op.idx_list)
# Advanced indexing is a minefield, avoid all cases except for consecutive integer indices
if any(
(
isinstance(adv_idx.type, NoneTypeT)
or (isinstance(adv_idx.type, TensorType) and adv_idx.type.dtype == "bool")
or (isinstance(adv_idx.type, SliceType) and not is_full_slice(adv_idx))
if (
not all(
(
(isinstance(adv_idx, TensorVariable) and adv_idx.type.dtype != "bool")
or (isinstance(adv_idx, slice) and adv_idx == slice(None))
)
for adv_idx in adv_idxs
)
for adv_idx in adv_idxs
) or _non_consecutive_adv_indexing(adv_idxs):
return None
for first_adv_idx_dim, adv_idx in enumerate(adv_idxs):
# We already made sure there were only None slices besides integer indexes
if isinstance(adv_idx.type, TensorType):
if isinstance(adv_idx, TensorVariable):
break
else: # no-break
# Not sure if this should ever happen, but better safe than sorry
......@@ -909,7 +901,7 @@ def local_subtensor_of_adv_subtensor(fgraph, node):
copy_stack_trace([basic_subtensor, adv_subtensor], x_indexed)
x_after_index_lift = expand_dims(x_indexed, dropped_dims)
x_after_adv_idx = adv_subtensor.owner.op(x_after_index_lift, *adv_idxs)
x_after_adv_idx = adv_subtensor.owner.op(x_after_index_lift, *adv_index_vars)
copy_stack_trace([basic_subtensor, adv_subtensor], x_after_adv_idx)
new_out = squeeze(x_after_adv_idx[basic_idxs_kept], dropped_dims)
......
......@@ -38,7 +38,7 @@ from pytensor.tensor.elemwise import CAReduce, DimShuffle
from pytensor.tensor.math import Min, neg
from pytensor.tensor.rewriting.basic import register_uncanonicalize
from pytensor.tensor.shape import Reshape, reshape
from pytensor.tensor.subtensor import Subtensor
from pytensor.tensor.subtensor import Subtensor, indices_from_subtensor
@register_uncanonicalize
......@@ -193,60 +193,42 @@ def local_dimshuffle_subtensor(fgraph, node):
if not all(broadcastable[i] for i in missing_dims):
return False
# create a new idx_list for a new Subtensor object
# have to loop on idx_list and inputs
# inputs has the length of sum of non None elements of idx_list
# (check in slice!).
# len(missing_dims) can be < len(idx_list), this happens if
# tensor was indexed such as x[scalar, :, :], check that as well
new_idx_list = list(input_.owner.op.idx_list)
new_inputs = [input_.owner.inputs[0]]
# create a new index tuple for a new Subtensor
# Reconstruct the full indices from the subtensor node, then replace
# dimensions that are being dropped by dimshuffle with scalar index 0
x = input_.owner.inputs[0]
indices = list(
indices_from_subtensor(
input_.owner.inputs[1:], input_.owner.op.idx_list
)
)
zero = constant(0)
j = 0
slice_i = -1
subtensor_removed_dims = 0
for i, idx in enumerate(input_.owner.op.idx_list):
# Track which output dimension each index corresponds to
# Scalar indices remove dimensions, slices keep them
output_dim = 0
for i, idx in enumerate(indices):
if isinstance(idx, slice):
slice_i += 1
if slice_i in missing_dims:
# Missing dim is a slice(None), remove by indexing by 0
# This slice produces an output dimension
if output_dim in missing_dims:
# This output dimension is being dropped, so replace slice with scalar
if idx == slice(None):
new_idx_list[i] = zero
new_inputs += [zero]
# Missing dim is an ordinary slice with known output dim length of 1
# Remove by indexing by start
indices[i] = zero
else:
if idx.start is None:
start = zero
else:
start = input_.owner.inputs[1 + j]
j += 1
new_idx_list[i] = start
new_inputs += [start]
# Ignore useless stop and step input if there is one
for slice_attr in ("stop", "step"):
if getattr(idx, slice_attr) is not None:
j += 1
# Keep non-dropped slice inputs
else:
for slice_attr in ("start", "stop", "step"):
if getattr(idx, slice_attr) is not None:
new_inputs += [input_.owner.inputs[1 + j]]
j += 1
# Keep non-dropped non-slice inputs
# Use the start of the slice (or 0 if None)
indices[i] = idx.start if idx.start is not None else zero
output_dim += 1
# Scalar indices don't contribute to output dimensions
# Handle trailing dimensions that weren't explicitly indexed
for input_dim in range(len(indices), x.ndim):
if output_dim in missing_dims:
# This unindexed dimension is being dropped, index with 0
indices.append(zero)
else:
new_inputs += [input_.owner.inputs[1 + j]]
j += 1
subtensor_removed_dims += 1
# Verify the trailing dimensions the subtensor didn't look at.
for idx in range(len(input_.owner.op.idx_list), new_inputs[0].ndim):
if (idx - subtensor_removed_dims) in missing_dims:
while len(new_idx_list) < idx:
new_idx_list.append(slice(None))
new_idx_list.append(zero)
new_inputs.append(zero)
return [Subtensor(new_idx_list)(*new_inputs)]
# This unindexed dimension is kept, index with slice(None)
indices.append(slice(None))
output_dim += 1
return [x[tuple(indices)]]
return False
import logging
import sys
import warnings
from collections.abc import Callable, Iterable, Sequence
from collections.abc import Callable, Sequence
from itertools import chain, groupby, zip_longest
from typing import cast, overload
from typing import TypeVar, cast, overload
import numpy as np
from numpy.lib.array_utils import normalize_axis_tuple
......@@ -15,7 +15,6 @@ from pytensor.gradient import disconnected_type
from pytensor.graph.basic import Apply, Constant, Variable
from pytensor.graph.op import Op
from pytensor.graph.replace import _vectorize_node
from pytensor.graph.type import Type
from pytensor.graph.utils import MethodNotDefined
from pytensor.link.c.op import COp
from pytensor.link.c.params_type import ParamsType
......@@ -38,117 +37,114 @@ from pytensor.tensor.basic import (
)
from pytensor.tensor.blockwise import vectorize_node_fallback
from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.exceptions import AdvancedIndexingError, NotScalarConstantError
from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.math import add, clip
from pytensor.tensor.shape import Reshape, Shape_i, specify_broadcastable
from pytensor.tensor.shape import (
Reshape,
Shape_i,
specify_broadcastable,
)
from pytensor.tensor.type import (
TensorType,
bscalar,
complex_dtypes,
cscalar,
discrete_dtypes,
dscalar,
fscalar,
integer_dtypes,
iscalar,
lscalar,
tensor,
ubscalar,
uiscalar,
ulscalar,
uwscalar,
wscalar,
zscalar,
)
from pytensor.tensor.type_other import (
MakeSlice,
NoneConst,
NoneSliceConst,
NoneTypeT,
SliceConstant,
SliceType,
make_slice,
)
from pytensor.tensor.type_other import NoneTypeT
from pytensor.tensor.variable import TensorConstant, TensorVariable
from pytensor.utils import unzip
_logger = logging.getLogger("pytensor.tensor.subtensor")
invalid_scal_types = (ps.float64, ps.float32, ps.float16)
scal_types = (
ps.int64,
ps.int32,
ps.int16,
ps.int8,
ps.uint64,
ps.uint32,
ps.uint16,
ps.uint8,
)
tensor_types = (
lscalar,
iscalar,
wscalar,
bscalar,
ulscalar,
uiscalar,
uwscalar,
ubscalar,
)
invalid_tensor_types = (
fscalar,
dscalar,
cscalar,
zscalar,
)
T = TypeVar("T")
def flatten_index_variables(
idx_vars: Sequence[T | None | slice],
) -> tuple[list[int | slice], list[T]]:
counter = 0
idx_list: list[int | slice] = []
flat_vars = []
for idx_var in idx_vars:
if isinstance(idx_var, slice):
slice_idx_list: list[None | int] = []
for arg_entry in (idx_var.start, idx_var.stop, idx_var.step):
if arg_entry is None or (
isinstance(arg_entry, Variable)
and isinstance(arg_entry.type, NoneTypeT)
):
slice_idx_list.append(None)
else:
flat_vars.append(arg_entry)
slice_idx_list.append(counter)
counter += 1
idx_list.append(slice(*slice_idx_list))
else:
flat_vars.append(idx_var)
idx_list.append(counter)
counter += 1
return idx_list, flat_vars
def unflatten_index_variables(
flat_indices: Sequence[T],
idx_list: Sequence[slice | int],
) -> tuple[slice | T, ...]:
indices: list[T | slice] = []
for idx_entry in idx_list:
if isinstance(idx_entry, int):
indices.append(flat_indices[idx_entry])
elif isinstance(idx_entry, slice):
start, stop, step = idx_entry.start, idx_entry.stop, idx_entry.step
indices.append(
slice(
None if idx_entry.start is None else flat_indices[start],
None if idx_entry.stop is None else flat_indices[stop],
None if idx_entry.step is None else flat_indices[step],
)
)
else:
raise ValueError(f"idx_entry must be int or slice, got {type(idx_entry)}")
return tuple(indices)
def indices_from_subtensor(
op_indices: Iterable[ScalarConstant],
idx_list: list[Type | slice | Variable] | None,
op_indices: Sequence[Variable],
idx_list: tuple[slice | int, ...],
) -> tuple[slice | Variable, ...]:
"""Recreate the index tuple from which a ``*Subtensor**`` ``Op`` was created.
Parameters
==========
----------
op_indices
The flattened indices obtained from ``x.inputs``, when ``x`` is a
``*Subtensor*`` node.
The flattened indices obtained from ``x.inputs``, when ``x`` is a ``*Subtensor*`` node.
idx_list
The values describing the types of each dimension's index. This is
obtained from ``op.idx_list``, when ``op`` is a ``*Subtensor*``
``Op``.
The values describing each dimension's index. This is obtained from
``op.idx_list``. Entries can be:
- Integer positions (indices into op_indices)
- slice objects with int/None components
Returns
-------
tuple[slice | Variable, ...]
A tuple containing a mix of ``slice`` objects and ``Variable`` objects.
Each element corresponds to one indexing dimension:
- ``slice`` objects for slice-based indexing (e.g., ``x[1:3]``)
- ``Variable`` objects for scalar or array-based indexing
Callers should handle both types when iterating over the result.
Example
=======
-------
array, *op_indices = subtensor_node.inputs
idx_list = getattr(subtensor_node.op, "idx_list", None)
indices = indices_from_subtensor(op_indices, idx_list)
indices = indices_from_subtensor(op_indices, subtensor_node.op.idx_list)
"""
def convert_indices(indices, entry):
"""Reconstruct ``*Subtensor*`` index input parameter entries."""
if indices and isinstance(entry, Type):
rval = indices.pop(0)
return rval
elif isinstance(entry, slice):
return slice(
convert_indices(indices, entry.start),
convert_indices(indices, entry.stop),
convert_indices(indices, entry.step),
)
else:
return entry
op_indices = list(op_indices)
return (
tuple(convert_indices(op_indices, idx) for idx in idx_list)
if idx_list
else tuple(op_indices)
)
return unflatten_index_variables(op_indices, idx_list)
def as_index_constant(
......@@ -182,7 +178,7 @@ def as_index_literal(idx: None) -> None: ...
@overload
def as_index_literal(idx: slice | SliceConstant) -> slice: ...
def as_index_literal(idx: slice) -> slice: ...
@overload
......@@ -194,14 +190,7 @@ def as_index_literal(idx: Variable): ...
def as_index_literal(
idx: None
| int
| np.integer
| slice
| SliceConstant
| ScalarConstant
| TensorConstant
| Variable,
idx: None | int | np.integer | slice | ScalarConstant | TensorConstant | Variable,
) -> int | np.integer | slice | None:
"""Convert a symbolic index element to its Python equivalent.
......@@ -224,9 +213,6 @@ def as_index_literal(
if not isinstance(idx, Variable):
raise TypeError(f"Not an index element: {idx}")
if isinstance(idx.type, NoneTypeT):
return None
if isinstance(idx, ScalarConstant):
return cast(int, idx.data)
......@@ -240,13 +226,6 @@ def as_index_literal(
if isinstance(idx, TensorConstant):
return cast(int, idx.data.item())
if isinstance(idx, SliceConstant):
return cast(slice, idx.data)
if isinstance(idx.type, SliceType):
assert idx.owner is not None
return slice(*map(as_index_literal, idx.owner.inputs))
# Other kinds of variables are not supported
raise NotScalarConstantError()
......@@ -275,10 +254,8 @@ def get_canonical_form_slice(
) -> tuple[slice | TensorVariable, int | TensorVariable]:
"""Convert indices or slices to canonical form.
Scalar integer indices or python Slices with Scalar/None attributes
used in basic Subtensor Ops are supported.
Symbolic slices (of SliceType) or vector indices
used in advanced Subtensor Ops are not supported.
Handles slice objects with ScalarVariable (including ScalarConstant) or None components.
Vector indices and advanced indexing operations are handled separately by AdvancedSubtensor.
Given a slice [start:stop:step] transform it into a canonical form
that respects the conventions imposed by python and numpy.
......@@ -492,16 +469,20 @@ def get_canonical_form_slice(
return slice(nw_start, nw_stop, nw_step), 1
def range_len(slc):
"""Length of a `range` object.
def slice_len(slc, n):
"""Compute the length of a slice for an array of a given length.
We're essentially computing `len(range(*slc.indices(n)))`.
Adapted from CPython.
"""
from pytensor.tensor import and_, gt, lt, switch
# TODO: Do we need to do this or should we expect `slc` to already be canonicalized?
canon_slc, _ = get_canonical_form_slice(slc, n)
start, stop, step = tuple(
as_index_constant(a) for a in [slc.start, slc.stop, slc.step]
as_index_constant(a) for a in [canon_slc.start, canon_slc.stop, canon_slc.step]
)
return switch(
and_(gt(step, 0), lt(start, stop)),
......@@ -514,31 +495,6 @@ def range_len(slc):
)
def slice_len(slc, n):
"""Compute the length of a slice for an array of a given length.
We're essentially computing `len(range(*slc.indices(n)))`.
"""
# TODO: Do we need to do this or should we expect `slc` to
# already be canonicalized?
canon_slc, _ = get_canonical_form_slice(slc, n)
return range_len(canon_slc)
def is_basic_idx(idx):
"""Determine if an index is of the NumPy basic type.
XXX: This only checks a single index, so an integer is *not* considered a
basic index, because--depending on the other indices its used with--an
integer can indicate advanced indexing.
"""
return isinstance(idx, slice | type(None)) or isinstance(
getattr(idx, "type", None), SliceType | NoneTypeT
)
def basic_shape(shape, indices):
r"""Computes the shape resulting from basic NumPy indexing.
......@@ -557,25 +513,8 @@ def basic_shape(shape, indices):
for n, idx in zip(shape[: len(indices)], indices, strict=True):
if isinstance(idx, slice):
res_shape += (slice_len(idx, n),)
elif isinstance(getattr(idx, "type", None), SliceType):
if idx.owner is None:
if not isinstance(idx, Constant):
# This is an input slice, we can't reason symbolically on it.
# We don't even know if we will get None entries or integers
res_shape += (None,)
continue
else:
sl: slice = idx.data
slice_inputs = (sl.start, sl.stop, sl.step)
elif isinstance(idx.owner.op, MakeSlice):
slice_inputs = idx.owner.inputs
else:
raise ValueError(f"Unexpected Slice producing Op {idx.owner.op}")
res_shape += (slice_len(slice(*slice_inputs), n),)
elif idx is None:
res_shape += (ps.ScalarConstant(ps.int64, 1),)
elif isinstance(getattr(idx, "type", None), NoneTypeT):
res_shape += (ps.ScalarConstant(ps.int64, 1),)
else:
raise ValueError(f"Invalid index type: {idx}")
return res_shape
......@@ -593,14 +532,12 @@ def group_indices(indices):
"""
idx_groups = []
dim_num = -1
for basic, grp_indices in groupby(indices, key=is_basic_idx):
for basic, grp_indices in groupby(indices, key=lambda x: isinstance(x, slice)):
enum_grp_indices = []
for idx in grp_indices:
# We "zip" the dimension number to each index, which means we can't
# count indices that add new axes
if (idx is not None) and not isinstance(
getattr(idx, "type", None), NoneTypeT
):
if idx is not None:
dim_num += 1
enum_grp_indices.append((dim_num, idx))
......@@ -647,7 +584,7 @@ def indexed_result_shape(array_shape, indices, indices_are_shapes=False):
idx_groups = sorted(idx_groups, key=lambda x: x[0])
idx_groups = groupby(
chain.from_iterable(d_idx for _, d_idx in idx_groups),
key=lambda x: is_basic_idx(x[1]),
key=lambda x: isinstance(x[1], slice),
)
for basic, grp_dim_indices in idx_groups:
......@@ -707,72 +644,6 @@ def get_slice_elements(
return ret
def index_vars_to_types(entry, slice_ok=True):
r"""Change references to `Variable`s into references to `Type`s.
The `Subtensor.idx_list` field is unique to each `Subtensor` instance. It
is not unique to each `Apply` node, so it should not refer to specific
`Variable`s.
TODO WRITEME: This function also accepts an `entry` already being a `Type`;
when would that happen?
"""
if (
isinstance(entry, np.ndarray | Variable)
and hasattr(entry, "dtype")
and entry.dtype == "bool"
):
raise AdvancedIndexingError("Invalid index type or slice for Subtensor")
if isinstance(entry, Variable) and (
entry.type in invalid_scal_types or entry.type in invalid_tensor_types
):
raise TypeError("Expected an integer")
if isinstance(entry, Variable) and entry.type in scal_types:
return entry.type
elif isinstance(entry, Type) and entry in scal_types:
return entry
if (
isinstance(entry, Variable)
and entry.type in tensor_types
and all(entry.type.broadcastable)
):
return ps.get_scalar_type(entry.type.dtype)
elif isinstance(entry, Type) and entry in tensor_types and all(entry.broadcastable):
return ps.get_scalar_type(entry.dtype)
elif slice_ok and isinstance(entry, slice):
a = entry.start
b = entry.stop
c = entry.step
if a is not None:
slice_a = index_vars_to_types(a, False)
else:
slice_a = None
if b is not None and b != sys.maxsize:
# The special "maxsize" case is probably not needed here,
# as slices containing maxsize are not generated by
# __getslice__ anymore.
slice_b = index_vars_to_types(b, False)
else:
slice_b = None
if c is not None:
slice_c = index_vars_to_types(c, False)
else:
slice_c = None
return slice(slice_a, slice_b, slice_c)
elif isinstance(entry, int | np.integer):
raise TypeError()
else:
raise AdvancedIndexingError("Invalid index type or slice for Subtensor")
def get_constant_idx(
idx_list, inputs, allow_partial=False, only_process_constants=False, elemwise=True
):
......@@ -803,7 +674,7 @@ def get_constant_idx(
>>> a = matrix("a")
>>> b = a[v, 1:3]
>>> b.owner.op.idx_list
(ScalarType(int64), slice(ScalarType(int64), ScalarType(int64), None))
(0, slice(1, 2, None))
>>> get_constant_idx(b.owner.op.idx_list, b.owner.inputs, allow_partial=True)
[v, slice(np.int64(1), np.int64(3), None)]
>>> get_constant_idx(b.owner.op.idx_list, b.owner.inputs)
......@@ -835,15 +706,11 @@ def get_constant_idx(
return list(map(conv, real_idx))
def as_nontensor_scalar(a: Variable) -> ps.ScalarVariable:
"""Convert a value to a `ScalarType` variable."""
# Since ps.as_scalar does not know about tensor types (it would
# create a circular import) , this method converts either a
# TensorVariable or a ScalarVariable to a scalar.
if isinstance(a, Variable) and isinstance(a.type, TensorType):
return pytensor.tensor.scalar_from_tensor(a)
else:
return ps.as_scalar(a)
def as_scalar_index_variable(idx) -> ps.ScalarVariable:
idx = ps.as_scalar(idx)
if idx.type.dtype not in integer_dtypes:
raise TypeError("basic indices must be integers")
return idx # type: ignore[no-any-return]
def slice_static_length(slc, dim_length):
......@@ -864,17 +731,71 @@ def slice_static_length(slc, dim_length):
return len(range(*slice(*entries).indices(dim_length)))
class Subtensor(COp):
class BaseSubtensor:
"""Base class for Subtensor operations that handles idx_list and hash/equality."""
def __init__(self, idx_list: Sequence[int | slice]):
index_counter = -1
for idx_entry in idx_list:
if isinstance(idx_entry, int):
if idx_entry != (index_counter + 1):
raise ValueError(
f"idx_list entries should have consecutive integers, got {idx_list}"
)
index_counter = idx_entry
elif isinstance(idx_entry, slice):
for slice_idx_entry in (
idx_entry.start,
idx_entry.stop,
idx_entry.step,
):
if slice_idx_entry is not None:
if not isinstance(slice_idx_entry, int):
raise ValueError(
f"idx_list slice entries must be None or integer, got {slice_idx_entry} in {idx_entry}"
)
if slice_idx_entry != (index_counter + 1):
raise ValueError(
f"idx_list entries should have consecutive integers, got {idx_list}"
)
index_counter = slice_idx_entry
else:
raise ValueError(
f"idx_list entries must be int or slice, got {idx_entry}"
)
self.n_index_vars = index_counter + 1
self.idx_list = tuple(idx_list)
def _hashable_idx_list(self):
"""Return a hashable version of idx_list (slices converted to tuples).
Slices are not hashable in Python < 3.12, so we convert them to tuples.
"""
return tuple(
(slice, entry.start, entry.stop, entry.step)
if isinstance(entry, slice)
else entry
for entry in self.idx_list
)
def __hash__(self):
# Temporary workaround: slices are hashable in Python 3.12+
props_values = tuple(
self._hashable_idx_list() if prop == "idx_list" else getattr(self, prop)
for prop in self.__props__
)
return hash((type(self), props_values))
class Subtensor(BaseSubtensor, COp):
"""Basic NumPy indexing operator."""
check_input = False
view_map = {0: [0]}
_f16_ok = True
__props__ = ("idx_list",)
def __init__(self, idx_list):
# TODO: Provide the type of `self.idx_list`
self.idx_list = tuple(map(index_vars_to_types, idx_list))
__hash__ = BaseSubtensor.__hash__
def make_node(self, x, *inputs):
"""
......@@ -887,23 +808,16 @@ class Subtensor(COp):
"""
x = as_tensor_variable(x)
inputs = tuple(as_nontensor_scalar(a) for a in inputs)
inputs = tuple(as_scalar_index_variable(a) for a in inputs)
idx_list = list(self.idx_list)
if len(idx_list) > x.type.ndim:
raise IndexError("too many indices for array")
input_types = get_slice_elements(
idx_list, lambda entry: isinstance(entry, Type)
input_positions = get_slice_elements(
idx_list, lambda entry: isinstance(entry, int)
)
assert len(inputs) == len(input_types)
for input, expected_type in zip(inputs, input_types, strict=True):
if not expected_type.is_super(input.type):
raise TypeError(
f"Incompatible types for Subtensor template. Expected {input.type}, got {expected_type}."
)
assert len(inputs) == len(input_positions)
padded = [
*indices_from_subtensor(inputs, self.idx_list),
......@@ -924,13 +838,10 @@ class Subtensor(COp):
def perform(self, node, inputs, out_):
(out,) = out_
x = inputs[0]
cdata = get_idx_list(inputs, self.idx_list)
if len(cdata) == 1:
cdata = cdata[0]
x, *index_variables = inputs
out[0] = np.asarray(x.__getitem__(cdata))
cdata = unflatten_index_variables(index_variables, self.idx_list)
out[0] = np.asarray(x.__getitem__(tuple(cdata)))
def infer_shape(self, fgraph, node, shapes):
def _is_constant(const, x):
......@@ -978,8 +889,7 @@ class Subtensor(COp):
def grad(self, inputs, grads):
(gz,) = grads
x = inputs[0]
rest = inputs[1:]
x, *index_variables = inputs
if x.dtype in discrete_dtypes:
first = x.zeros_like(dtype=config.floatX)
else:
......@@ -988,43 +898,28 @@ class Subtensor(COp):
# We have an optimization that will convert this to a
# set subtensor here at:
# pytensor/tensor/opt.py:local_incsubtensor_of_zeros_to_setsubtensor()
first = IncSubtensor(self.idx_list)(x.zeros_like(), gz, *rest)
return [first, *(disconnected_type() for _ in range(len(rest)))]
first = IncSubtensor(self.idx_list)(x.zeros_like(), gz, *index_variables)
return [first, *(disconnected_type() for _ in range(len(index_variables)))]
def connection_pattern(self, node):
rval = [[True], *([False] for _ in node.inputs[1:])]
_x, *index_variables = node.inputs
rval = [[True], *([False] for _ in index_variables)]
return rval
def __hash__(self):
msg = []
for entry in self.idx_list:
if isinstance(entry, slice):
msg += [(entry.start, entry.stop, entry.step)]
else:
msg += [entry]
idx_list = tuple(msg)
# backport
# idx_list = tuple((entry.start, entry.stop, entry.step)
# if isinstance(entry, slice)
# else entry
# for entry in self.idx_list)
return hash(idx_list)
@staticmethod
def str_from_slice(entry):
if entry.step:
if entry.step is not None:
return ":".join(
(
"start" if entry.start else "",
"stop" if entry.stop else "",
"start" if entry.start is not None else "",
"stop" if entry.stop is not None else "",
"step",
)
)
if entry.stop:
return f"{'start' if entry.start else ''}:stop"
if entry.start:
if entry.stop is not None:
return f"{'start' if entry.start is not None else ''}:stop"
if entry.start is not None:
return "start:"
return ":"
......@@ -1107,12 +1002,7 @@ class Subtensor(COp):
return pos[1]
def init_entry(entry, depth=0):
if isinstance(entry, np.integer | int):
init_cmds.append(f"subtensor_spec[{spec_pos()}] = {entry};")
inc_spec_pos(1)
if depth == 0:
is_slice.append(0)
elif isinstance(entry, Type):
if isinstance(entry, int):
init_cmds.append(
f"subtensor_spec[{spec_pos()}] = {inputs[input_pos()]};"
)
......@@ -1375,7 +1265,58 @@ class Subtensor(COp):
# (they should be defaulted to zeros_like by the global R_op)
if eval_points[0] is None:
return [None]
return self(eval_points[0], *inputs[1:], return_list=True)
_x, *index_variables = inputs
return self(eval_points[0], *index_variables, return_list=True)
def basic_subtensor(x, *index_variables):
idx_list, flat_index_vars = flatten_index_variables(index_variables)
return Subtensor(idx_list)(x, *flat_index_vars)
@_get_vector_length.register(Subtensor) # type: ignore
def _get_vector_length_Subtensor(op, var):
# If we take a slice, we know how many elements it will result in
# TODO: We can cover more `*Subtensor` cases.
try:
indices = get_idx_list(var.owner.inputs, var.owner.op.idx_list)
start = (
None
if indices[0].start is None
else get_scalar_constant_value(indices[0].start)
)
stop = (
None
if indices[0].stop is None
else get_scalar_constant_value(indices[0].stop)
)
step = (
None
if indices[0].step is None
else get_scalar_constant_value(indices[0].step)
)
if start == stop:
return 0
arg_len = get_vector_length(var.owner.inputs[0])
return len(range(*slice(start, stop, step).indices(arg_len)))
except (ValueError, NotScalarConstantError):
raise ValueError(f"Length of {var} cannot be determined")
@_vectorize_node.register(Subtensor)
def vectorize_subtensor(op: Subtensor, node, batch_x, *batch_idxs):
"""Rewrite subtensor with non-batched indexes as another Subtensor with prepended empty slices."""
# TODO: Vectorize Subtensor with non-slice batched indexes as AdvancedSubtensor
if any(batch_inp.type.ndim > 0 for batch_inp in batch_idxs):
return vectorize_node_fallback(op, node, batch_x, *batch_idxs)
old_x, *_ = node.inputs
batch_ndims = batch_x.type.ndim - old_x.type.ndim
new_idx_list = (slice(None),) * batch_ndims + op.idx_list
return Subtensor(new_idx_list).make_node(batch_x, *batch_idxs)
class SubtensorPrinter(Printer):
......@@ -1387,25 +1328,28 @@ class SubtensorPrinter(Printer):
input = inputs.pop(0)
sidxs = []
getattr(pstate, "precedence", None)
def process_slice_component(comp):
"""Process a slice component, returning string representation."""
if comp is None:
return ""
elif isinstance(comp, int):
with set_precedence(pstate):
return pstate.pprinter.process(inputs.pop(0))
else:
return str(comp)
for entry in idxs:
if isinstance(entry, ps.ScalarType):
if isinstance(entry, int):
with set_precedence(pstate):
sidxs.append(pstate.pprinter.process(inputs.pop()))
sidxs.append(pstate.pprinter.process(inputs.pop(0)))
elif isinstance(entry, slice):
if entry.start is None or entry.start == 0:
msg1 = ""
else:
msg1 = entry.start
if entry.stop is None or entry.stop == sys.maxsize:
msg2 = ""
else:
msg2 = entry.stop
msg1 = process_slice_component(entry.start)
msg2 = process_slice_component(entry.stop)
if entry.step is None:
msg3 = ""
else:
msg3 = f":{entry.step}"
msg3 = f":{process_slice_component(entry.step)}"
sidxs.append(f"{msg1}:{msg2}{msg3}")
......@@ -1418,322 +1362,83 @@ class SubtensorPrinter(Printer):
pprint.assign(Subtensor, SubtensorPrinter())
@_vectorize_node.register(Subtensor)
def vectorize_subtensor(op: Subtensor, node, batch_x, *batch_idxs):
"""Rewrite subtensor with non-batched indexes as another Subtensor with prepended empty slices."""
# TODO: Vectorize Subtensor with non-slice batched indexes as AdvancedSubtensor
if any(batch_inp.type.ndim > 0 for batch_inp in batch_idxs):
return vectorize_node_fallback(op, node, batch_x, *batch_idxs)
class IncSubtensor(BaseSubtensor, COp):
"""
Increment a subtensor.
old_x, *_ = node.inputs
batch_ndims = batch_x.type.ndim - old_x.type.ndim
new_idx_list = (slice(None),) * batch_ndims + op.idx_list
return Subtensor(new_idx_list).make_node(batch_x, *batch_idxs)
This is like numpy's
x[i,j,k] += y
def set_subtensor(x, y, inplace=False, tolerate_inplace_aliasing=False):
"""
Return x with the given subtensor overwritten by y.
It is used internally to implement the gradient on SubTensor.
Parameters
----------
x
Symbolic variable for the lvalue of = operation.
y
Symbolic variable for the rvalue of = operation.
tolerate_inplace_aliasing
See inc_subtensor for documentation.
Examples
--------
To replicate the numpy expression ``r[10:] = 5``, type
set_instead_of_inc
If True set the subtensor to the value instead of incrementing it by
that value.
.. code-block:: python
"""
from pytensor.tensor import set_subtensor, vector
check_input = False
__props__ = (
"idx_list",
"inplace",
"set_instead_of_inc",
"destroyhandler_tolerate_aliased",
)
__hash__ = BaseSubtensor.__hash__
r = vector("r")
new_r = set_subtensor(r[10:], 5)
def __init__(
self,
idx_list,
inplace=False,
set_instead_of_inc=False,
destroyhandler_tolerate_aliased=None,
):
if destroyhandler_tolerate_aliased is None:
destroyhandler_tolerate_aliased = ()
super().__init__(idx_list)
self.inplace = inplace
if inplace:
self.destroy_map = {0: [0]}
self.destroyhandler_tolerate_aliased = tuple(destroyhandler_tolerate_aliased)
self.set_instead_of_inc = set_instead_of_inc
Consider using :meth:`pytensor.tensor.variable.TensorVariable.set` instead.
def __str__(self):
name = "SetSubtensor" if self.set_instead_of_inc else "IncSubtensor"
return f"{name}{{{Subtensor.str_from_indices(self.idx_list)}}}"
"""
return inc_subtensor(
x,
y,
inplace,
set_instead_of_inc=True,
tolerate_inplace_aliasing=tolerate_inplace_aliasing,
)
def make_node(self, x, y, *inputs):
"""
Parameters
----------
x
The tensor to increment.
y
The value to increment by.
inputs
The indeces/slices list to increment in combination with idx_list.
E.g. self._idx_list = (0, slice(1, None, None), 2, slice(3, None, 4))
tell to use inputs[0] as the first dim.
"""
x, y = map(as_tensor_variable, [x, y])
if y.ndim > x.ndim:
raise ValueError(
f"Trying to increment a {int(x.ndim)}-dimensional "
f"subtensor with a {int(y.ndim)}-dimensional value."
)
inputs = tuple(map(as_scalar_index_variable, inputs))
def inc_subtensor(
x,
y,
inplace=False,
set_instead_of_inc=False,
tolerate_inplace_aliasing=False,
ignore_duplicates=False,
):
"""Update the value of an indexed array by a given amount.
idx_list = list(self.idx_list)
if len(idx_list) > x.type.ndim:
raise IndexError("too many indices for array")
This is equivalent to ``x[indices] += y`` or ``np.add.at(x, indices, y)``,
depending on the value of `ignore_duplicates`.
Parameters
----------
x
The symbolic result of a Subtensor operation.
y
The amount by which to increment the array.
inplace
Don't use. PyTensor will do in-place operations itself, when possible.
set_instead_of_inc
If True, do a set_subtensor instead.
tolerate_inplace_aliasing:
Allow `x` and `y` to be views of a single underlying array even while
working in-place. For correct results, `x` and `y` must not be overlapping
views; if they overlap, the result of this `Op` will generally be
incorrect. This value has no effect if ``inplace=False``.
ignore_duplicates
This determines whether ``x[indices] += y`` is used or
``np.add.at(x, indices, y)``.
Examples
--------
To replicate the expression ``r[10:] += 5``:
.. code-block:: python
from pytensor.tensor import ivector, inc_subtensor
r = ivector("r")
new_r = inc_subtensor(r[10:], 5)
To replicate the expression ``r[[0, 1, 0]] += 5``:
.. code-block:: python
r = ivector("r")
new_r = inc_subtensor(r[[0, 1, 0]], 5, ignore_duplicates=True)
Consider using :meth:`pytensor.tensor.variable.TensorVariable.inc` instead.
"""
# First of all, y cannot have a higher dimension than x,
# nor have non-broadcastable dimensions where x is broadcastable.
x = as_tensor_variable(x)
y = as_tensor_variable(y)
if y.ndim > x.ndim:
raise TypeError(
f"Trying to increment a {int(x.ndim)}-dimensional "
f"subtensor with a {int(y.ndim)}-dimensional value."
)
dim_offset = x.ndim - y.ndim
for dim in range(y.ndim):
if x.broadcastable[dim + dim_offset] and not y.broadcastable[dim]:
# It is acceptable to try to increment a subtensor with a
# broadcastable dim with a tensor that is not broadcastable
# on that dimension. However, its length must then be 1.
# We insert a SpecifyShape Op to make sure it is the case.
y = specify_broadcastable(y, dim)
if x.owner is None:
raise TypeError("x must be the result of a subtensor operation")
# retrieve idx_list from x.owner
if isinstance(x.owner.op, Subtensor):
if tolerate_inplace_aliasing:
destroyhandler_tolerate_aliased = [[0, 1]]
else:
destroyhandler_tolerate_aliased = []
the_op = IncSubtensor(
x.owner.op.idx_list,
inplace,
set_instead_of_inc,
destroyhandler_tolerate_aliased=destroyhandler_tolerate_aliased,
)
real_x = x.owner.inputs[0]
real_idxargs = x.owner.inputs[1:]
return the_op(real_x, y, *real_idxargs)
elif isinstance(x.owner.op, AdvancedSubtensor1):
real_x = x.owner.inputs[0]
ilist = x.owner.inputs[1]
if ignore_duplicates:
the_op = AdvancedIncSubtensor(
inplace, set_instead_of_inc=set_instead_of_inc, ignore_duplicates=True
)
else:
the_op = AdvancedIncSubtensor1(
inplace, set_instead_of_inc=set_instead_of_inc
)
return the_op(real_x, y, ilist)
elif isinstance(x.owner.op, AdvancedSubtensor):
real_x = x.owner.inputs[0]
ilist = x.owner.inputs[1:]
the_op = AdvancedIncSubtensor(
inplace,
set_instead_of_inc=set_instead_of_inc,
ignore_duplicates=ignore_duplicates,
)
return the_op(real_x, y, *ilist)
elif isinstance(x.owner.op, DimShuffle):
inner_x = x.owner.inputs[0]
# In the dimshuffle case, there are in fact two dimshuffles:
# one to make the indexed dimension the last one,
# and one to put it back where it was. So, in the case where we have
# inc_subtensor(x[:,i], y), the graph is actually
# inc_subtensor((x.T)[i].T, y).
# We could get all the way to x, and then get rid of the dimshuffles
# completely, but the problem is that advanced_inc_subtensor1 can only
# work on the first (outer-most, left-most) dimension of x,
# just like advanced_subtensor1.
# So we call advanced_inc_subtensor1(x.T, i, y.T) (as we also need to
# transpose y if it is not a scalar or a vector), but then we need to
# return something that has the same shape as x, not as x.T (inner_x).
# So re-apply the outer dimshuffle on the new inc_subtensor,
# and return advanced_inc_subtensor1(x.T, i, y.T).T.
# Get the dimshuffle pattern to apply to y.
x_order = x.owner.op.new_order
y_order = ["x"] * x.ndim
for i, v in enumerate(x_order):
if v != "x" and (v - dim_offset) >= 0:
y_order[v - dim_offset] = i
inner_incsubtensor = inc_subtensor(
inner_x,
y.dimshuffle(y_order),
inplace=inplace,
set_instead_of_inc=set_instead_of_inc,
tolerate_inplace_aliasing=tolerate_inplace_aliasing,
ignore_duplicates=ignore_duplicates,
)
# The broadcastable pattern of inner_x may not be the same as
# the one of x, so we have to build a new dimshuffle here,
# instead of reusing x.owner.op().
return inner_incsubtensor.dimshuffle(x.owner.op.new_order)
elif isinstance(x.owner.op, Reshape):
# This case happens when the indices are not arranged as a vector, but
# as a higher-dimensional array. This is handled by the subtensor
# by flattening this list, taking the subtensor, then reshaping the
# result.
inner_x = x.owner.inputs[0]
# Try to apply inc_subtensor on inner_x.
# If it works, there is no need to reshape, as the inc_subtensor
# will have the same shape as inner_x, which is what we want.
# We also explicitly duplicate y to its broadcasted shape
# before we partially flatten it to inner_x dimension. This is
# not strictly needed in all cases, but it is easier this way.
if y.ndim > 0:
# This if is needed to prevent some useless warning about
# old code bug.
expanded_y = alloc(y, *[x.shape[i] for i in range(x.ndim)])
flattened_y = expanded_y.reshape(inner_x.shape)
else:
flattened_y = y
inner_incsubtensor = inc_subtensor(
inner_x,
flattened_y,
inplace=inplace,
set_instead_of_inc=set_instead_of_inc,
tolerate_inplace_aliasing=tolerate_inplace_aliasing,
ignore_duplicates=ignore_duplicates,
)
return inner_incsubtensor
else:
raise TypeError("x must be the result of a subtensor operation")
class IncSubtensor(COp):
"""
Increment a subtensor.
This is like numpy's
x[i,j,k] += y
It is used internally to implement the gradient on SubTensor.
Parameters
----------
set_instead_of_inc
If True set the subtensor to the value instead of incrementing it by
that value.
"""
check_input = False
__props__ = ("idx_list", "inplace", "set_instead_of_inc")
def __init__(
self,
idx_list,
inplace=False,
set_instead_of_inc=False,
destroyhandler_tolerate_aliased=None,
):
if destroyhandler_tolerate_aliased is None:
destroyhandler_tolerate_aliased = []
self.idx_list = list(map(index_vars_to_types, idx_list))
self.inplace = inplace
if inplace:
self.destroy_map = {0: [0]}
self.destroyhandler_tolerate_aliased = list(destroyhandler_tolerate_aliased)
self.set_instead_of_inc = set_instead_of_inc
def __hash__(self):
idx_list = tuple(
(entry.start, entry.stop, entry.step) if isinstance(entry, slice) else entry
for entry in self.idx_list
)
return hash((type(self), idx_list, self.inplace, self.set_instead_of_inc))
def __str__(self):
name = "SetSubtensor" if self.set_instead_of_inc else "IncSubtensor"
return f"{name}{{{Subtensor.str_from_indices(self.idx_list)}}}"
def make_node(self, x, y, *inputs):
"""
Parameters
----------
x
The tensor to increment.
y
The value to increment by.
inputs: TODO WRITEME
"""
x, y = map(as_tensor_variable, [x, y])
if y.ndim > x.ndim:
if len(inputs) != self.n_index_vars:
raise ValueError(
f"Trying to increment a {int(x.ndim)}-dimensional "
f"subtensor with a {int(y.ndim)}-dimensional value."
)
inputs = tuple(map(as_nontensor_scalar, inputs))
idx_list = list(self.idx_list)
if len(idx_list) > x.type.ndim:
raise IndexError("too many indices for array")
input_types = get_slice_elements(
idx_list, lambda entry: isinstance(entry, Type)
)
if len(inputs) != len(input_types):
raise IndexError(
"Not enough inputs to fill in the Subtensor template.", inputs, idx_list
)
for input, expected_type in zip(inputs, input_types, strict=True):
if not expected_type.is_super(input.type):
raise TypeError(
f"Wrong type for Subtensor template. Expected {input.type}, got {expected_type}."
)
return Apply(self, (x, y, *inputs), [x.type()])
......@@ -1747,7 +1452,7 @@ class IncSubtensor(COp):
indices = tuple(
(
next(flat_indices_iterator)
if isinstance(entry, Type)
if isinstance(entry, int)
else slice(
None if entry.start is None else next(flat_indices_iterator),
None if entry.stop is None else next(flat_indices_iterator),
......@@ -1992,17 +1697,18 @@ class IncSubtensor(COp):
return [None]
# Again we ignore eval points for indices because incsubtensor is
# not differentiable wrt to those
return self(eval_points[0], eval_points[1], *inputs[2:], return_list=True)
_x, _y, *index_variables = inputs
return self(eval_points[0], eval_points[1], *index_variables, return_list=True)
def connection_pattern(self, node):
rval = [[True], [True], *([False] for _ in node.inputs[2:])]
_x, _y, *index_variables = node.inputs
rval = [[True], [True], *([False] for _ in index_variables)]
return rval
def grad(self, inputs, grads):
(g_output,) = grads
x, y = inputs[:2]
idx_list = inputs[2:]
x, y, *index_variables = inputs
if x.dtype in discrete_dtypes:
# The output dtype is the same as x
......@@ -2016,25 +1722,25 @@ class IncSubtensor(COp):
else:
if self.set_instead_of_inc:
gx = set_subtensor(
Subtensor(idx_list=self.idx_list)(g_output, *idx_list),
Subtensor(idx_list=self.idx_list)(g_output, *index_variables),
pytensor.tensor.zeros_like(y),
)
else:
gx = g_output
gy = Subtensor(idx_list=self.idx_list)(g_output, *idx_list)
gy = Subtensor(idx_list=self.idx_list)(g_output, *index_variables)
gy = _sum_grad_over_bcasted_dims(y, gy)
return [gx, gy, *(disconnected_type() for _ in range(len(idx_list)))]
return [gx, gy, *(disconnected_type() for _ in range(len(index_variables)))]
class IncSubtensorPrinter(SubtensorPrinter):
def process(self, r, pstate):
x, _y, *idx_args = r.owner.inputs
x, y, *index_variables = r.owner.inputs
res = self._process(r.owner.op.idx_list, [x, *idx_args], pstate)
res = self._process(r.owner.op.idx_list, [x, *index_variables], pstate)
with set_precedence(pstate, 1000):
y_str = pstate.pprinter.process(r.owner.inputs[1], pstate)
y_str = pstate.pprinter.process(y, pstate)
if r.owner.op.set_instead_of_inc:
res = f"set_subtensor({res}, {y_str})"
......@@ -2095,9 +1801,13 @@ class AdvancedSubtensor1(COp):
# sparse_grad doesn't go in here since it only affects the output
# of the grad() method.
__props__ = ()
idx_list = (0,)
_f16_ok = True
check_input = False
def __hash__(self):
return hash(type(self))
def __init__(self, sparse_grad=False):
self.sparse_grad = sparse_grad
......@@ -2121,7 +1831,8 @@ class AdvancedSubtensor1(COp):
output_storage[0][0] = x.take(i, axis=0, out=None)
def connection_pattern(self, node):
rval = [[True], *([False] for _ in node.inputs[1:])]
_x, *index_variables = node.inputs
rval = [[True], *([False] for _ in index_variables)]
return rval
......@@ -2151,7 +1862,8 @@ class AdvancedSubtensor1(COp):
def R_op(self, inputs, eval_points):
if eval_points[0] is None:
return [None]
return self.make_node(eval_points[0], *inputs[1:]).outputs
_x, *index_variables = inputs
return self.make_node(eval_points[0], *index_variables).outputs
def infer_shape(self, fgraph, node, ishapes):
x, ilist = ishapes
......@@ -2245,13 +1957,17 @@ class AdvancedSubtensor1(COp):
advanced_subtensor1 = AdvancedSubtensor1()
class AdvancedIncSubtensor1(COp):
class AdvancedIncSubtensor1(BaseSubtensor, COp):
"""
Increments a subtensor using advanced slicing (list of index).
"""
__props__ = ("inplace", "set_instead_of_inc")
__props__ = (
"inplace",
"set_instead_of_inc",
)
idx_list = (0,)
check_input = False
params_type = ParamsType(inplace=ps.bool, set_instead_of_inc=ps.bool)
......@@ -2267,8 +1983,20 @@ class AdvancedIncSubtensor1(COp):
if inplace:
self.destroy_map = {0: [0]}
def __hash__(self):
return hash(
(
type(self),
self.inplace,
self.set_instead_of_inc,
)
)
def clone_inplace(self):
return self.__class__(inplace=True, set_instead_of_inc=self.set_instead_of_inc)
return self.__class__(
inplace=True,
set_instead_of_inc=self.set_instead_of_inc,
)
def __str__(self):
if self.inplace:
......@@ -2494,7 +2222,8 @@ class AdvancedIncSubtensor1(COp):
def R_op(self, inputs, eval_points):
if None in eval_points[:2]:
return [None]
return self.make_node(eval_points[0], eval_points[1], *inputs[2:]).outputs
_x, _y, *index_variables = inputs
return self.make_node(eval_points[0], eval_points[1], *index_variables).outputs
def connection_pattern(self, node):
rval = [[True], [True], [False]]
......@@ -2527,15 +2256,8 @@ advanced_inc_subtensor1 = AdvancedIncSubtensor1()
advanced_set_subtensor1 = AdvancedIncSubtensor1(set_instead_of_inc=True)
def as_index_variable(idx):
if idx is None:
return NoneConst.clone()
if isinstance(idx, slice):
return make_slice(idx)
if isinstance(idx, Variable) and isinstance(idx.type, SliceType):
return idx
if isinstance(idx, Variable) and isinstance(idx.type, NoneTypeT):
return idx
def as_tensor_index_variable(idx):
"""Convert index to Variable form for advanced indexing."""
idx = as_tensor_variable(idx)
if idx.type.dtype not in discrete_dtypes:
raise TypeError("index must be integers or a boolean mask")
......@@ -2547,53 +2269,45 @@ def as_index_variable(idx):
return idx
def check_advanced_indexing_dimensions(input, idx_list):
"""
This function checks if the index list in idx_list is correct.
If there are any boolean masks, we check if the mask has the
same shape as the input. This is enforced in NumPy 0.13.0 and
newer, but not by earlier versions. If the size is not the same,
this method raises an IndexError.
"""
dim_seen = 0
for index in idx_list:
if index is np.newaxis:
# skip, does not count as an input dimension
pass
elif isinstance(index, np.ndarray) and index.dtype == "bool":
for i in range(index.ndim):
if index.shape[i] != input.shape[dim_seen + i]:
raise IndexError(
"boolean index did not match indexed array "
f"along dimension {int(dim_seen + i)}; dimension is "
f"{int(input.shape[dim_seen + i])} but "
f"corresponding boolean dimension is {index.shape[i]}"
)
dim_seen += index.ndim
else:
dim_seen += 1
class AdvancedSubtensor(Op):
class AdvancedSubtensor(BaseSubtensor, COp):
"""Implements NumPy's advanced indexing."""
__props__ = ()
def make_node(self, x, *indices):
x = as_tensor_variable(x)
indices = tuple(map(as_index_variable, indices))
__props__ = ("idx_list",)
__hash__ = BaseSubtensor.__hash__
def c_code_cache_version(self):
hv = Subtensor.helper_c_code_cache_version()
if hv:
return (3, hv)
else:
return ()
def make_node(self, x, *index_variables):
if len(index_variables) != self.n_index_vars:
raise ValueError(
f"Expected {self.n_index_vars} inputs, got {len(index_variables)}"
)
x = as_tensor_variable(x)
index_variables = tuple(as_tensor_index_variable(a) for a in index_variables)
idx_list = self.idx_list
if len(idx_list) > x.type.ndim:
raise IndexError("too many indices for array")
reconstructed_indices = unflatten_index_variables(index_variables, idx_list)
explicit_indices = []
new_axes = []
for idx in indices:
if isinstance(idx.type, TensorType) and idx.dtype == "bool":
for idx in reconstructed_indices:
if isinstance(idx, slice):
explicit_indices.append(idx)
elif hasattr(idx, "dtype") and idx.dtype == "bool":
if idx.type.ndim == 0:
raise NotImplementedError(
"Indexing with scalar booleans not supported"
)
# Check static shape aligned
axis = len(explicit_indices) - len(new_axes)
axis = len(explicit_indices)
indexed_shape = x.type.shape[axis : axis + idx.type.ndim]
for j, (indexed_length, indexer_length) in enumerate(
zip(indexed_shape, idx.type.shape)
......@@ -2611,48 +2325,27 @@ class AdvancedSubtensor(Op):
if isinstance(idx, Constant):
nonzero_indices = [tensor_constant(i) for i in idx.data.nonzero()]
else:
# Note: Sometimes we could infer a shape error by reasoning about the largest possible size of nonzero
# and seeing that other integer indices cannot possible match it
nonzero_indices = idx.nonzero()
explicit_indices.extend(nonzero_indices)
else:
if isinstance(idx.type, NoneTypeT):
new_axes.append(len(explicit_indices))
explicit_indices.append(idx)
if (len(explicit_indices) - len(new_axes)) > x.type.ndim:
if len(explicit_indices) > x.type.ndim:
raise IndexError(
f"too many indices for array: tensor is {x.type.ndim}-dimensional, but {len(explicit_indices) - len(new_axes)} were indexed"
f"too many indices for array: tensor is {x.type.ndim}-dimensional, but {len(explicit_indices)} were indexed"
)
# Perform basic and advanced indexing shape inference separately
# Perform basic and advanced indexing shape inference separately (no newaxis)
basic_group_shape = []
advanced_indices = []
adv_group_axis = None
last_adv_group_axis = None
if new_axes:
expanded_x_shape_list = list(x.type.shape)
for new_axis in new_axes:
expanded_x_shape_list.insert(new_axis, 1)
expanded_x_shape = tuple(expanded_x_shape_list)
else:
expanded_x_shape = x.type.shape
for i, (idx, dim_length) in enumerate(
zip_longest(explicit_indices, expanded_x_shape, fillvalue=NoneSliceConst)
zip_longest(explicit_indices, x.type.shape, fillvalue=slice(None))
):
if isinstance(idx.type, NoneTypeT):
basic_group_shape.append(1) # New-axis
elif isinstance(idx.type, SliceType):
if isinstance(idx, Constant):
basic_group_shape.append(slice_static_length(idx.data, dim_length))
elif idx.owner is not None and isinstance(idx.owner.op, MakeSlice):
basic_group_shape.append(
slice_static_length(slice(*idx.owner.inputs), dim_length)
)
else:
# Symbolic root slice (owner is None), or slice operation we don't understand
basic_group_shape.append(None)
else: # TensorType
if isinstance(idx, slice):
basic_group_shape.append(slice_static_length(idx, dim_length))
else: # TensorType (advanced index)
# Keep track of advanced group axis
if adv_group_axis is None:
# First time we see an advanced index
......@@ -2687,14 +2380,15 @@ class AdvancedSubtensor(Op):
return Apply(
self,
[x, *indices],
[x, *index_variables],
[tensor(dtype=x.type.dtype, shape=tuple(indexed_shape))],
)
def R_op(self, inputs, eval_points):
if eval_points[0] is None:
return [None]
return self.make_node(eval_points[0], *inputs[1:]).outputs
_x, *index_variables = inputs
return self.make_node(eval_points[0], *index_variables).outputs
def infer_shape(self, fgraph, node, ishapes):
def is_bool_index(idx):
......@@ -2703,30 +2397,32 @@ class AdvancedSubtensor(Op):
or getattr(idx, "dtype", None) == "bool"
)
indices = node.inputs[1:]
_x, *index_variables = node.inputs
full_indices = unflatten_index_variables(index_variables, self.idx_list)
index_shapes = []
for idx, ishape in zip(indices, ishapes[1:], strict=True):
# Mixed bool indexes are converted to nonzero entries
shape0_op = Shape_i(0)
if is_bool_index(idx):
index_shapes.extend((shape0_op(nz_dim),) for nz_dim in nonzero(idx))
# The `ishapes` entries for `SliceType`s will be None, and
# we need to give `indexed_result_shape` the actual slices.
elif isinstance(getattr(idx, "type", None), SliceType):
for idx in full_indices:
if isinstance(idx, slice):
index_shapes.append(idx)
else:
index_shapes.append(ishape)
shape0_op = Shape_i(0)
if is_bool_index(idx):
index_shapes.extend((shape0_op(nz_dim),) for nz_dim in nonzero(idx))
else:
input_shape_idx = (
index_variables.index(idx) + 1
) # +1 because ishapes[0] is x
index_shapes.append(ishapes[input_shape_idx])
res_shape = list(
indexed_result_shape(ishapes[0], index_shapes, indices_are_shapes=True)
)
for i, res_dim_length in enumerate(res_shape):
if res_dim_length is None:
# This can happen when we have a Slice provided by the user (not a constant nor the result of MakeSlice)
# We must compute the Op to find its shape
res_shape[i] = Shape_i(i)(node.out)
adv_indices = [idx for idx in indices if not is_basic_idx(idx)]
adv_indices = [idx for idx in full_indices if not isinstance(idx, slice)]
bool_indices = [idx for idx in adv_indices if is_bool_index(idx)]
# Special logic when the only advanced index group is of bool type.
......@@ -2737,7 +2433,7 @@ class AdvancedSubtensor(Op):
# Because there are no more advanced index groups, there is exactly
# one output dim per index variable up to the bool group.
# Note: Scalar integer indexing counts as advanced indexing.
start_dim = indices.index(bool_index)
start_dim = full_indices.index(bool_index)
res_shape[start_dim] = bool_index.sum()
assert node.outputs[0].ndim == len(res_shape)
......@@ -2745,25 +2441,31 @@ class AdvancedSubtensor(Op):
def perform(self, node, inputs, out_):
(out,) = out_
check_advanced_indexing_dimensions(inputs[0], inputs[1:])
rval = inputs[0].__getitem__(tuple(inputs[1:]))
x, *index_variables = inputs
full_indices = unflatten_index_variables(index_variables, self.idx_list)
rval = x.__getitem__(tuple(full_indices))
# When there are no arrays, we are not actually doing advanced
# indexing, so __getitem__ will not return a copy.
# Since no view_map is set, we need to copy the returned value
if not any(
isinstance(v.type, TensorType) and v.ndim > 0 for v in node.inputs[1:]
isinstance(idx, np.ndarray) and idx.ndim > 0 for idx in full_indices
):
rval = rval.copy()
out[0] = rval
def connection_pattern(self, node):
rval = [[True], *([False] for _ in node.inputs[1:])]
_x, *index_variables = node.inputs
rval = [[True], *([False] for _ in index_variables)]
return rval
def grad(self, inputs, grads):
(gz,) = grads
x = inputs[0]
x, *index_variables = inputs
if x.dtype in discrete_dtypes:
# The output dtype is the same as x
gx = x.zeros_like(dtype=config.floatX)
......@@ -2771,10 +2473,10 @@ class AdvancedSubtensor(Op):
raise NotImplementedError("No support for complex grad yet")
else:
gx = x.zeros_like()
rest = inputs[1:]
return [
advanced_inc_subtensor(gx, gz, *rest),
*(disconnected_type() for _ in range(len(rest))),
AdvancedIncSubtensor(self.idx_list)(gx, gz, *index_variables),
*(disconnected_type() for _ in range(len(index_variables))),
]
@staticmethod
......@@ -2791,7 +2493,7 @@ class AdvancedSubtensor(Op):
This function checks if the advanced indexing is non-consecutive,
in which case the advanced index dimensions are placed on the left of the
output array, regardless of their opriginal position.
output array, regardless of their original position.
See: https://numpy.org/doc/stable/user/basics.indexing.html#combining-advanced-and-basic-indexing
......@@ -2806,11 +2508,21 @@ class AdvancedSubtensor(Op):
bool
True if the advanced indexing is non-consecutive, False otherwise.
"""
_, *idxs = node.inputs
return _non_consecutive_adv_indexing(idxs)
indices = indices_from_subtensor(node.inputs[1:], node.op.idx_list)
return _non_consecutive_adv_indexing(indices)
class AdvancedSubtensorPrinter(SubtensorPrinter):
def process(self, r, pstate):
return self._process(r.owner.op.idx_list, r.owner.inputs, pstate)
pprint.assign(AdvancedSubtensor, AdvancedSubtensorPrinter())
advanced_subtensor = AdvancedSubtensor()
def advanced_subtensor(x, *index_variables):
idx_list, flat_index_vars = flatten_index_variables(index_variables)
return AdvancedSubtensor(idx_list)(x, *flat_index_vars)
@_vectorize_node.register(AdvancedSubtensor)
......@@ -2830,30 +2542,33 @@ def vectorize_advanced_subtensor(op: AdvancedSubtensor, node, *batch_inputs):
# which would put the indexed results to the left of the batch dimensions!
# TODO: Not all cases must be handled by Blockwise, but the logic is complex
# Blockwise doesn't accept None or Slices types so we raise informative error here
# TODO: Implement these internally, so Blockwise is always a safe fallback
if any(not isinstance(idx, TensorVariable) for idx in idxs):
raise NotImplementedError(
"Vectorized AdvancedSubtensor with batched indexes or non-consecutive advanced indexing "
"and slices or newaxis is currently not supported."
)
else:
return vectorize_node_fallback(op, node, batch_x, *batch_idxs)
return vectorize_node_fallback(op, node, batch_x, *batch_idxs)
# Otherwise we just need to add None slices for every new batch dim
x_batch_ndim = batch_x.type.ndim - x.type.ndim
empty_slices = (slice(None),) * x_batch_ndim
return op.make_node(batch_x, *empty_slices, *batch_idxs)
new_idx_list = (slice(None),) * x_batch_ndim + op.idx_list
return type(op)(new_idx_list).make_node(batch_x, *batch_idxs)
class AdvancedIncSubtensor(Op):
class AdvancedIncSubtensor(BaseSubtensor, Op):
"""Increments a subtensor using advanced indexing."""
__props__ = ("inplace", "set_instead_of_inc", "ignore_duplicates")
__props__ = (
"idx_list",
"inplace",
"set_instead_of_inc",
"ignore_duplicates",
)
__hash__ = BaseSubtensor.__hash__
def __init__(
self, inplace=False, set_instead_of_inc=False, ignore_duplicates=False
self,
idx_list,
inplace=False,
set_instead_of_inc=False,
ignore_duplicates=False,
):
super().__init__(idx_list)
self.set_instead_of_inc = set_instead_of_inc
self.inplace = inplace
if inplace:
......@@ -2867,25 +2582,27 @@ class AdvancedIncSubtensor(Op):
else "AdvancedIncSubtensor"
)
def make_node(self, x, y, *inputs):
def make_node(self, x, y, *index_variables):
if len(index_variables) != self.n_index_vars:
raise ValueError(
f"Expected {self.n_index_vars} tensor inputs but got {len(index_variables)}"
)
index_variables = tuple(
as_tensor_index_variable(idx) for idx in index_variables
)
x = as_tensor_variable(x)
y = as_tensor_variable(y)
new_inputs = []
for inp in inputs:
if isinstance(inp, list | tuple):
inp = as_tensor_variable(inp)
new_inputs.append(inp)
return Apply(
self,
(x, y, *new_inputs),
[x, y, *index_variables],
[x.type()],
)
def perform(self, node, inputs, out_):
x, y, *indices = inputs
x, y, *index_variables = inputs
check_advanced_indexing_dimensions(x, indices)
full_indices = unflatten_index_variables(index_variables, self.idx_list)
(out,) = out_
if not self.inplace:
......@@ -2894,28 +2611,29 @@ class AdvancedIncSubtensor(Op):
out[0] = x
if self.set_instead_of_inc:
out[0][tuple(indices)] = y
out[0][tuple(full_indices)] = y
elif self.ignore_duplicates:
out[0][tuple(indices)] += y
out[0][tuple(full_indices)] += y
else:
np.add.at(out[0], tuple(indices), y)
np.add.at(out[0], tuple(full_indices), y)
def infer_shape(self, fgraph, node, ishapes):
return [ishapes[0]]
def connection_pattern(self, node):
rval = [[True], [True], *([False] for _ in node.inputs[2:])]
_x, _y, *index_variables = node.inputs
rval = [[True], [True], *([False] for _ in index_variables)]
return rval
def R_op(self, inputs, eval_points):
if None in eval_points[:2]:
return [None]
return self.make_node(eval_points[0], eval_points[1], *inputs[2:]).outputs
_x, _y, *index_variables = inputs
return self.make_node(eval_points[0], eval_points[1], *index_variables).outputs
def grad(self, inpt, output_gradients):
x, y = inpt[:2]
idxs = inpt[2:]
x, y, *index_variables = inpt
(outgrad,) = output_gradients
if x.dtype in discrete_dtypes:
# The output dtype is the same as x
......@@ -2928,21 +2646,22 @@ class AdvancedIncSubtensor(Op):
raise NotImplementedError("No support for complex grad yet")
else:
if self.set_instead_of_inc:
gx = advanced_set_subtensor(outgrad, y.zeros_like(), *idxs)
gx = (
type(self)(self.idx_list, set_instead_of_inc=True)
.make_node(outgrad, y.zeros_like(), *index_variables)
.outputs[0]
)
else:
gx = outgrad
gy = advanced_subtensor(outgrad, *idxs)
gy = (
AdvancedSubtensor(self.idx_list)
.make_node(outgrad, *index_variables)
.outputs[0]
)
# Make sure to sum gy over the dimensions of y that have been
# added or broadcasted
gy = _sum_grad_over_bcasted_dims(y, gy)
return [gx, gy, *(disconnected_type() for _ in range(len(idxs)))]
@staticmethod
def non_contiguous_adv_indexing(node: Apply) -> bool:
warnings.warn(
"Method was renamed to `non_consecutive_adv_indexing`", FutureWarning
)
return AdvancedIncSubtensor.non_consecutive_adv_indexing(node)
return [gx, gy, *(disconnected_type() for _ in range(len(index_variables)))]
@staticmethod
def non_consecutive_adv_indexing(node: Apply) -> bool:
......@@ -2951,7 +2670,7 @@ class AdvancedIncSubtensor(Op):
This function checks if the advanced indexing is non-consecutive,
in which case the advanced index dimensions are placed on the left of the
output array, regardless of their opriginal position.
output array, regardless of their original position.
See: https://numpy.org/doc/stable/user/basics.indexing.html#combining-advanced-and-basic-indexing
......@@ -2966,16 +2685,257 @@ class AdvancedIncSubtensor(Op):
bool
True if the advanced indexing is non-consecutive, False otherwise.
"""
_, _, *idxs = node.inputs
return _non_consecutive_adv_indexing(idxs)
indices = indices_from_subtensor(node.inputs[2:], node.op.idx_list)
return _non_consecutive_adv_indexing(indices)
advanced_inc_subtensor = AdvancedIncSubtensor()
advanced_set_subtensor = AdvancedIncSubtensor(set_instead_of_inc=True)
advanced_inc_subtensor_nodup = AdvancedIncSubtensor(ignore_duplicates=True)
advanced_set_subtensor_nodup = AdvancedIncSubtensor(
set_instead_of_inc=True, ignore_duplicates=True
)
def advanced_inc_subtensor(x, y, *args, **kwargs):
idx_list, flat_index_vars = flatten_index_variables(args)
return AdvancedIncSubtensor(idx_list, **kwargs)(x, y, *flat_index_vars)
def advanced_set_subtensor(x, y, *args, **kwargs):
return advanced_inc_subtensor(x, y, *args, set_instead_of_inc=True, **kwargs)
class AdvancedIncSubtensorPrinter(SubtensorPrinter):
def process(self, r, pstate):
x, y, *index_variables = r.owner.inputs
res = self._process(r.owner.op.idx_list, [x, *index_variables], pstate)
with set_precedence(pstate, 1000):
y_str = pstate.pprinter.process(y, pstate)
if r.owner.op.set_instead_of_inc:
res = f"set_subtensor({res}, {y_str})"
else:
res = f"inc_subtensor({res}, {y_str})"
return res
pprint.assign(AdvancedIncSubtensor, AdvancedIncSubtensorPrinter())
def set_subtensor(x, y, inplace=False, tolerate_inplace_aliasing=False):
"""
Return x with the given subtensor overwritten by y.
Parameters
----------
x
Symbolic variable for the lvalue of = operation.
y
Symbolic variable for the rvalue of = operation.
tolerate_inplace_aliasing
See inc_subtensor for documentation.
Examples
--------
To replicate the numpy expression ``r[10:] = 5``, type
.. code-block:: python
from pytensor.tensor import set_subtensor, vector
r = vector("r")
new_r = set_subtensor(r[10:], 5)
Consider using :meth:`pytensor.tensor.variable.TensorVariable.set` instead.
"""
return inc_subtensor(
x,
y,
inplace,
set_instead_of_inc=True,
tolerate_inplace_aliasing=tolerate_inplace_aliasing,
)
def inc_subtensor(
x,
y,
inplace=False,
set_instead_of_inc=False,
tolerate_inplace_aliasing=False,
ignore_duplicates=False,
):
"""Update the value of an indexed array by a given amount.
This is equivalent to ``x[indices] += y`` or ``np.add.at(x, indices, y)``,
depending on the value of `ignore_duplicates`.
Parameters
----------
x
The symbolic result of a Subtensor operation.
y
The amount by which to increment the array.
inplace
Don't use. PyTensor will do in-place operations itself, when possible.
set_instead_of_inc
If True, do a set_subtensor instead.
tolerate_inplace_aliasing:
Allow `x` and `y` to be views of a single underlying array even while
working in-place. For correct results, `x` and `y` must not be overlapping
views; if they overlap, the result of this `Op` will generally be
incorrect. This value has no effect if ``inplace=False``.
ignore_duplicates
This determines whether ``x[indices] += y`` is used or
``np.add.at(x, indices, y)``.
Examples
--------
To replicate the expression ``r[10:] += 5``:
.. code-block:: python
from pytensor.tensor import ivector, inc_subtensor
r = ivector("r")
new_r = inc_subtensor(r[10:], 5)
To replicate the expression ``r[[0, 1, 0]] += 5``:
.. code-block:: python
r = ivector("r")
new_r = inc_subtensor(r[[0, 1, 0]], 5, ignore_duplicates=True)
Consider using :meth:`pytensor.tensor.variable.TensorVariable.inc` instead.
"""
# First of all, y cannot have a higher dimension than x,
# nor have non-broadcastable dimensions where x is broadcastable.
x = as_tensor_variable(x)
y = as_tensor_variable(y)
if y.ndim > x.ndim:
raise TypeError(
f"Trying to increment a {int(x.ndim)}-dimensional "
f"subtensor with a {int(y.ndim)}-dimensional value."
)
dim_offset = x.ndim - y.ndim
for dim in range(y.ndim):
if x.broadcastable[dim + dim_offset] and not y.broadcastable[dim]:
# It is acceptable to try to increment a subtensor with a
# broadcastable dim with a tensor that is not broadcastable
# on that dimension. However, its length must then be 1.
# We insert a SpecifyShape Op to make sure it is the case.
y = specify_broadcastable(y, dim)
if x.owner is None:
raise TypeError("x must be the result of a subtensor operation")
# retrieve idx_list from x.owner
if isinstance(x.owner.op, Subtensor):
if tolerate_inplace_aliasing:
destroyhandler_tolerate_aliased = [[0, 1]]
else:
destroyhandler_tolerate_aliased = []
the_op = IncSubtensor(
x.owner.op.idx_list,
inplace,
set_instead_of_inc,
destroyhandler_tolerate_aliased=destroyhandler_tolerate_aliased,
)
real_x, *index_variables = x.owner.inputs
return the_op(real_x, y, *index_variables)
elif isinstance(x.owner.op, AdvancedSubtensor1):
real_x = x.owner.inputs[0]
ilist = x.owner.inputs[1]
if ignore_duplicates:
the_op = AdvancedIncSubtensor(
(0,),
inplace,
set_instead_of_inc=set_instead_of_inc,
ignore_duplicates=True,
)
else:
the_op = AdvancedIncSubtensor1(
inplace, set_instead_of_inc=set_instead_of_inc
)
return the_op(real_x, y, ilist)
elif isinstance(x.owner.op, AdvancedSubtensor):
real_x, *index_variables = x.owner.inputs
the_op = AdvancedIncSubtensor(
x.owner.op.idx_list,
inplace,
set_instead_of_inc=set_instead_of_inc,
ignore_duplicates=ignore_duplicates,
)
return the_op(real_x, y, *index_variables)
elif isinstance(x.owner.op, DimShuffle):
inner_x = x.owner.inputs[0]
# In the dimshuffle case, there are in fact two dimshuffles:
# one to make the indexed dimension the last one,
# and one to put it back where it was. So, in the case where we have
# inc_subtensor(x[:,i], y), the graph is actually
# inc_subtensor((x.T)[i].T, y).
# We could get all the way to x, and then get rid of the dimshuffles
# completely, but the problem is that advanced_inc_subtensor1 can only
# work on the first (outer-most, left-most) dimension of x,
# just like advanced_subtensor1.
# So we call advanced_inc_subtensor1(x.T, i, y.T) (as we also need to
# transpose y if it is not a scalar or a vector), but then we need to
# return something that has the same shape as x, not as x.T (inner_x).
# So re-apply the outer dimshuffle on the new inc_subtensor,
# and return advanced_inc_subtensor1(x.T, i, y.T).T.
# Get the dimshuffle pattern to apply to y.
x_order = x.owner.op.new_order
y_order = ["x"] * x.ndim
for i, v in enumerate(x_order):
if v != "x" and (v - dim_offset) >= 0:
y_order[v - dim_offset] = i
inner_incsubtensor = inc_subtensor(
inner_x,
y.dimshuffle(y_order),
inplace=inplace,
set_instead_of_inc=set_instead_of_inc,
tolerate_inplace_aliasing=tolerate_inplace_aliasing,
ignore_duplicates=ignore_duplicates,
)
# The broadcastable pattern of inner_x may not be the same as
# the one of x, so we have to build a new dimshuffle here,
# instead of reusing x.owner.op().
return inner_incsubtensor.dimshuffle(x.owner.op.new_order)
elif isinstance(x.owner.op, Reshape):
# This case happens when the indices are not arranged as a vector, but
# as a higher-dimensional array. This is handled by the subtensor
# by flattening this list, taking the subtensor, then reshaping the
# result.
inner_x = x.owner.inputs[0]
# Try to apply inc_subtensor on inner_x.
# If it works, there is no need to reshape, as the inc_subtensor
# will have the same shape as inner_x, which is what we want.
# We also explicitly duplicate y to its broadcasted shape
# before we partially flatten it to inner_x dimension. This is
# not strictly needed in all cases, but it is easier this way.
if y.ndim > 0:
# This if is needed to prevent some useless warning about
# old code bug.
expanded_y = alloc(y, *[x.shape[i] for i in range(x.ndim)])
flattened_y = expanded_y.reshape(inner_x.shape)
else:
flattened_y = y
inner_incsubtensor = inc_subtensor(
inner_x,
flattened_y,
inplace=inplace,
set_instead_of_inc=set_instead_of_inc,
tolerate_inplace_aliasing=tolerate_inplace_aliasing,
ignore_duplicates=ignore_duplicates,
)
return inner_incsubtensor
else:
raise TypeError("x must be the result of a subtensor operation")
def take(a, indices, axis=None, mode="raise"):
......@@ -3021,39 +2981,6 @@ def take(a, indices, axis=None, mode="raise"):
return a[full_indices]
@_get_vector_length.register(Subtensor) # type: ignore
def _get_vector_length_Subtensor(op, var):
# If we take a slice, we know how many elements it will result in
# TODO: We can cover more `*Subtensor` cases.
try:
indices = pytensor.tensor.subtensor.get_idx_list(
var.owner.inputs, var.owner.op.idx_list
)
start = (
None
if indices[0].start is None
else get_scalar_constant_value(indices[0].start)
)
stop = (
None
if indices[0].stop is None
else get_scalar_constant_value(indices[0].stop)
)
step = (
None
if indices[0].step is None
else get_scalar_constant_value(indices[0].step)
)
if start == stop:
return 0
arg_len = get_vector_length(var.owner.inputs[0])
return len(range(*slice(start, stop, step).indices(arg_len)))
except (ValueError, NotScalarConstantError):
raise ValueError(f"Length of {var} cannot be determined")
def slice_at_axis(sl: slice, axis: int) -> tuple[slice, ...]:
"""
Construct tuple of slices to slice an array in the given dimension.
......
......@@ -15,9 +15,8 @@ from pytensor.scalar import (
ComplexError,
)
from pytensor.tensor import _get_vector_length
from pytensor.tensor.exceptions import AdvancedIndexingError
from pytensor.tensor.type import TensorType
from pytensor.tensor.type_other import NoneConst
from pytensor.tensor.type_other import NoneTypeT
from pytensor.tensor.utils import hash_from_ndarray
......@@ -455,15 +454,14 @@ class _tensor_py_operators:
elif not isinstance(args, tuple):
args = (args,)
# Count the dimensions, check for bools and find ellipses.
ellipses = []
index_dim_count = 0
for i, arg in enumerate(args):
if arg is np.newaxis or arg is NoneConst:
# no increase in index_dim_count
if arg is None or (
isinstance(arg, Variable) and isinstance(arg.type, NoneTypeT)
):
pass
elif arg is Ellipsis:
# no increase in index_dim_count
ellipses.append(i)
elif (
isinstance(arg, np.ndarray | Variable)
......@@ -505,6 +503,41 @@ class _tensor_py_operators:
self.ndim - index_dim_count
)
if any(
arg is None
or (isinstance(arg, Variable) and isinstance(arg.type, NoneTypeT))
for arg in args
):
expansion_axes = []
new_args = []
# Track dims consumed by args and inserted `None`s after ellipsis
counter = 0
nones = 0
for arg in args:
if arg is None or (
isinstance(arg, Variable) and isinstance(arg.type, NoneTypeT)
):
expansion_axes.append(counter + nones) # Expand here
nones += 1
new_args.append(slice(None))
else:
new_args.append(arg)
consumed = 1
if hasattr(arg, "dtype") and arg.dtype == "bool":
consumed = arg.ndim
counter += consumed
expanded = pt.expand_dims(self, expansion_axes)
if all(
isinstance(arg, slice)
and arg.start is None
and arg.stop is None
and arg.step is None
for arg in new_args
):
return expanded
return expanded[tuple(new_args)]
def is_empty_array(val):
return (isinstance(val, tuple | list) and len(val) == 0) or (
isinstance(val, np.ndarray) and val.size == 0
......@@ -520,74 +553,16 @@ class _tensor_py_operators:
for inp in args
)
# Determine if advanced indexing is needed or not. The logic is
# already in `index_vars_to_types`: if it succeeds, standard indexing is
# used; if it fails with `AdvancedIndexingError`, advanced indexing is
# used
advanced = False
for i, arg in enumerate(args):
if includes_bool(arg):
advanced = True
break
if arg is not np.newaxis and arg is not NoneConst:
try:
pt.subtensor.index_vars_to_types(arg)
except AdvancedIndexingError:
if advanced:
break
else:
advanced = True
if advanced:
return pt.subtensor.advanced_subtensor(self, *args)
if all(
(
isinstance(arg, slice | int | float | np.number)
or (hasattr(arg, "ndim") and arg.ndim == 0 and arg.dtype != "bool")
)
for arg in args
):
return pt.subtensor.basic_subtensor(self, *args)
else:
if np.newaxis in args or NoneConst in args:
# `np.newaxis` (i.e. `None`) in NumPy indexing mean "add a new
# broadcastable dimension at this location". Since PyTensor adds
# new broadcastable dimensions via the `DimShuffle` `Op`, the
# following code uses said `Op` to add one of the new axes and
# then uses recursion to apply any other indices and add any
# remaining new axes.
counter = 0
pattern = []
new_args = []
for arg in args:
if arg is np.newaxis or arg is NoneConst:
pattern.append("x")
new_args.append(slice(None, None, None))
else:
pattern.append(counter)
counter += 1
new_args.append(arg)
pattern.extend(list(range(counter, self.ndim)))
view = self.dimshuffle(pattern)
full_slices = True
for arg in new_args:
# We can't do arg == slice(None, None, None) as in
# Python 2.7, this call __lt__ if we have a slice
# with some symbolic variable.
if not (
isinstance(arg, slice)
and (arg.start is None or arg.start is NoneConst)
and (arg.stop is None or arg.stop is NoneConst)
and (arg.step is None or arg.step is NoneConst)
):
full_slices = False
if full_slices:
return view
else:
return view.__getitem__(tuple(new_args))
else:
return pt.subtensor.Subtensor(args)(
self,
*pt.subtensor.get_slice_elements(
args, lambda entry: isinstance(entry, Variable)
),
)
return pt.subtensor.advanced_subtensor(self, *args)
def __setitem__(self, key, value):
raise TypeError(
......
......@@ -2,9 +2,10 @@ from itertools import zip_longest
from pytensor import as_symbolic
from pytensor.graph import Constant, node_rewriter
from pytensor.tensor import TensorType, arange, specify_shape
from pytensor.tensor import arange, specify_shape
from pytensor.tensor.subtensor import _non_consecutive_adv_indexing, inc_subtensor
from pytensor.tensor.type_other import NoneTypeT, SliceType
from pytensor.tensor.variable import TensorVariable
from pytensor.xtensor.basic import tensor_from_xtensor, xtensor_from_tensor
from pytensor.xtensor.indexing import Index, IndexUpdate, index
from pytensor.xtensor.rewriting.utils import register_lower_xtensor
......@@ -106,7 +107,7 @@ def _lower_index(node):
# We can use basic indexing directly if no other index acts on this dimension
# This is an optimization that avoids creating an unnecessary arange tensor
# and facilitates the use of the specialized AdvancedSubtensor1 when possible
aligned_idxs.append(idx)
aligned_idxs.append(to_basic_idx(idx))
basic_idx_axis.append(out_dims.index(x_dim))
else:
# Otherwise we need to convert the basic index into an equivalent advanced indexing
......@@ -131,7 +132,7 @@ def _lower_index(node):
if basic_idx_axis:
aligned_idxs = [
idx.squeeze(axis=basic_idx_axis)
if (isinstance(idx.type, TensorType) and idx.type.ndim > 0)
if (isinstance(idx, TensorVariable) and idx.type.ndim > 0)
else idx
for idx in aligned_idxs
]
......
......@@ -26,9 +26,7 @@ from pytensor.graph.rewriting.unify import LiteralString, OpPattern
from pytensor.raise_op import assert_op
from pytensor.tensor.math import Dot, add, dot, exp
from pytensor.tensor.rewriting.basic import constant_folding
from pytensor.tensor.subtensor import AdvancedSubtensor
from pytensor.tensor.type import matrix, values_eq_approx_always_true, vector
from pytensor.tensor.type_other import MakeSlice, SliceConstant, slicetype
from tests.graph.utils import (
MyOp,
MyType,
......@@ -629,21 +627,6 @@ def test_pre_constant_merge():
assert res == [o2]
assert o2.owner.inputs[2] is c2
# What is this supposed to test?
ms = MakeSlice()(1)
res = pre_constant_merge(empty_fgraph, [ms])
assert res == [ms]
const_slice = SliceConstant(type=slicetype, data=slice(1, None, 2))
assert isinstance(const_slice, Constant)
adv = AdvancedSubtensor()(matrix(), [2, 3], const_slice)
res = pre_constant_merge(empty_fgraph, adv)
assert res == [adv]
def test_pre_greedy_node_rewriter():
empty_fgraph = FunctionGraph([], [])
......@@ -679,15 +662,6 @@ def test_pre_greedy_node_rewriter():
assert cst.owner.inputs[0] is o1
assert cst.owner.inputs[4] is cst.owner.inputs[0]
# What exactly is this supposed to test?
ms = MakeSlice()(1)
cst = pre_greedy_node_rewriter(empty_fgraph, [constant_folding], ms)
assert isinstance(cst, SliceConstant)
# Make sure constant of slice signature is hashable.
assert isinstance(hash(cst.signature()), int)
@pytest.mark.parametrize("tracks", [True, False])
@pytest.mark.parametrize("out_pattern", [(op2, "x"), "x", 1.0])
......
......@@ -225,6 +225,37 @@ def test_jax_IncSubtensor():
compare_jax_and_py([], [out_pt], [])
@pytest.mark.parametrize(
"func", (pt_subtensor.advanced_inc_subtensor1, pt_subtensor.advanced_set_subtensor1)
)
def test_jax_AdvancedIncSubtensor1_runtime_broadcast(func):
"""Test that JAX backend checks for runtime broadcasting in AdvancedIncSubtensor1.
JAX silently broadcasts when using .at[].set() or .at[].add(), but PyTensor
requires explicit broadcastable dimensions. This test ensures we raise the same
error as the Python/C backend when runtime broadcasting would occur.
"""
from pytensor import function
y = pt.matrix("y", dtype="float64", shape=(None, None))
x = pt.zeros((10, 5))
idxs = np.repeat(np.arange(10), 2) # 20 indices
out = func(x, y, idxs)
f = function([y], out, mode="JAX")
# Should work with correctly sized y
f(np.ones((20, 5)))
# Should raise for runtime broadcasting on first dimension
with pytest.raises(ValueError, match="Runtime broadcasting not allowed"):
f(np.ones((1, 5)))
# Should raise for runtime broadcasting on second dimension
with pytest.raises(ValueError, match="Runtime broadcasting not allowed"):
f(np.ones((20, 1)))
def test_jax_IncSubtensor_boolean_indexing_reexpressible():
"""Setting or incrementing values with boolean indexing.
......
......@@ -187,27 +187,6 @@ def test_mlx_inplace_variants():
compare_mlx_and_py([], [out_pt], [])
@pytest.mark.xfail(
reason="MLX slice indices must be integers or None, dynamic slices not supported"
)
def test_mlx_MakeSlice():
"""Test MakeSlice operation."""
# Test slice creation
start = pt.iscalar("start")
stop = pt.iscalar("stop")
step = pt.iscalar("step")
# Create a slice using MakeSlice
slice_op = pt_subtensor.MakeSlice()
slice_pt = slice_op(start, stop, step)
# Use simple constant array instead of arange
x_pt = pt.constant(np.arange(10, dtype=np.float32))
out_pt = x_pt[slice_pt]
compare_mlx_and_py([start, stop, step], [out_pt], [1, 8, 2])
def test_mlx_subtensor_edge_cases():
"""Test edge cases and boundary conditions."""
# Empty slices - use constant array
......
......@@ -3,9 +3,7 @@ import contextlib
import numpy as np
import pytest
import pytensor.scalar as ps
import pytensor.tensor as pt
from pytensor import Mode, as_symbolic
from pytensor.tensor import as_tensor
from pytensor.tensor.subtensor import (
AdvancedIncSubtensor,
......@@ -20,51 +18,16 @@ from pytensor.tensor.subtensor import (
inc_subtensor,
set_subtensor,
)
from tests.link.numba.test_basic import compare_numba_and_py, numba_mode
from tests.link.numba.test_basic import (
compare_numba_and_py,
numba_inplace_mode,
numba_mode,
)
rng = np.random.default_rng(sum(map(ord, "Numba subtensors")))
@pytest.mark.parametrize("step", [None, 1, 2, -2, "x"], ids=lambda x: f"step={x}")
@pytest.mark.parametrize("stop", [None, 10, "x"], ids=lambda x: f"stop={x}")
@pytest.mark.parametrize("start", [None, 0, 3, "x"], ids=lambda x: f"start={x}")
def test_slice(start, stop, step):
x = ps.int64("x")
sym_slice = as_symbolic(
slice(
x if start == "x" else start,
x if stop == "x" else stop,
x if step == "x" else step,
)
)
no_opt_mode = Mode(linker="numba", optimizer=None)
evaled_slice = sym_slice.eval({x: -5}, on_unused_input="ignore", mode=no_opt_mode)
assert isinstance(evaled_slice, slice)
if start == "x":
assert evaled_slice.start == -5
elif start is None and (evaled_slice.step is None or evaled_slice.step > 0):
# Numba can convert to 0 (and sometimes does) in this case
assert evaled_slice.start in (None, 0)
else:
assert evaled_slice.start == start
if stop == "x":
assert evaled_slice.stop == -5
else:
assert evaled_slice.stop == stop
if step == "x":
assert evaled_slice.step == -5
elif step is None:
# Numba can convert to 1 (and sometimes does) in this case
assert evaled_slice.step in (None, 1)
else:
assert evaled_slice.step == step
@pytest.mark.parametrize(
"x, indices",
[
......@@ -182,6 +145,11 @@ def test_AdvancedSubtensor1_out_of_bounds():
as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
([[1, 2], [2, 1]], slice(1, None), [[0, 0], [0, 0]]),
),
# Newaxis with vector indexing
(
as_tensor(np.arange(4 * 4).reshape((4, 4))),
(None, [0, 1, 2], [0, 1, 2]),
),
],
)
@pytest.mark.filterwarnings("error") # Raise if we did not expect objmode to be needed
......@@ -447,6 +415,13 @@ def test_AdvancedIncSubtensor1(x, y, indices):
False,
False,
),
(
np.arange(4 * 4).reshape((4, 4)),
np.array(5), # Broadcasted scalar value
(None, [0, 1, 2], [0, 1, 2]), # Newaxis with vector indexing
False,
False,
),
],
)
@pytest.mark.parametrize("inplace", (False, True))
......@@ -460,7 +435,9 @@ def test_AdvancedIncSubtensor(
inplace,
):
# Need rewrite to support certain forms of advanced indexing without object mode
mode = numba_mode.including("specialize")
# Use inplace_mode when testing inplace operations to preserve inplace flag
base_mode = numba_inplace_mode if inplace else numba_mode
mode = base_mode.including("specialize")
x_pt = pt.as_tensor(x).type("x")
y_pt = pt.as_tensor(y).type("y")
......@@ -514,22 +491,3 @@ def test_AdvancedIncSubtensor(
x_orig = x.copy()
fn(x, y)
assert not np.all(x == x_orig)
def test_advanced_indexing_with_newaxis_fallback_obj_mode():
# This should be automatically solved with https://github.com/pymc-devs/pytensor/issues/1564
# After which we can add these parametrizations to the relevant tests above
x = pt.matrix("x")
out = x[None, [0, 1, 2], [0, 1, 2]]
with pytest.warns(
UserWarning,
match=r"Numba will use object mode to run AdvancedSubtensor's perform method",
):
compare_numba_and_py([x], [out], [np.random.normal(size=(4, 4))])
out = x[None, [0, 1, 2], [0, 1, 2]].inc(5)
with pytest.warns(
UserWarning,
match=r"Numba will use object mode to run AdvancedIncSubtensor's perform method",
):
compare_numba_and_py([x], [out], [np.random.normal(size=(4, 4))])
......@@ -1642,9 +1642,15 @@ def test_InplaceElemwiseOptimizer_bug():
# with config.change_flags(tensor__insert_inplace_optimizer_validate_nb=10):
rewrite_graph(fgraph, include=("inplace",))
pytensor.config.tensor__insert_inplace_optimizer_validate_nb = 1
with pytest.warns(
FutureWarning,
match="tensor__insert_inplace_optimizer_validate_nb config is deprecated",
):
rewrite_graph(fgraph, include=("inplace",))
# Save original value to restore later
original_value = pytensor.config.tensor__insert_inplace_optimizer_validate_nb
try:
pytensor.config.tensor__insert_inplace_optimizer_validate_nb = 1
with pytest.warns(
FutureWarning,
match="tensor__insert_inplace_optimizer_validate_nb config is deprecated",
):
rewrite_graph(fgraph, include=("inplace",))
finally:
# Restore original value to avoid affecting other tests
pytensor.config.tensor__insert_inplace_optimizer_validate_nb = original_value
......@@ -52,7 +52,6 @@ from pytensor.tensor.type import (
tensor4,
vector,
)
from pytensor.tensor.type_other import make_slice
from tests import unittest_tools as utt
from tests.unittest_tools import create_pytensor_param
......@@ -1701,11 +1700,11 @@ def test_local_uint_constant_indices():
assert isinstance(new_index, Constant)
assert new_index.type.dtype == "uint8"
# `AdvancedSubtensor`, two indices, one symbolic slice, convert
# `AdvancedSubtensor`, two indices, one slice, convert
x = pt.matrix("x")
indices = (
pt.as_tensor_variable(np.array(1, np.int64)),
make_slice(slice(None, 10)),
pt.as_tensor_variable(np.array([1], np.int64)),
slice(None, 10),
)
z = x[indices]
......@@ -1792,7 +1791,7 @@ def test_local_uint_constant_indices():
z_fn = pytensor.function([x], z, mode=mode)
subtensor_node = z_fn.maker.fgraph.outputs[0].owner
assert isinstance(subtensor_node.op, AdvancedSubtensor)
assert isinstance(subtensor_node.op, (AdvancedSubtensor, AdvancedSubtensor1))
new_index = subtensor_node.inputs[1]
assert isinstance(new_index, Constant)
assert new_index.type.dtype == "uint8"
......@@ -1843,7 +1842,6 @@ class TestBlockwiseIncSubtensor:
out = vectorize_graph(core_graph, replace={core_x: x, core_y: y})
fn, ref_fn = self.compile_fn_and_ref([x, y], out)
assert self.has_blockwise(ref_fn)
assert not self.has_blockwise(fn)
test_x = np.ones(x.type.shape, dtype=x.type.dtype)
test_y = rng.integers(1, 10, size=y.type.shape, dtype=y.type.dtype)
np.testing.assert_allclose(fn(test_x, test_y), ref_fn(test_x, test_y))
......@@ -1948,15 +1946,7 @@ class TestBlockwiseIncSubtensor:
@pytest.mark.parametrize(
"basic_idx",
[
True,
pytest.param(
False,
marks=pytest.mark.xfail(
reason="AdvancedIncSubtensor with slices can't be blockwise"
),
),
],
[True, False],
ids=["basic_idx", "adv_idx"],
)
@pytest.mark.parametrize(
......@@ -1973,7 +1963,7 @@ class TestBlockwiseIncSubtensor:
core_idx = pt.tensor("idx", dtype=int, shape=() if basic_idx else (2,))
# The empty slice before core_idx, will lead to a transposition of the advanced view
# once it is paired with an new arange slice on the batched dimensions.
# once it is paired with a new arange slice on the batched dimensions.
# That's why core_v is (2, 3), and not (3, 2), in the case of advanced indexing
core_out = core_a[0, :, core_idx].set(core_v)
......
......@@ -32,7 +32,6 @@ from pytensor.tensor import (
lscalars,
matrix,
shape,
slicetype,
specify_shape,
tensor,
tensor3,
......@@ -557,7 +556,7 @@ class TestLocalSubtensorSpecifyShapeLift:
(
matrix(),
(iscalar(), iscalar()),
(slicetype(),),
(slice(iscalar(), iscalar(), iscalar()),),
),
(
matrix(),
......@@ -789,12 +788,12 @@ def test_local_subtensor_shape_constant():
(lambda x: x[:, [0, 1]][0], True),
(lambda x: x[:, [0, 1], [0, 0]][1:], True),
(lambda x: x[:, [[0, 1], [0, 0]]][1:], True),
(lambda x: x[:, None, [0, 1]][0], True),
# Not supported, basic indexing on advanced indexing dim
(lambda x: x[[0, 1]][0], False),
# Not implemented, basic indexing on the right of advanced indexing
# Not supported, basic indexing on the right of advanced indexing
(lambda x: x[[0, 1]][:, 0], False),
# Not implemented, complex flavors of advanced indexing
(lambda x: x[:, None, [0, 1]][0], False),
(lambda x: x[:, 5:, [0, 1]][0], False),
(lambda x: x[:, :, np.array([True, False, False])][0], False),
(lambda x: x[[0, 1], :, [0, 1]][:, 0], False),
......
......@@ -31,6 +31,8 @@ from pytensor.tensor.blockwise import (
vectorize_node_fallback,
)
from pytensor.tensor.nlinalg import MatrixInverse, eig
from pytensor.tensor.random import normal
from pytensor.tensor.random.op import default_rng
from pytensor.tensor.rewriting.blas import specialize_matmul_to_batched_dot
from pytensor.tensor.signal import convolve1d
from pytensor.tensor.slinalg import (
......@@ -114,16 +116,18 @@ def test_vectorize_blockwise():
def test_vectorize_node_fallback_unsupported_type():
x = tensor("x", shape=(2, 6))
node = x[:, [0, 2, 4]].owner
rng = default_rng()
node = normal(rng=rng).owner
with pytest.raises(
NotImplementedError,
match=re.escape(
"Cannot vectorize node AdvancedSubtensor(x, MakeSlice.0, [0 2 4]) with input MakeSlice.0 of type slice"
'Cannot vectorize node normal_rv{"(),()->()"}('
"DefaultGeneratorMakerOp.0, NoneConst{None}, 0.0, 1.0)"
" with input DefaultGeneratorMakerOp.0 of type RandomGeneratorType"
),
):
vectorize_node_fallback(node.op, node, node.inputs)
vectorize_node_fallback(node.op, node, *node.inputs)
def check_blockwise_runtime_broadcasting(mode):
......
......@@ -11,20 +11,19 @@ from numpy.testing import assert_array_equal
import pytensor
import pytensor.scalar as scal
import pytensor.tensor.basic as ptb
from pytensor import function
from pytensor.compile import DeepCopyOp, shared
from pytensor import function, shared
from pytensor.compile import DeepCopyOp
from pytensor.compile.io import In
from pytensor.compile.mode import Mode, get_default_mode
from pytensor.configdefaults import config
from pytensor.gradient import grad
from pytensor.graph import Constant
from pytensor.graph.basic import equal_computations
from pytensor.graph.op import get_test_value
from pytensor.graph.rewriting.utils import is_same_graph
from pytensor.link.numba import NumbaLinker
from pytensor.printing import pprint
from pytensor.scalar.basic import as_scalar, int16
from pytensor.tensor import as_tensor, constant, get_vector_length, vectorize
from pytensor.tensor import as_tensor, constant, get_vector_length, ivector, vectorize
from pytensor.tensor.blockwise import Blockwise, BlockwiseWithCoreShape
from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.math import exp, isinf, lt, switch
......@@ -33,7 +32,6 @@ from pytensor.tensor.shape import specify_broadcastable, specify_shape
from pytensor.tensor.subtensor import (
AdvancedIncSubtensor,
AdvancedIncSubtensor1,
AdvancedIndexingError,
AdvancedSubtensor,
AdvancedSubtensor1,
IncSubtensor,
......@@ -49,7 +47,6 @@ from pytensor.tensor.subtensor import (
flip,
get_canonical_form_slice,
inc_subtensor,
index_vars_to_types,
indexed_result_shape,
set_subtensor,
slice_at_axis,
......@@ -80,13 +77,7 @@ from pytensor.tensor.type import (
tensor5,
vector,
)
from pytensor.tensor.type_other import (
NoneConst,
SliceConstant,
as_symbolic_slice,
make_slice,
slicetype,
)
from pytensor.tensor.type_other import NoneConst
from tests import unittest_tools as utt
from tests.tensor.utils import inplace_func, integers_ranged, random
......@@ -106,20 +97,12 @@ def test_as_index_literal():
assert res == slice(1, None)
res = as_index_literal(slice(None, None, ptb.as_tensor(2)))
assert res == slice(None, None, 2)
res = as_index_literal(SliceConstant(slicetype, slice(None)))
assert res == slice(None)
res = as_index_literal(make_slice(None, ptb.as_tensor(1)))
assert res == slice(None, 1)
res = as_index_literal(ptb.as_tensor(2))
assert res == 2
res = as_index_literal(np.newaxis)
assert res is np.newaxis
res = as_index_literal(NoneConst)
assert res is np.newaxis
res = as_index_literal(NoneConst.clone())
assert res is np.newaxis
class TestGetCanonicalFormSlice:
......@@ -128,8 +111,6 @@ class TestGetCanonicalFormSlice:
[
NoneConst,
None,
as_symbolic_slice(slice(3, 7, 2)),
as_symbolic_slice(slice(3, int16(), 2)),
vector(),
],
)
......@@ -137,6 +118,19 @@ class TestGetCanonicalFormSlice:
with pytest.raises(ValueError, match="not a supported slice"):
get_canonical_form_slice(idx, 5)
@pytest.mark.parametrize(
"idx,expected_direction",
[
(slice(3, 7, 2), 1),
(slice(None, None), 1),
(slice(None, None, -1), -1),
],
)
def test_python_slice_support(self, idx, expected_direction):
result, direction = get_canonical_form_slice(idx, 10)
assert isinstance(result, slice)
assert direction == expected_direction
def test_scalar_constant(self):
a = as_scalar(0)
length = lscalar()
......@@ -408,7 +402,7 @@ class TestSubtensor(utt.OptimizationTestMixin):
f = inplace_func([], t, mode=mode)
topo = f.maker.fgraph.toposort()
topo_ = [node for node in topo if not isinstance(node.op, DeepCopyOp)]
assert len(topo_) == length
assert len(topo_) == length, f.dprint()
if length == 1:
assert isinstance(topo_[0].op, op_type)
tval = f()
......@@ -623,7 +617,7 @@ class TestSubtensor(utt.OptimizationTestMixin):
(3, DimShuffle, np.index_exp[..., [0, 2, 3]]),
(1, DimShuffle, np.index_exp[np.newaxis, ...]),
(
1,
4 if config.mode == "FAST_COMPILE" else 3,
AdvancedSubtensor,
np.index_exp[..., np.newaxis, [1, 2]],
),
......@@ -1967,7 +1961,7 @@ class TestAdvancedSubtensor:
x = self.shared(x_val, name="x")
y = tensor(dtype="float32", shape=(None,) * len(y_val.shape), name="y")
sym_idx = [ptb.as_tensor_variable(ix) for ix in idx]
expr = AdvancedIncSubtensor(inplace=inplace)(x, y, *sym_idx)
expr = advanced_inc_subtensor(x, y, *sym_idx, inplace=inplace)
f = pytensor.function(
[y], expr, mode=self.mode.excluding("inplace"), accept_inplace=inplace
)
......@@ -2303,20 +2297,29 @@ class TestAdvancedSubtensor:
def test_adv_sub_slice(self):
# Reported in https://github.com/Theano/Theano/issues/5898
var = self.shared(np.zeros([3, 3], dtype=config.floatX))
slc = slicetype()
f = pytensor.function([slc], var[slc], mode=self.mode)
s = slice(1, 3)
assert f(s).shape == (2, 3)
f_shape0 = pytensor.function([slc], var[slc].shape[0], mode=self.mode)
assert f_shape0(s) == 2
# Test with scalar variables for slice boundaries
start = lscalar("start")
stop = lscalar("stop")
# Create sliced output
f = pytensor.function([start, stop], var[start:stop], mode=self.mode)
result = f(1, 3)
assert result.shape == (2, 3)
f_shape1 = pytensor.function([slc], var[slc].shape[1], mode=self.mode)
f_shape0 = pytensor.function(
[start, stop], var[start:stop].shape[0], mode=self.mode
)
assert f_shape0(1, 3) == 2
f_shape1 = pytensor.function(
[start, stop], var[start:stop].shape[1], mode=self.mode
)
assert not any(
isinstance(node.op, AdvancedSubtensor)
for node in f_shape1.maker.fgraph.toposort()
)
assert f_shape1(s) == 3
assert f_shape1(1, 3) == 3
def test_adv_grouped(self):
# Reported in https://github.com/Theano/Theano/issues/6152
......@@ -2798,8 +2801,8 @@ class TestInferShape(utt.InferShapeTester):
def test_advanced_subtensor_constant_slice(self):
x = dmatrix("x")
constant_slice = pytensor.as_symbolic(slice(1, None, None))
assert isinstance(constant_slice, Constant)
# Use Python slice directly instead of as_symbolic(slice())
constant_slice = slice(1, None, None)
adv_indices = ptb.constant(np.zeros((2, 3)), dtype="int")
y = advanced_subtensor(x, constant_slice, adv_indices)
assert tuple(y.shape.eval({x: np.zeros((10, 10))})) == (9, 2, 3)
......@@ -2808,7 +2811,7 @@ class TestInferShape(utt.InferShapeTester):
@config.change_flags(compute_test_value="raise")
def test_basic_shape():
test_shape = (5, 4)
test_indices = (make_slice(1, 3, None),)
test_indices = (slice(1, 3, None),) # Python slice instead of make_slice()
res = basic_shape(test_shape, test_indices)
assert get_test_value(res) == (2,)
......@@ -2846,18 +2849,6 @@ test_idx = np.ix_(np.array([True, True]), np.array([True]), np.array([True, True
np.arange(np.prod((5, 6, 7, 8))).reshape((5, 6, 7, 8)),
(slice(None, None), *test_idx[:1]),
),
(
np.arange(np.prod((5, 6, 7, 8))).reshape((5, 6, 7, 8)),
(slice(None, None), None, *test_idx[1:2]),
),
(
np.arange(np.prod((5, 6, 7, 8))).reshape((5, 6, 7, 8)),
(np.array(1), slice(None, None), None),
),
(
np.arange(np.prod((5, 6, 7, 8))).reshape((5, 6, 7, 8)),
(slice(None, None), None, np.array(1)),
),
(
np.arange(np.prod((5, 6, 7, 8))).reshape((5, 6, 7, 8)),
(*test_idx[:1], slice(None, None), *test_idx[1:2]),
......@@ -2866,10 +2857,6 @@ test_idx = np.ix_(np.array([True, True]), np.array([True]), np.array([True, True
np.arange(np.prod((5, 6, 7, 8))).reshape((5, 6, 7, 8)),
(*test_idx[:1], slice(None, None), *test_idx[1:2], slice(None, None)),
),
(
np.arange(np.prod((5, 6, 7, 8))).reshape((5, 6, 7, 8)),
(*test_idx[:1], None, *test_idx[1:2]),
),
(np.arange(np.prod((5, 4))).reshape((5, 4)), ([1, 3, 2], slice(1, 3))),
(np.arange(np.prod((5, 4))).reshape((5, 4)), (slice(1, 3), [1, 3, 2])),
(
......@@ -2929,12 +2916,11 @@ def test_get_vector_length():
"indices, exp_res",
[
((0,), "x[0]"),
# TODO: The numbers should be printed
((slice(None, 2),), "x[:int64]"),
((slice(0, None),), "x[int64:]"),
((slice(0, 2),), "x[int64:int64]"),
((slice(0, 2, 2),), "x[int64:int64:int64]"),
((slice(0, 2), 0, slice(0, 2)), "x[int64:int64, 2, int64:int64]"),
((slice(None, 2),), "x[:2]"),
((slice(0, None),), "x[0:]"),
((slice(0, 2),), "x[0:2]"),
((slice(0, 2, 2),), "x[0:2:2]"),
((slice(0, 2), 0, slice(0, 2)), "x[0:2, 0, 0:2]"),
],
)
def test_pprint_Subtensor(indices, exp_res):
......@@ -2948,7 +2934,7 @@ def test_pprint_Subtensor(indices, exp_res):
[
((0,), False, "inc_subtensor(x[0], z)"),
((0,), True, "set_subtensor(x[0], z)"),
((slice(0, 2),), True, "set_subtensor(x[int64:int64], z)"),
((slice(0, 2),), True, "set_subtensor(x[0:2], z)"),
],
)
def test_pprint_IncSubtensor(indices, set_instead_of_inc, exp_res):
......@@ -2958,22 +2944,38 @@ def test_pprint_IncSubtensor(indices, set_instead_of_inc, exp_res):
assert pprint(y) == exp_res
def test_index_vars_to_types():
x = ptb.as_tensor_variable(np.array([True, False]))
with pytest.raises(AdvancedIndexingError):
index_vars_to_types(x)
with pytest.raises(TypeError):
index_vars_to_types(1)
@pytest.mark.parametrize(
"indices, exp_res",
[
# Vector index
((ivector("idx"),), "x[idx]"),
# Two vector indices
((ivector("idx"), ivector("idx2")), "x[idx, idx2]"),
# Vector index with scalar (triggers advanced indexing)
((ivector("idx"), 0), "x[idx, 0]"),
# Vector index with constant slice
((ivector("idx"), slice(0, 5)), "x[idx, 0:5]"),
],
)
def test_pprint_AdvancedSubtensor(indices, exp_res):
x = tensor4("x")
y = advanced_subtensor(x, *indices)
assert pprint(y) == exp_res
res = index_vars_to_types(iscalar)
assert isinstance(res, scal.ScalarType)
x = scal.constant(1, dtype=np.uint8)
assert isinstance(x.type, scal.ScalarType)
res = index_vars_to_types(x)
assert res == x.type
@pytest.mark.parametrize(
"indices, set_instead_of_inc, exp_res",
[
((ivector("idx"),), False, "inc_subtensor(x[idx], z)"),
((ivector("idx"),), True, "set_subtensor(x[idx], z)"),
((ivector("idx"), slice(None, 5)), True, "set_subtensor(x[idx, :5], z)"),
],
)
def test_pprint_AdvancedIncSubtensor(indices, set_instead_of_inc, exp_res):
x = tensor4("x")
z = tensor3("z")
y = advanced_inc_subtensor(x, z, *indices, set_instead_of_inc=set_instead_of_inc)
assert pprint(y) == exp_res
@pytest.mark.parametrize(
......@@ -3066,15 +3068,12 @@ def test_vectorize_subtensor_without_batch_indices():
(2,),
False,
),
# (this is currently failing because PyTensor tries to vectorize the slice(None) operation,
# due to the exact same None constant being used there and in the np.newaxis)
pytest.param(
(lambda x, idx: x[:, idx, None]),
"(7,5,3),(2)->(7,2,1,3)",
(11, 7, 5, 3),
(2,),
False,
marks=pytest.mark.xfail(raises=NotImplementedError),
),
(
(lambda x, idx: x[:, idx, idx, :]),
......@@ -3083,27 +3082,23 @@ def test_vectorize_subtensor_without_batch_indices():
(2,),
False,
),
# (not supported, because fallback Blocwise can't handle slices)
pytest.param(
(lambda x, idx: x[:, idx, :, idx]),
"(7,5,3,5),(2)->(2,7,3)",
(11, 7, 5, 3, 5),
(2,),
True,
marks=pytest.mark.xfail(raises=NotImplementedError),
),
# Core x, batched idx
((lambda x, idx: x[idx]), "(t1),(idx)->(tx)", (7,), (11, 2), True),
# Batched x, batched idx
((lambda x, idx: x[idx]), "(t1),(idx)->(tx)", (11, 7), (11, 2), True),
# (not supported, because fallback Blocwise can't handle slices)
pytest.param(
(lambda x, idx: x[:, idx, :]),
"(t1,t2,t3),(idx)->(t1,tx,t3)",
(11, 7, 5, 3),
(11, 2),
True,
marks=pytest.mark.xfail(raises=NotImplementedError),
),
],
)
......@@ -3238,3 +3233,37 @@ class TestBenchmarks:
)
fn.vm.allow_gc = gc
benchmark(fn, x_values)
def test_subtensor_hash_and_eq():
s1 = Subtensor(idx_list=[slice(None, None, None), 0])
s2 = Subtensor(idx_list=[slice(None, None, None), 0])
assert s1 == s2
assert hash(s1) == hash(s2)
s3 = AdvancedSubtensor(idx_list=[slice(None, None, None), 0])
s4 = AdvancedIncSubtensor(idx_list=[slice(0, 1, None), 2])
assert s3 != s4
assert hash(s3) != hash(s4)
assert s1 != s3
inc1 = IncSubtensor(
idx_list=[slice(None)], inplace=True, destroyhandler_tolerate_aliased=[(0, 1)]
)
inc2 = IncSubtensor(
idx_list=[slice(None)], inplace=True, destroyhandler_tolerate_aliased=[(0, 1)]
)
inc3 = IncSubtensor(
idx_list=[slice(None)], inplace=True, destroyhandler_tolerate_aliased=[(0, 2)]
)
assert inc1 == inc2
assert hash(inc1) == hash(inc2)
assert inc1 != inc3
if hash(inc1) == hash(inc3):
assert inc1 == inc3
s_mix1 = Subtensor(idx_list=[0, slice(None), slice(None, 1)])
s_mix2 = Subtensor(idx_list=[0, slice(None), slice(None, 1)])
assert s_mix1 == s_mix2
assert hash(s_mix1) == hash(s_mix2)
......@@ -4,30 +4,8 @@ import pytensor
from pytensor import as_symbolic
from pytensor.graph.basic import Constant
from pytensor.tensor.math import argmax
from pytensor.tensor.type import iscalar, vector
from pytensor.tensor.type_other import (
MakeSlice,
NoneConst,
NoneTypeT,
SliceConstant,
SliceType,
make_slice,
)
def test_SliceType():
st = SliceType()
assert st == st.clone()
def test_make_slice_merge():
# In the past, this was crahsing during compilation.
i = iscalar()
s1 = make_slice(0, i)
s2 = make_slice(0, i)
f = pytensor.function([i], [s1, s2])
nodes = f.maker.fgraph.apply_nodes
assert len([n for n in nodes if isinstance(n.op, MakeSlice)]) == 1
from pytensor.tensor.type import vector
from pytensor.tensor.type_other import NoneConst, NoneTypeT
def test_none_Constant():
......@@ -47,8 +25,6 @@ def test_none_Constant():
# This trigger equals that returned the wrong answer in the past.
import pickle
import pytensor
x = vector("x")
y = argmax(x)
kwargs = {}
......@@ -60,11 +36,18 @@ def test_none_Constant():
def test_as_symbolic():
# Remove this when xtensor is not using symbolic slices
from pytensor.tensor.type import iscalar
from pytensor.tensor.type_other import SliceConstant, slicetype
res = as_symbolic(None)
assert res is NoneConst
res = as_symbolic(slice(iscalar()))
assert res.owner.op == make_slice
res = as_symbolic(slice(1, 2))
assert isinstance(res, SliceConstant)
assert res.type == slicetype
assert res.data == slice(1, 2)
i = iscalar()
res = as_symbolic(slice(i))
assert res.owner is not None
......@@ -35,7 +35,7 @@ from pytensor.tensor.type import (
scalar,
tensor3,
)
from pytensor.tensor.type_other import MakeSlice, NoneConst
from pytensor.tensor.type_other import NoneConst
from pytensor.tensor.variable import (
DenseTensorConstant,
DenseTensorVariable,
......@@ -232,11 +232,11 @@ def test__getitem__AdvancedSubtensor():
z = x[:, i]
op_types = [type(node.op) for node in io_toposort([x, i], [z])]
assert op_types == [MakeSlice, AdvancedSubtensor]
assert op_types == [AdvancedSubtensor]
z = x[..., i, None]
op_types = [type(node.op) for node in io_toposort([x, i], [z])]
assert op_types == [MakeSlice, AdvancedSubtensor]
assert op_types == [DimShuffle, AdvancedSubtensor]
z = x[i, None]
op_types = [type(node.op) for node in io_toposort([x, i], [z])]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论