提交 cc6bed1a authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Revert "Refactor AdvancedSubtensor"

This reverts commit db7fa079.
上级 03afa5bb
...@@ -771,9 +771,9 @@ class DestroyHandler(Bookkeeper): ...@@ -771,9 +771,9 @@ class DestroyHandler(Bookkeeper):
} }
tolerated.add(destroyed_idx) tolerated.add(destroyed_idx)
tolerate_aliased = getattr( tolerate_aliased = getattr(
app.op, "destroyhandler_tolerate_aliased", () app.op, "destroyhandler_tolerate_aliased", []
) )
assert isinstance(tolerate_aliased, tuple | list) assert isinstance(tolerate_aliased, list)
ignored = { ignored = {
idx1 for idx0, idx1 in tolerate_aliased if idx0 == destroyed_idx idx1 for idx0, idx1 in tolerate_aliased if idx0 == destroyed_idx
} }
......
...@@ -8,6 +8,7 @@ from pytensor.tensor.subtensor import ( ...@@ -8,6 +8,7 @@ from pytensor.tensor.subtensor import (
Subtensor, Subtensor,
indices_from_subtensor, indices_from_subtensor,
) )
from pytensor.tensor.type_other import MakeSlice
BOOLEAN_MASK_ERROR = """JAX does not support resizing arrays with boolean BOOLEAN_MASK_ERROR = """JAX does not support resizing arrays with boolean
...@@ -34,8 +35,10 @@ slice length. ...@@ -34,8 +35,10 @@ slice length.
@jax_funcify.register(AdvancedSubtensor) @jax_funcify.register(AdvancedSubtensor)
@jax_funcify.register(AdvancedSubtensor1) @jax_funcify.register(AdvancedSubtensor1)
def jax_funcify_Subtensor(op, node, **kwargs): def jax_funcify_Subtensor(op, node, **kwargs):
idx_list = getattr(op, "idx_list", None)
def subtensor(x, *ilists): def subtensor(x, *ilists):
indices = indices_from_subtensor(ilists, op.idx_list) indices = indices_from_subtensor(ilists, idx_list)
if len(indices) == 1: if len(indices) == 1:
indices = indices[0] indices = indices[0]
...@@ -45,9 +48,10 @@ def jax_funcify_Subtensor(op, node, **kwargs): ...@@ -45,9 +48,10 @@ def jax_funcify_Subtensor(op, node, **kwargs):
@jax_funcify.register(IncSubtensor) @jax_funcify.register(IncSubtensor)
@jax_funcify.register(AdvancedIncSubtensor)
@jax_funcify.register(AdvancedIncSubtensor1) @jax_funcify.register(AdvancedIncSubtensor1)
def jax_funcify_IncSubtensor(op, node, **kwargs): def jax_funcify_IncSubtensor(op, node, **kwargs):
idx_list = getattr(op, "idx_list", None)
if getattr(op, "set_instead_of_inc", False): if getattr(op, "set_instead_of_inc", False):
def jax_fn(x, indices, y): def jax_fn(x, indices, y):
...@@ -58,7 +62,7 @@ def jax_funcify_IncSubtensor(op, node, **kwargs): ...@@ -58,7 +62,7 @@ def jax_funcify_IncSubtensor(op, node, **kwargs):
def jax_fn(x, indices, y): def jax_fn(x, indices, y):
return x.at[indices].add(y) return x.at[indices].add(y)
def incsubtensor(x, y, *ilist, jax_fn=jax_fn, idx_list=op.idx_list): def incsubtensor(x, y, *ilist, jax_fn=jax_fn, idx_list=idx_list):
indices = indices_from_subtensor(ilist, idx_list) indices = indices_from_subtensor(ilist, idx_list)
if len(indices) == 1: if len(indices) == 1:
indices = indices[0] indices = indices[0]
...@@ -69,3 +73,29 @@ def jax_funcify_IncSubtensor(op, node, **kwargs): ...@@ -69,3 +73,29 @@ def jax_funcify_IncSubtensor(op, node, **kwargs):
return jax_fn(x, indices, y) return jax_fn(x, indices, y)
return incsubtensor return incsubtensor
@jax_funcify.register(AdvancedIncSubtensor)
def jax_funcify_AdvancedIncSubtensor(op, node, **kwargs):
if getattr(op, "set_instead_of_inc", False):
def jax_fn(x, indices, y):
return x.at[indices].set(y)
else:
def jax_fn(x, indices, y):
return x.at[indices].add(y)
def advancedincsubtensor(x, y, *ilist, jax_fn=jax_fn):
return jax_fn(x, ilist, y)
return advancedincsubtensor
@jax_funcify.register(MakeSlice)
def jax_funcify_MakeSlice(op, **kwargs):
def makeslice(*x):
return slice(*x)
return makeslice
...@@ -10,14 +10,15 @@ from pytensor.tensor.subtensor import ( ...@@ -10,14 +10,15 @@ from pytensor.tensor.subtensor import (
Subtensor, Subtensor,
indices_from_subtensor, indices_from_subtensor,
) )
from pytensor.tensor.type_other import MakeSlice
@mlx_funcify.register(Subtensor) @mlx_funcify.register(Subtensor)
def mlx_funcify_Subtensor(op, node, **kwargs): def mlx_funcify_Subtensor(op, node, **kwargs):
idx_list = getattr(op, "idx_list", None)
def subtensor(x, *ilists): def subtensor(x, *ilists):
indices = indices_from_subtensor( indices = indices_from_subtensor([int(element) for element in ilists], idx_list)
[int(element) for element in ilists], op.idx_list
)
if len(indices) == 1: if len(indices) == 1:
indices = indices[0] indices = indices[0]
...@@ -29,8 +30,10 @@ def mlx_funcify_Subtensor(op, node, **kwargs): ...@@ -29,8 +30,10 @@ def mlx_funcify_Subtensor(op, node, **kwargs):
@mlx_funcify.register(AdvancedSubtensor) @mlx_funcify.register(AdvancedSubtensor)
@mlx_funcify.register(AdvancedSubtensor1) @mlx_funcify.register(AdvancedSubtensor1)
def mlx_funcify_AdvancedSubtensor(op, node, **kwargs): def mlx_funcify_AdvancedSubtensor(op, node, **kwargs):
idx_list = getattr(op, "idx_list", None)
def advanced_subtensor(x, *ilists): def advanced_subtensor(x, *ilists):
indices = indices_from_subtensor(ilists, op.idx_list) indices = indices_from_subtensor(ilists, idx_list)
if len(indices) == 1: if len(indices) == 1:
indices = indices[0] indices = indices[0]
...@@ -42,6 +45,8 @@ def mlx_funcify_AdvancedSubtensor(op, node, **kwargs): ...@@ -42,6 +45,8 @@ def mlx_funcify_AdvancedSubtensor(op, node, **kwargs):
@mlx_funcify.register(IncSubtensor) @mlx_funcify.register(IncSubtensor)
@mlx_funcify.register(AdvancedIncSubtensor1) @mlx_funcify.register(AdvancedIncSubtensor1)
def mlx_funcify_IncSubtensor(op, node, **kwargs): def mlx_funcify_IncSubtensor(op, node, **kwargs):
idx_list = getattr(op, "idx_list", None)
if getattr(op, "set_instead_of_inc", False): if getattr(op, "set_instead_of_inc", False):
def mlx_fn(x, indices, y): def mlx_fn(x, indices, y):
...@@ -58,7 +63,7 @@ def mlx_funcify_IncSubtensor(op, node, **kwargs): ...@@ -58,7 +63,7 @@ def mlx_funcify_IncSubtensor(op, node, **kwargs):
x[indices] += y x[indices] += y
return x return x
def incsubtensor(x, y, *ilist, mlx_fn=mlx_fn, idx_list=op.idx_list): def incsubtensor(x, y, *ilist, mlx_fn=mlx_fn, idx_list=idx_list):
indices = indices_from_subtensor(ilist, idx_list) indices = indices_from_subtensor(ilist, idx_list)
if len(indices) == 1: if len(indices) == 1:
indices = indices[0] indices = indices[0]
...@@ -90,3 +95,11 @@ def mlx_funcify_AdvancedIncSubtensor(op, node, **kwargs): ...@@ -90,3 +95,11 @@ def mlx_funcify_AdvancedIncSubtensor(op, node, **kwargs):
return mlx_fn(x, ilist, y) return mlx_fn(x, ilist, y)
return advancedincsubtensor return advancedincsubtensor
@mlx_funcify.register(MakeSlice)
def mlx_funcify_MakeSlice(op, **kwargs):
def makeslice(*x):
return slice(*x)
return makeslice
...@@ -10,17 +10,18 @@ from numba import types ...@@ -10,17 +10,18 @@ from numba import types
from numba.core.pythonapi import box from numba.core.pythonapi import box
import pytensor.link.numba.dispatch.basic as numba_basic import pytensor.link.numba.dispatch.basic as numba_basic
from pytensor.graph import Variable from pytensor.graph import Type
from pytensor.link.numba.cache import ( from pytensor.link.numba.cache import (
compile_numba_function_src, compile_numba_function_src,
) )
from pytensor.link.numba.dispatch.basic import ( from pytensor.link.numba.dispatch.basic import (
generate_fallback_impl, generate_fallback_impl,
register_funcify_and_cache_key, register_funcify_and_cache_key,
register_funcify_default_op_cache_key,
) )
from pytensor.link.numba.dispatch.compile_ops import numba_deepcopy from pytensor.link.numba.dispatch.compile_ops import numba_deepcopy
from pytensor.link.numba.dispatch.string_codegen import create_tuple_string from pytensor.link.numba.dispatch.string_codegen import create_tuple_string
from pytensor.tensor import TensorType, TensorVariable from pytensor.tensor import TensorType
from pytensor.tensor.subtensor import ( from pytensor.tensor.subtensor import (
AdvancedIncSubtensor, AdvancedIncSubtensor,
AdvancedIncSubtensor1, AdvancedIncSubtensor1,
...@@ -28,8 +29,8 @@ from pytensor.tensor.subtensor import ( ...@@ -28,8 +29,8 @@ from pytensor.tensor.subtensor import (
AdvancedSubtensor1, AdvancedSubtensor1,
IncSubtensor, IncSubtensor,
Subtensor, Subtensor,
indices_from_subtensor,
) )
from pytensor.tensor.type_other import MakeSlice, NoneTypeT
def slice_new(self, start, stop, step): def slice_new(self, start, stop, step):
...@@ -117,6 +118,15 @@ def numba_deepcopy_slice(x): ...@@ -117,6 +118,15 @@ def numba_deepcopy_slice(x):
return deepcopy_slice return deepcopy_slice
@register_funcify_default_op_cache_key(MakeSlice)
def numba_funcify_MakeSlice(op, **kwargs):
@numba_basic.numba_njit
def makeslice(*x):
return slice(*x)
return makeslice
def subtensor_op_cache_key(op, **extra_fields): def subtensor_op_cache_key(op, **extra_fields):
key_parts = [type(op), tuple(extra_fields.items())] key_parts = [type(op), tuple(extra_fields.items())]
if hasattr(op, "idx_list"): if hasattr(op, "idx_list"):
...@@ -146,36 +156,35 @@ def subtensor_op_cache_key(op, **extra_fields): ...@@ -146,36 +156,35 @@ def subtensor_op_cache_key(op, **extra_fields):
def numba_funcify_default_subtensor(op, node, **kwargs): def numba_funcify_default_subtensor(op, node, **kwargs):
"""Create a Python function that assembles and uses an index on an array.""" """Create a Python function that assembles and uses an index on an array."""
def convert_indices(indices_iterator, entry): def convert_indices(indice_names, entry):
if isinstance(entry, int): if indice_names and isinstance(entry, Type):
name, var = next(indices_iterator) return next(indice_names)
if var.ndim == 0 and isinstance(var.type, TensorType):
return f"{name}.item()"
return name
elif isinstance(entry, slice): elif isinstance(entry, slice):
return ( return (
f"slice({convert_indices(indices_iterator, entry.start)}, " f"slice({convert_indices(indice_names, entry.start)}, "
f"{convert_indices(indices_iterator, entry.stop)}, " f"{convert_indices(indice_names, entry.stop)}, "
f"{convert_indices(indices_iterator, entry.step)})" f"{convert_indices(indice_names, entry.step)})"
) )
elif isinstance(entry, type(None)): elif isinstance(entry, type(None)):
return "None" return "None"
else: else:
raise ValueError(f"Unknown index type: {entry}") raise ValueError()
set_or_inc = isinstance( set_or_inc = isinstance(
op, IncSubtensor | AdvancedIncSubtensor1 | AdvancedIncSubtensor op, IncSubtensor | AdvancedIncSubtensor1 | AdvancedIncSubtensor
) )
index_start_idx = 1 + int(set_or_inc) index_start_idx = 1 + int(set_or_inc)
op_indices = list(node.inputs[index_start_idx:]) op_indices = list(node.inputs[index_start_idx:])
idx_list = op.idx_list idx_list = getattr(op, "idx_list", None)
idx_names = [f"idx_{i}" for i in range(len(op_indices))] idx_names = [f"idx_{i}" for i in range(len(op_indices))]
input_names = ["x", "y", *idx_names] if set_or_inc else ["x", *idx_names] input_names = ["x", "y", *idx_names] if set_or_inc else ["x", *idx_names]
indices_iterator = iter(zip(idx_names, op_indices)) idx_names_iterator = iter(idx_names)
indices_creation_src = tuple( indices_creation_src = (
convert_indices(indices_iterator, idx) for idx in idx_list tuple(convert_indices(idx_names_iterator, idx) for idx in idx_list)
if idx_list
else tuple(input_names[index_start_idx:])
) )
if len(indices_creation_src) == 1: if len(indices_creation_src) == 1:
...@@ -231,24 +240,20 @@ def {function_name}({", ".join(input_names)}): ...@@ -231,24 +240,20 @@ def {function_name}({", ".join(input_names)}):
@register_funcify_and_cache_key(AdvancedIncSubtensor) @register_funcify_and_cache_key(AdvancedIncSubtensor)
def numba_funcify_AdvancedSubtensor(op, node, **kwargs): def numba_funcify_AdvancedSubtensor(op, node, **kwargs):
if isinstance(op, AdvancedSubtensor): if isinstance(op, AdvancedSubtensor):
_x, *index_variables = node.inputs _x, _y, idxs = node.inputs[0], None, node.inputs[1:]
else: else:
_x, _y, *index_variables = node.inputs _x, _y, *idxs = node.inputs
reconstructed_indices = indices_from_subtensor(index_variables, op.idx_list) adv_idxs = [
{
adv_idxs = [] "axis": i,
for i, idx in enumerate(reconstructed_indices): "dtype": idx.type.dtype,
if isinstance(idx, TensorVariable): "bcast": idx.type.broadcastable,
# This is an advanced tensor index "ndim": idx.type.ndim,
adv_idxs.append( }
{ for i, idx in enumerate(idxs)
"axis": i, if isinstance(idx.type, TensorType)
"dtype": idx.type.dtype, ]
"bcast": idx.type.broadcastable,
"ndim": idx.type.ndim,
}
)
must_ignore_duplicates = ( must_ignore_duplicates = (
isinstance(op, AdvancedIncSubtensor) isinstance(op, AdvancedIncSubtensor)
...@@ -260,10 +265,13 @@ def numba_funcify_AdvancedSubtensor(op, node, **kwargs): ...@@ -260,10 +265,13 @@ def numba_funcify_AdvancedSubtensor(op, node, **kwargs):
) )
) )
# Special implementation for integer indices that respects duplicates
if ( if (
not must_ignore_duplicates not must_ignore_duplicates
and len(adv_idxs) >= 1 and len(adv_idxs) >= 1
and all(adv_idx["dtype"] != "bool" for adv_idx in adv_idxs) and all(adv_idx["dtype"] != "bool" for adv_idx in adv_idxs)
# Implementation does not support newaxis
and not any(isinstance(idx.type, NoneTypeT) for idx in idxs)
): ):
return vector_integer_advanced_indexing(op, node, **kwargs) return vector_integer_advanced_indexing(op, node, **kwargs)
...@@ -391,6 +399,7 @@ def vector_integer_advanced_indexing( ...@@ -391,6 +399,7 @@ def vector_integer_advanced_indexing(
y_bcast = np.broadcast_to(y_adv_dims_front, (*adv_idx_shape, *basic_idx_shape)) y_bcast = np.broadcast_to(y_adv_dims_front, (*adv_idx_shape, *basic_idx_shape))
# Ravel the advanced dims (if needed) # Ravel the advanced dims (if needed)
# Note that numba reshape only supports C-arrays, so we ravel before reshape
y_bcast = y_bcast y_bcast = y_bcast
# Index over tuples of raveled advanced indices and update buffer # Index over tuples of raveled advanced indices and update buffer
...@@ -451,90 +460,45 @@ def vector_integer_advanced_indexing( ...@@ -451,90 +460,45 @@ def vector_integer_advanced_indexing(
return x return x
""" """
if isinstance(op, AdvancedSubtensor1 | AdvancedSubtensor): if isinstance(op, AdvancedSubtensor1 | AdvancedSubtensor):
x, *index_variables = node.inputs x, *idxs = node.inputs
else: else:
x, y, *index_variables = node.inputs x, y, *idxs = node.inputs
[out] = node.outputs [out] = node.outputs
reconstructed_indices = indices_from_subtensor(index_variables, op.idx_list)
idx_args = [f"idx{i}" for i in range(len(index_variables))]
var_to_arg = dict(zip(index_variables, idx_args))
idxs = []
def get_idx_str(val, is_slice_component=False):
if val is None:
return "None"
if isinstance(val, Variable) and val in var_to_arg:
arg = var_to_arg[val]
if val.ndim == 0 and is_slice_component:
return f"{arg}.item()"
return arg
raise ValueError(f"Unexpected index value: {val}")
for idx in reconstructed_indices:
if isinstance(idx, slice):
start = get_idx_str(idx.start, is_slice_component=True)
stop = get_idx_str(idx.stop, is_slice_component=True)
step = get_idx_str(idx.step, is_slice_component=True)
idxs.append(f"slice({start}, {stop}, {step})")
else:
# It's a direct index variable
idxs.append(get_idx_str(idx, is_slice_component=False))
adv_indices_pos = tuple( adv_indices_pos = tuple(
i for i, idx in enumerate(reconstructed_indices) if not isinstance(idx, slice) i for i, idx in enumerate(idxs) if isinstance(idx.type, TensorType)
) )
assert adv_indices_pos # Otherwise it's just basic indexing assert adv_indices_pos # Otherwise it's just basic indexing
basic_indices_pos = tuple( basic_indices_pos = tuple(
i for i, idx in enumerate(reconstructed_indices) if isinstance(idx, slice) i for i, idx in enumerate(idxs) if not isinstance(idx.type, TensorType)
) )
explicit_basic_indices_pos = (*basic_indices_pos, *range(len(idxs), x.type.ndim))
# Create index signature for generated function: "idx0, idx1, idx2, ..." # Create index signature and split them among basic and advanced
idx_signature = ", ".join(idx_args) idx_signature = ", ".join(f"idx{i}" for i in range(len(idxs)))
adv_indices = [f"idx{i}" for i in adv_indices_pos]
basic_indices = [f"idx{i}" for i in basic_indices_pos]
# String representations of advanced and basic indices for codegen # Define transpose axis so that advanced indexing dims are on the front
adv_indices = [idxs[i] for i in adv_indices_pos] adv_axis_front_order = (*adv_indices_pos, *explicit_basic_indices_pos)
basic_indices = [idxs[i] for i in basic_indices_pos] adv_axis_front_transpose_needed = adv_axis_front_order != tuple(range(x.ndim))
adv_idx_ndim = max(idxs[i].ndim for i in adv_indices_pos)
to_tuple = create_tuple_string # alias to make code more readable below # Helper needed for basic indexing after moving advanced indices to the front
basic_indices_with_none_slices = ", ".join(
(*((":",) * len(adv_indices)), *basic_indices)
)
# Compute number of dimensions in advanced indices (after broadcasting) # Position of the first advanced index dimension after indexing the array
if len(adv_indices_pos) == 1: if (np.diff(adv_indices_pos) > 1).any():
adv_idx = reconstructed_indices[adv_indices_pos[0]] # If not consecutive, it's always at the front
adv_idx_ndim = adv_idx.ndim # type: ignore[union-attr] out_adv_axis_pos = 0
else: else:
# Multiple advanced indices - use max ndim (broadcast result ndim) # Otherwise wherever the first advanced index is located
adv_idx_ndim = max(reconstructed_indices[i].ndim for i in adv_indices_pos) # type: ignore[union-attr]
# Determine output position of advanced indexed dimensions
# If advanced indices are consecutive, they go in the first advanced index position
# Otherwise they go at the beginning
if adv_indices_pos == tuple(range(adv_indices_pos[0], adv_indices_pos[-1] + 1)):
# Consecutive - advanced dims will be at position of first advanced index
out_adv_axis_pos = adv_indices_pos[0] out_adv_axis_pos = adv_indices_pos[0]
else:
# Non-consecutive - advanced dims go at the front
out_adv_axis_pos = 0
# Include trailing dimensions not covered by explicit indices to_tuple = create_tuple_string # alias to make code more readable below
explicit_basic_indices_pos = (
*basic_indices_pos,
*range(len(reconstructed_indices), x.type.ndim),
)
# Compute transpose to move advanced indexed dims to the front
adv_axis_front_order = (*adv_indices_pos, *explicit_basic_indices_pos)
adv_axis_front_transpose_needed = adv_axis_front_order != tuple(range(x.type.ndim))
# Compute basic indices with "None" slices for dimensions that will be indexed by advanced indices
basic_indices_with_none_slices = ", ".join(
":" for _ in range(len(adv_indices_pos))
) + (", " + ", ".join(basic_indices) if basic_indices else "")
if isinstance(op, AdvancedSubtensor1 | AdvancedSubtensor): if isinstance(op, AdvancedSubtensor1 | AdvancedSubtensor):
# Define transpose axis on the output to restore original meaning # Define transpose axis on the output to restore original meaning
...@@ -593,8 +557,7 @@ def vector_integer_advanced_indexing( ...@@ -593,8 +557,7 @@ def vector_integer_advanced_indexing(
else: else:
# Make implicit dims of y explicit to simplify code # Make implicit dims of y explicit to simplify code
# Numba doesn't support `np.expand_dims` with multiple axis, so we use indexing with newaxis # Numba doesn't support `np.expand_dims` with multiple axis, so we use indexing with newaxis
indexed_ndim = x[tuple(reconstructed_indices)].type.ndim indexed_ndim = x[tuple(idxs)].type.ndim
y_expand_dims = [":"] * y.type.ndim y_expand_dims = [":"] * y.type.ndim
y_implicit_dims = range(indexed_ndim - y.type.ndim) y_implicit_dims = range(indexed_ndim - y.type.ndim)
for axis in y_implicit_dims: for axis in y_implicit_dims:
......
...@@ -9,6 +9,7 @@ from pytensor.tensor.subtensor import ( ...@@ -9,6 +9,7 @@ from pytensor.tensor.subtensor import (
Subtensor, Subtensor,
indices_from_subtensor, indices_from_subtensor,
) )
from pytensor.tensor.type_other import MakeSlice, SliceType
def check_negative_steps(indices): def check_negative_steps(indices):
...@@ -46,11 +47,23 @@ def pytorch_funcify_Subtensor(op, node, **kwargs): ...@@ -46,11 +47,23 @@ def pytorch_funcify_Subtensor(op, node, **kwargs):
return subtensor return subtensor
@pytorch_funcify.register(MakeSlice)
def pytorch_funcify_makeslice(op, **kwargs):
def makeslice(start, stop, step):
# Torch does not like numpy integers in indexing slices
return slice(
None if start is None else int(start),
None if stop is None else int(stop),
None if step is None else int(step),
)
return makeslice
@pytorch_funcify.register(AdvancedSubtensor1) @pytorch_funcify.register(AdvancedSubtensor1)
@pytorch_funcify.register(AdvancedSubtensor) @pytorch_funcify.register(AdvancedSubtensor)
def pytorch_funcify_AdvSubtensor(op, node, **kwargs): def pytorch_funcify_AdvSubtensor(op, node, **kwargs):
def advsubtensor(x, *indices): def advsubtensor(x, *indices):
indices = indices_from_subtensor(indices, op.idx_list)
check_negative_steps(indices) check_negative_steps(indices)
return x[indices] return x[indices]
...@@ -89,14 +102,12 @@ def pytorch_funcify_IncSubtensor(op, node, **kwargs): ...@@ -89,14 +102,12 @@ def pytorch_funcify_IncSubtensor(op, node, **kwargs):
@pytorch_funcify.register(AdvancedIncSubtensor) @pytorch_funcify.register(AdvancedIncSubtensor)
@pytorch_funcify.register(AdvancedIncSubtensor1) @pytorch_funcify.register(AdvancedIncSubtensor1)
def pytorch_funcify_AdvancedIncSubtensor(op, node, **kwargs): def pytorch_funcify_AdvancedIncSubtensor(op, node, **kwargs):
idx_list = op.idx_list
inplace = op.inplace inplace = op.inplace
ignore_duplicates = getattr(op, "ignore_duplicates", False) ignore_duplicates = getattr(op, "ignore_duplicates", False)
if op.set_instead_of_inc: if op.set_instead_of_inc:
def adv_set_subtensor(x, y, *flattened_indices): def adv_set_subtensor(x, y, *indices):
indices = indices_from_subtensor(flattened_indices, idx_list)
check_negative_steps(indices) check_negative_steps(indices)
if isinstance(op, AdvancedIncSubtensor1): if isinstance(op, AdvancedIncSubtensor1):
op._check_runtime_broadcasting(node, x, y, indices) op._check_runtime_broadcasting(node, x, y, indices)
...@@ -109,8 +120,7 @@ def pytorch_funcify_AdvancedIncSubtensor(op, node, **kwargs): ...@@ -109,8 +120,7 @@ def pytorch_funcify_AdvancedIncSubtensor(op, node, **kwargs):
elif ignore_duplicates: elif ignore_duplicates:
def adv_inc_subtensor_no_duplicates(x, y, *flattened_indices): def adv_inc_subtensor_no_duplicates(x, y, *indices):
indices = indices_from_subtensor(flattened_indices, idx_list)
check_negative_steps(indices) check_negative_steps(indices)
if isinstance(op, AdvancedIncSubtensor1): if isinstance(op, AdvancedIncSubtensor1):
op._check_runtime_broadcasting(node, x, y, indices) op._check_runtime_broadcasting(node, x, y, indices)
...@@ -122,14 +132,13 @@ def pytorch_funcify_AdvancedIncSubtensor(op, node, **kwargs): ...@@ -122,14 +132,13 @@ def pytorch_funcify_AdvancedIncSubtensor(op, node, **kwargs):
return adv_inc_subtensor_no_duplicates return adv_inc_subtensor_no_duplicates
else: else:
if any(isinstance(entry, slice) for entry in idx_list): if any(isinstance(idx.type, SliceType) for idx in node.inputs[2:]):
raise NotImplementedError( raise NotImplementedError(
"IncSubtensor with potential duplicates indexes and slice indexing not implemented in PyTorch" "IncSubtensor with potential duplicates indexes and slice indexing not implemented in PyTorch"
) )
def adv_inc_subtensor(x, y, *flattened_indices): def adv_inc_subtensor(x, y, *indices):
indices = indices_from_subtensor(flattened_indices, idx_list) # Not needed because slices aren't supported
# Not needed because slices aren't supported in this path
# check_negative_steps(indices) # check_negative_steps(indices)
if not inplace: if not inplace:
x = x.clone() x = x.clone()
......
...@@ -72,9 +72,9 @@ from pytensor.tensor.shape import shape ...@@ -72,9 +72,9 @@ from pytensor.tensor.shape import shape
from pytensor.tensor.subtensor import ( from pytensor.tensor.subtensor import (
IncSubtensor, IncSubtensor,
Subtensor, Subtensor,
basic_subtensor,
get_canonical_form_slice, get_canonical_form_slice,
get_idx_list, get_idx_list,
get_slice_elements,
set_subtensor, set_subtensor,
) )
from pytensor.tensor.variable import TensorConstant, TensorVariable from pytensor.tensor.variable import TensorConstant, TensorVariable
...@@ -1211,7 +1211,7 @@ def _is_default_scan_buffer(final_buffer: TensorVariable, taps: int) -> bool: ...@@ -1211,7 +1211,7 @@ def _is_default_scan_buffer(final_buffer: TensorVariable, taps: int) -> bool:
if not ( if not (
isinstance(op, IncSubtensor) isinstance(op, IncSubtensor)
and op.set_instead_of_inc and op.set_instead_of_inc
and op.idx_list == (slice(None, 0),) and op.idx_list == [slice(None, ps.int64)]
): ):
return False return False
...@@ -1389,6 +1389,12 @@ def scan_save_mem_rewrite(fgraph, node, backend_supports_output_pre_allocation: ...@@ -1389,6 +1389,12 @@ def scan_save_mem_rewrite(fgraph, node, backend_supports_output_pre_allocation:
else: else:
# 2.3.1 extract idx list of subtensor # 2.3.1 extract idx list of subtensor
this_slice = get_idx_list(cl.inputs, cl.op.idx_list) this_slice = get_idx_list(cl.inputs, cl.op.idx_list)
if this_slice is None:
# if unable to extract idx_list
# => outputs needs all its intermediate values
global_nsteps = None
slices[i] = None
break
# 2.3.2 extract the begin/end of the first dimension # 2.3.2 extract the begin/end of the first dimension
if i >= op_info.n_mit_mot: if i >= op_info.n_mit_mot:
...@@ -1481,6 +1487,9 @@ def scan_save_mem_rewrite(fgraph, node, backend_supports_output_pre_allocation: ...@@ -1481,6 +1487,9 @@ def scan_save_mem_rewrite(fgraph, node, backend_supports_output_pre_allocation:
break break
else: else:
this_slice = get_idx_list(cl.inputs, cl.op.idx_list) this_slice = get_idx_list(cl.inputs, cl.op.idx_list)
if this_slice is None:
store_steps[i] = 0
break
if isinstance(this_slice[0], slice): if isinstance(this_slice[0], slice):
start = this_slice[0].start start = this_slice[0].start
...@@ -1702,9 +1711,16 @@ def scan_save_mem_rewrite(fgraph, node, backend_supports_output_pre_allocation: ...@@ -1702,9 +1711,16 @@ def scan_save_mem_rewrite(fgraph, node, backend_supports_output_pre_allocation:
) )
else: else:
fslice = sanitize(cnf_slice[0]) fslice = sanitize(cnf_slice[0])
nw_slice = (fslice, *old_slices[1:])
nw_pos = inv_compress_map[idx] nw_pos = inv_compress_map[idx]
new_o = basic_subtensor(new_outs[nw_pos], fslice, *old_slices[1:])
subtens = Subtensor(nw_slice)
# slice inputs
sl_ins = get_slice_elements(
nw_slice, lambda entry: isinstance(entry, Variable)
)
new_o = cast(TensorVariable, subtens(new_outs[nw_pos], *sl_ins))
if new_o.ndim > 0: if new_o.ndim > 0:
new_o = new_o[:: cnf_slice[1]] new_o = new_o[:: cnf_slice[1]]
replaced_outs.append(idx) replaced_outs.append(idx)
...@@ -1755,7 +1771,11 @@ def scan_save_mem_rewrite(fgraph, node, backend_supports_output_pre_allocation: ...@@ -1755,7 +1771,11 @@ def scan_save_mem_rewrite(fgraph, node, backend_supports_output_pre_allocation:
) )
nw_slice = (sanitize(position), *old_slices[1:]) nw_slice = (sanitize(position), *old_slices[1:])
new_o = basic_subtensor(new_outs[nw_pos], *nw_slice) subtens = Subtensor(nw_slice)
sl_ins = get_slice_elements(
nw_slice, lambda entry: isinstance(entry, Variable)
)
new_o = cast(TensorVariable, subtens(new_outs[nw_pos], *sl_ins))
if new_o.ndim > 0: if new_o.ndim > 0:
new_o = new_o[:: cnf_slice[1]] new_o = new_o[:: cnf_slice[1]]
old_new += [(old, new_o)] old_new += [(old, new_o)]
......
...@@ -29,7 +29,7 @@ from pytensor.graph.fg import FunctionGraph, Output ...@@ -29,7 +29,7 @@ from pytensor.graph.fg import FunctionGraph, Output
from pytensor.graph.op import Op from pytensor.graph.op import Op
from pytensor.graph.replace import _vectorize_node from pytensor.graph.replace import _vectorize_node
from pytensor.graph.rewriting.db import EquilibriumDB from pytensor.graph.rewriting.db import EquilibriumDB
from pytensor.graph.type import HasShape from pytensor.graph.type import HasShape, Type
from pytensor.link.c.op import COp from pytensor.link.c.op import COp
from pytensor.link.c.params_type import ParamsType from pytensor.link.c.params_type import ParamsType
from pytensor.printing import Printer, min_informative_str, pprint, set_precedence from pytensor.printing import Printer, min_informative_str, pprint, set_precedence
...@@ -433,7 +433,7 @@ def _get_underlying_scalar_constant_value( ...@@ -433,7 +433,7 @@ def _get_underlying_scalar_constant_value(
var.ndim == 1 for var in v.owner.inputs[0].owner.inputs[1:] var.ndim == 1 for var in v.owner.inputs[0].owner.inputs[1:]
): ):
idx = v.owner.op.idx_list[0] idx = v.owner.op.idx_list[0]
if isinstance(idx, int): if isinstance(idx, Type):
idx = _get_underlying_scalar_constant_value( idx = _get_underlying_scalar_constant_value(
v.owner.inputs[1], max_recur=max_recur v.owner.inputs[1], max_recur=max_recur
) )
...@@ -467,7 +467,7 @@ def _get_underlying_scalar_constant_value( ...@@ -467,7 +467,7 @@ def _get_underlying_scalar_constant_value(
and len(v.owner.op.idx_list) == 1 and len(v.owner.op.idx_list) == 1
): ):
idx = v.owner.op.idx_list[0] idx = v.owner.op.idx_list[0]
if isinstance(idx, int): if isinstance(idx, Type):
idx = _get_underlying_scalar_constant_value( idx = _get_underlying_scalar_constant_value(
v.owner.inputs[1], max_recur=max_recur v.owner.inputs[1], max_recur=max_recur
) )
...@@ -488,7 +488,7 @@ def _get_underlying_scalar_constant_value( ...@@ -488,7 +488,7 @@ def _get_underlying_scalar_constant_value(
op = owner.op op = owner.op
idx_list = op.idx_list idx_list = op.idx_list
idx = idx_list[0] idx = idx_list[0]
if isinstance(idx, int): if isinstance(idx, Type):
idx = _get_underlying_scalar_constant_value( idx = _get_underlying_scalar_constant_value(
owner.inputs[1], max_recur=max_recur owner.inputs[1], max_recur=max_recur
) )
......
...@@ -23,7 +23,7 @@ from pytensor.tensor.subtensor import ( ...@@ -23,7 +23,7 @@ from pytensor.tensor.subtensor import (
indices_from_subtensor, indices_from_subtensor,
) )
from pytensor.tensor.type import integer_dtypes from pytensor.tensor.type import integer_dtypes
from pytensor.tensor.type_other import NoneTypeT from pytensor.tensor.type_other import NoneTypeT, SliceType
def is_rv_used_in_graph(base_rv, node, fgraph): def is_rv_used_in_graph(base_rv, node, fgraph):
...@@ -237,15 +237,20 @@ def local_subtensor_rv_lift(fgraph, node): ...@@ -237,15 +237,20 @@ def local_subtensor_rv_lift(fgraph, node):
return False return False
# Parse indices # Parse indices
if isinstance(subtensor_op, Subtensor | AdvancedSubtensor): if isinstance(subtensor_op, Subtensor):
indices = indices_from_subtensor(node.inputs[1:], subtensor_op.idx_list) indices = indices_from_subtensor(node.inputs[1:], subtensor_op.idx_list)
else: else:
indices = node.inputs[1:] indices = node.inputs[1:]
# The rewrite doesn't apply if advanced indexing could broadcast the samples (leading to duplicates) # The rewrite doesn't apply if advanced indexing could broadcast the samples (leading to duplicates)
# TODO: This rewrite is aborting with dummy indexing dimensions which aren't a problem # Note: For simplicity this also excludes subtensor-related expand_dims (np.newaxis).
# (e.g., x[[0],] is equivalent to x[0] - can only index one entry, won't lead to duplicates) # If we wanted to support that we could rewrite it as subtensor + dimshuffle
if any(is_nd_advanced_idx(idx, integer_dtypes) for idx in indices): # and make use of the dimshuffle lift rewrite
return False # TODO: This rewrite is aborting with dummy indexing dimensions which aren't a problem
if any(
is_nd_advanced_idx(idx, integer_dtypes) or isinstance(idx.type, NoneTypeT)
for idx in indices
):
return False
# Check that indexing does not act on support dims # Check that indexing does not act on support dims
batch_ndims = rv_op.batch_ndim(rv_node) batch_ndims = rv_op.batch_ndim(rv_node)
...@@ -263,7 +268,10 @@ def local_subtensor_rv_lift(fgraph, node): ...@@ -263,7 +268,10 @@ def local_subtensor_rv_lift(fgraph, node):
non_bool_indices[batch_ndims:], non_bool_indices[batch_ndims:],
) )
for idx in supp_indices: for idx in supp_indices:
if idx != slice(None): if not (
isinstance(idx.type, SliceType)
and all(isinstance(i.type, NoneTypeT) for i in idx.owner.inputs)
):
return False return False
n_discarded_idxs = len(supp_indices) n_discarded_idxs = len(supp_indices)
indices = indices[:-n_discarded_idxs] indices = indices[:-n_discarded_idxs]
...@@ -323,7 +331,7 @@ def local_subtensor_rv_lift(fgraph, node): ...@@ -323,7 +331,7 @@ def local_subtensor_rv_lift(fgraph, node):
# Broadcasted dim # Broadcasted dim
if curr_dim in bcast_param_dims: if curr_dim in bcast_param_dims:
# Slice indexing, keep degenerate dim by none-slicing # Slice indexing, keep degenerate dim by none-slicing
if isinstance(idx, slice): if isinstance(idx, slice) or isinstance(idx.type, SliceType):
batch_indices.append(slice(None)) batch_indices.append(slice(None))
# Integer indexing, drop degenerate dim by 0-indexing # Integer indexing, drop degenerate dim by 0-indexing
else: else:
......
...@@ -17,6 +17,7 @@ from pytensor.graph.rewriting.basic import ( ...@@ -17,6 +17,7 @@ from pytensor.graph.rewriting.basic import (
) )
from pytensor.graph.traversal import ancestors from pytensor.graph.traversal import ancestors
from pytensor.graph.utils import InconsistencyError, get_variable_trace_string from pytensor.graph.utils import InconsistencyError, get_variable_trace_string
from pytensor.scalar import ScalarType
from pytensor.tensor.basic import ( from pytensor.tensor.basic import (
MakeVector, MakeVector,
as_tensor_variable, as_tensor_variable,
...@@ -841,16 +842,13 @@ def _is_shape_i_of_x( ...@@ -841,16 +842,13 @@ def _is_shape_i_of_x(
if isinstance(var.owner.op, Shape_i): if isinstance(var.owner.op, Shape_i):
return (var.owner.op.i == i) and (var.owner.inputs[0] == x) # type: ignore return (var.owner.op.i == i) and (var.owner.inputs[0] == x) # type: ignore
# Match Subtensor((int,))(Shape(input), i) - single integer index into shape # Match Subtensor((ScalarType,))(Shape(input), i)
if isinstance(var.owner.op, Subtensor): if isinstance(var.owner.op, Subtensor):
idx_entry = (
var.owner.op.idx_list[0] if len(var.owner.op.idx_list) == 1 else None
)
return ( return (
# Check we have integer indexing operation # Check we have integer indexing operation
# (and not slice or multiple indexing) # (and not slice or multiple indexing)
len(var.owner.op.idx_list) == 1 len(var.owner.op.idx_list) == 1
and isinstance(idx_entry, int) and isinstance(var.owner.op.idx_list[0], ScalarType)
# Check we are indexing on the shape of x # Check we are indexing on the shape of x
and var.owner.inputs[0].owner is not None and var.owner.inputs[0].owner is not None
and isinstance(var.owner.inputs[0].owner.op, Shape) and isinstance(var.owner.inputs[0].owner.op, Shape)
......
import itertools import itertools
import sys import sys
import warnings
import numpy as np import numpy as np
...@@ -16,7 +15,7 @@ from pytensor.graph.rewriting.basic import ( ...@@ -16,7 +15,7 @@ from pytensor.graph.rewriting.basic import (
node_rewriter, node_rewriter,
) )
from pytensor.raise_op import Assert from pytensor.raise_op import Assert
from pytensor.scalar import Add, ScalarConstant from pytensor.scalar import Add, ScalarConstant, ScalarType
from pytensor.scalar import constant as scalar_constant from pytensor.scalar import constant as scalar_constant
from pytensor.tensor.basic import ( from pytensor.tensor.basic import (
Alloc, Alloc,
...@@ -32,7 +31,6 @@ from pytensor.tensor.basic import ( ...@@ -32,7 +31,6 @@ from pytensor.tensor.basic import (
full, full,
get_scalar_constant_value, get_scalar_constant_value,
get_underlying_scalar_constant_value, get_underlying_scalar_constant_value,
moveaxis,
register_infer_shape, register_infer_shape,
switch, switch,
) )
...@@ -74,11 +72,10 @@ from pytensor.tensor.subtensor import ( ...@@ -74,11 +72,10 @@ from pytensor.tensor.subtensor import (
AdvancedSubtensor1, AdvancedSubtensor1,
IncSubtensor, IncSubtensor,
Subtensor, Subtensor,
_non_consecutive_adv_indexing,
advanced_inc_subtensor1, advanced_inc_subtensor1,
advanced_subtensor,
advanced_subtensor1, advanced_subtensor1,
as_index_constant, as_index_constant,
basic_subtensor,
get_canonical_form_slice, get_canonical_form_slice,
get_constant_idx, get_constant_idx,
get_idx_list, get_idx_list,
...@@ -87,6 +84,7 @@ from pytensor.tensor.subtensor import ( ...@@ -87,6 +84,7 @@ from pytensor.tensor.subtensor import (
indices_from_subtensor, indices_from_subtensor,
) )
from pytensor.tensor.type import TensorType from pytensor.tensor.type import TensorType
from pytensor.tensor.type_other import NoneTypeT, SliceType
from pytensor.tensor.variable import TensorConstant, TensorVariable from pytensor.tensor.variable import TensorConstant, TensorVariable
...@@ -156,10 +154,8 @@ def transform_take(a, indices, axis): ...@@ -156,10 +154,8 @@ def transform_take(a, indices, axis):
if len(shape_parts) > 1: if len(shape_parts) > 1:
shape = pytensor.tensor.concatenate(shape_parts) shape = pytensor.tensor.concatenate(shape_parts)
elif len(shape_parts) == 1:
shape = shape_parts[0]
else: else:
shape = () shape = shape_parts[0]
ndim = a.ndim + indices.ndim - 1 ndim = a.ndim + indices.ndim - 1
...@@ -167,11 +163,23 @@ def transform_take(a, indices, axis): ...@@ -167,11 +163,23 @@ def transform_take(a, indices, axis):
def is_full_slice(x): def is_full_slice(x):
warnings.warn( """Determine if `x` is a ``slice(None)`` or a symbolic equivalent."""
"The function is deprecated, use x==slice(None) instead.", if isinstance(x, slice):
DeprecationWarning, return x == slice(None)
)
return x == slice(None) if isinstance(x, Variable) and isinstance(x.type, SliceType):
if x.owner is None:
if isinstance(x, Constant):
return x.data == slice(None)
else:
# Root slice variable
return False
# Symbolic MakeSlice
# Ignores start = 0, step = 1 cases
return all(isinstance(i.type, NoneTypeT) for i in x.owner.inputs)
return False
def get_advsubtensor_axis(indices): def get_advsubtensor_axis(indices):
...@@ -186,13 +194,13 @@ def get_advsubtensor_axis(indices): ...@@ -186,13 +194,13 @@ def get_advsubtensor_axis(indices):
found_idx = False found_idx = False
axis = 0 axis = 0
for idx in indices: for idx in indices:
if not found_idx and idx == slice(None): if not found_idx and is_full_slice(idx):
# Preceding full slices # Preceding full slices
axis += 1 axis += 1
elif found_idx and not idx == slice(None): elif found_idx and not is_full_slice(idx):
# We don't handle multiple indices # We don't handle multiple indices
return return
elif found_idx and idx == slice(None): elif found_idx and is_full_slice(idx):
# Trailing full slices # Trailing full slices
continue continue
else: else:
...@@ -219,8 +227,9 @@ def local_replace_AdvancedSubtensor(fgraph, node): ...@@ -219,8 +227,9 @@ def local_replace_AdvancedSubtensor(fgraph, node):
if not isinstance(node.op, AdvancedSubtensor): if not isinstance(node.op, AdvancedSubtensor):
return return
indexed_var, *index_variables = node.inputs indexed_var = node.inputs[0]
indices = indices_from_subtensor(index_variables, node.op.idx_list) indices = node.inputs[1:]
axis = get_advsubtensor_axis(indices) axis = get_advsubtensor_axis(indices)
if axis is None or indices[axis].dtype == "bool": if axis is None or indices[axis].dtype == "bool":
...@@ -244,8 +253,9 @@ def local_AdvancedIncSubtensor_to_AdvancedIncSubtensor1(fgraph, node): ...@@ -244,8 +253,9 @@ def local_AdvancedIncSubtensor_to_AdvancedIncSubtensor1(fgraph, node):
# `AdvancedIncSubtensor1` does not ignore duplicate index values # `AdvancedIncSubtensor1` does not ignore duplicate index values
return return
res, val, *index_variables = node.inputs res = node.inputs[0]
indices = indices_from_subtensor(index_variables, node.op.idx_list) val = node.inputs[1]
indices = node.inputs[2:]
axis = get_advsubtensor_axis(indices) axis = get_advsubtensor_axis(indices)
...@@ -418,7 +428,11 @@ def local_subtensor_merge(fgraph, node): ...@@ -418,7 +428,11 @@ def local_subtensor_merge(fgraph, node):
merged_slices += slices1[pos_1:] merged_slices += slices1[pos_1:]
merged_slices = tuple(as_index_constant(s) for s in merged_slices) merged_slices = tuple(as_index_constant(s) for s in merged_slices)
out = basic_subtensor(x, *merged_slices) subtens = Subtensor(merged_slices)
sl_ins = get_slice_elements(merged_slices, lambda x: isinstance(x, Variable))
# Do not call make_node for test_value
out = subtens(x, *sl_ins)
# Copy over previous output stacktrace # Copy over previous output stacktrace
# and stacktrace from previous slicing operation. # and stacktrace from previous slicing operation.
...@@ -449,8 +463,9 @@ def local_subtensor_remove_broadcastable_index(fgraph, node): ...@@ -449,8 +463,9 @@ def local_subtensor_remove_broadcastable_index(fgraph, node):
remove_dim = [] remove_dim = []
node_inputs_idx = 1 node_inputs_idx = 1
for dim, elem in enumerate(idx): for dim, elem in enumerate(idx):
if isinstance(elem, int): if isinstance(elem, ScalarType):
# The idx is a integer position. # The idx is a ScalarType, ie a Type. This means the actual index
# is contained in node.inputs[1]
dim_index = node.inputs[node_inputs_idx] dim_index = node.inputs[node_inputs_idx]
if isinstance(dim_index, ScalarConstant): if isinstance(dim_index, ScalarConstant):
dim_index = dim_index.value dim_index = dim_index.value
...@@ -462,6 +477,9 @@ def local_subtensor_remove_broadcastable_index(fgraph, node): ...@@ -462,6 +477,9 @@ def local_subtensor_remove_broadcastable_index(fgraph, node):
elif isinstance(elem, slice): elif isinstance(elem, slice):
if elem != slice(None): if elem != slice(None):
return return
elif isinstance(elem, int | np.integer):
if elem in (0, -1) and node.inputs[0].broadcastable[dim]:
remove_dim.append(dim)
else: else:
raise TypeError("case not expected") raise TypeError("case not expected")
...@@ -488,29 +506,26 @@ def local_subtensor_inc_subtensor(fgraph, node): ...@@ -488,29 +506,26 @@ def local_subtensor_inc_subtensor(fgraph, node):
if not x.owner.op.set_instead_of_inc: if not x.owner.op.set_instead_of_inc:
return return
x_inc, y_inc, *inc_index_variables = x.owner.inputs if x.owner.inputs[2:] == node.inputs[1:] and tuple(
_sub_x, *sub_index_variables = node.inputs x.owner.op.idx_list
) == tuple(node.op.idx_list):
if (
inc_index_variables == sub_index_variables
and x.owner.op.idx_list == node.op.idx_list
):
out = node.outputs[0] out = node.outputs[0]
y = x.owner.inputs[1]
# If the dtypes differ, cast y into x.dtype # If the dtypes differ, cast y into x.dtype
if x.dtype != y_inc.dtype: if x.dtype != y.dtype:
y_inc = y_inc.astype(x.dtype) y = y.astype(x.dtype)
if ( if (
out.type.dtype == y_inc.type.dtype out.type.dtype == y.type.dtype
and out.type.broadcastable == y_inc.type.broadcastable and out.type.broadcastable == y.type.broadcastable
): ):
# if x[idx] and y have the same type, directly return y # if x[idx] and y have the same type, directly return y
return [y_inc] return [y]
else: else:
# The difference is related to broadcasting pattern # The difference is related to broadcasting pattern
assert out.broadcastable != y_inc.broadcastable assert out.broadcastable != y.broadcastable
# We have to alloc y to the shape of x[idx] # We have to alloc y to the shape of x[idx]
x_subtensor = node.op(x_inc, *inc_index_variables) x_subtensor = node.op(x.owner.inputs[0], *x.owner.inputs[2:])
return [alloc(y_inc, *x_subtensor.shape)] return [alloc(y, *x_subtensor.shape)]
else: else:
return return
...@@ -814,9 +829,9 @@ def merge_two_slices(fgraph, slice1, len1, slice2, len2): ...@@ -814,9 +829,9 @@ def merge_two_slices(fgraph, slice1, len1, slice2, len2):
raise ValueError("slice1 should be of type `slice`") raise ValueError("slice1 should be of type `slice`")
# Simple case where one of the slices is useless # Simple case where one of the slices is useless
if slice1 == slice(None): if is_full_slice(slice1):
return slice2 return slice2
elif slice2 == slice(None): elif is_full_slice(slice2):
return slice1 return slice1
sl1, reverse1 = get_canonical_form_slice(slice1, len1) sl1, reverse1 = get_canonical_form_slice(slice1, len1)
...@@ -1075,7 +1090,6 @@ compile.optdb.register( ...@@ -1075,7 +1090,6 @@ compile.optdb.register(
def local_inplace_AdvancedIncSubtensor(fgraph, node): def local_inplace_AdvancedIncSubtensor(fgraph, node):
if isinstance(node.op, AdvancedIncSubtensor) and not node.op.inplace: if isinstance(node.op, AdvancedIncSubtensor) and not node.op.inplace:
new_op = type(node.op)( new_op = type(node.op)(
node.op.idx_list,
inplace=True, inplace=True,
set_instead_of_inc=node.op.set_instead_of_inc, set_instead_of_inc=node.op.set_instead_of_inc,
ignore_duplicates=node.op.ignore_duplicates, ignore_duplicates=node.op.ignore_duplicates,
...@@ -1262,7 +1276,9 @@ def local_useless_inc_subtensor_alloc(fgraph, node): ...@@ -1262,7 +1276,9 @@ def local_useless_inc_subtensor_alloc(fgraph, node):
""" """
if isinstance(node.op, IncSubtensor | AdvancedIncSubtensor | AdvancedIncSubtensor1): if isinstance(node.op, IncSubtensor | AdvancedIncSubtensor | AdvancedIncSubtensor1):
x, y, *index_variables = node.inputs x = node.inputs[0]
y = node.inputs[1]
i = node.inputs[2:]
if y.owner is not None and isinstance(y.owner.op, Alloc): if y.owner is not None and isinstance(y.owner.op, Alloc):
# `z` is the input of the Alloc op, i.e. at.alloc(z, <shape>) # `z` is the input of the Alloc op, i.e. at.alloc(z, <shape>)
...@@ -1281,11 +1297,11 @@ def local_useless_inc_subtensor_alloc(fgraph, node): ...@@ -1281,11 +1297,11 @@ def local_useless_inc_subtensor_alloc(fgraph, node):
# Get the subtensor of `x` indexed by `i` in order to compare # Get the subtensor of `x` indexed by `i` in order to compare
# shapes later. # shapes later.
if isinstance(node.op, IncSubtensor): if isinstance(node.op, IncSubtensor):
xi = Subtensor(node.op.idx_list)(x, *index_variables) xi = Subtensor(node.op.idx_list)(x, *i)
elif isinstance(node.op, AdvancedIncSubtensor): elif isinstance(node.op, AdvancedIncSubtensor):
xi = AdvancedSubtensor(node.op.idx_list)(x, *index_variables) xi = advanced_subtensor(x, *i)
elif isinstance(node.op, AdvancedIncSubtensor1): elif isinstance(node.op, AdvancedIncSubtensor1):
xi = advanced_subtensor1(x, *index_variables) xi = advanced_subtensor1(x, *i)
else: else:
raise Exception("Should never happen!") raise Exception("Should never happen!")
...@@ -1345,7 +1361,7 @@ def local_useless_inc_subtensor_alloc(fgraph, node): ...@@ -1345,7 +1361,7 @@ def local_useless_inc_subtensor_alloc(fgraph, node):
msg = "`x[i]` and `y` do not have the same shape." msg = "`x[i]` and `y` do not have the same shape."
z = Assert(msg)(z, *cond) z = Assert(msg)(z, *cond)
r = node.op(x, z, *index_variables) r = node.op(x, z, *i)
# Copy over stacktrace from previous output, since # Copy over stacktrace from previous output, since
# we don't expect problems when removing the intermediate # we don't expect problems when removing the intermediate
# alloc operation and so we still want to point at the line # alloc operation and so we still want to point at the line
...@@ -1477,7 +1493,8 @@ def local_uint_constant_indices(fgraph, node): ...@@ -1477,7 +1493,8 @@ def local_uint_constant_indices(fgraph, node):
x, *indices = node.inputs x, *indices = node.inputs
y = None y = None
new_indices = list(indices_from_subtensor(indices, node.op.idx_list)) idx_list = getattr(node.op, "idx_list", None)
new_indices = list(indices_from_subtensor(indices, idx_list))
has_new_index = False has_new_index = False
for i, index in enumerate(new_indices): for i, index in enumerate(new_indices):
...@@ -1527,7 +1544,14 @@ def local_uint_constant_indices(fgraph, node): ...@@ -1527,7 +1544,14 @@ def local_uint_constant_indices(fgraph, node):
if not has_new_index: if not has_new_index:
return False return False
new_indices = get_slice_elements(new_indices) if isinstance(op, Subtensor | IncSubtensor):
# Basic index Ops contain information about the dtype of the indices, so wee have to recreate them
props = op._props_dict()
props["idx_list"] = new_indices
op = type(op)(**props)
# Basic index Ops don't expect slices, but the respective start/step/stop
new_indices = get_slice_elements(new_indices)
new_args = (x, *new_indices) if y is None else (x, y, *new_indices) new_args = (x, *new_indices) if y is None else (x, y, *new_indices)
new_out = op(*new_args) new_out = op(*new_args)
copy_stack_trace(node.outputs[0], new_out) copy_stack_trace(node.outputs[0], new_out)
...@@ -1587,18 +1611,27 @@ def local_blockwise_inc_subtensor(fgraph, node): ...@@ -1587,18 +1611,27 @@ def local_blockwise_inc_subtensor(fgraph, node):
core_op = node.op.core_op core_op = node.op.core_op
x, y, *idxs = node.inputs x, y, *idxs = node.inputs
[out] = node.outputs [out] = node.outputs
advanced = isinstance(core_op, AdvancedIncSubtensor) if isinstance(core_op, AdvancedIncSubtensor):
if any(
if advanced and any(idx.type.dtype == "bool" for idx in idxs): (
# Get out if we have boolean indices as they cross dimension boundaries # Blockwise requires all inputs to be tensors so it is not possible
# / can't be safely broadcasted depending on their runtime content # to wrap an AdvancedIncSubtensor with slice / newaxis inputs, but we check again just in case
return None # If this is ever supported we need to pay attention to special behavior of numpy when advanced indices
# are separated by basic indices
isinstance(idx, SliceType | NoneTypeT)
# Also get out if we have boolean indices as they cross dimension boundaries
# / can't be safely broadcasted depending on their runtime content
or (idx.type.dtype == "bool")
)
for idx in idxs
):
return None
batch_ndim = node.op.batch_ndim(node) batch_ndim = node.op.batch_ndim(node)
idxs_core_ndim = [len(inp_sig) for inp_sig in node.op.inputs_sig[2:]] idxs_core_ndim = [len(inp_sig) for inp_sig in node.op.inputs_sig[2:]]
max_idx_core_ndim = max(idxs_core_ndim, default=0) max_idx_core_ndim = max(idxs_core_ndim, default=0)
# Broadcast buffer to batch_shape # Step 1. Broadcast buffer to batch_shape
if x.type.broadcastable != out.type.broadcastable: if x.type.broadcastable != out.type.broadcastable:
batch_shape = [1] * batch_ndim batch_shape = [1] * batch_ndim
for inp in node.inputs: for inp in node.inputs:
...@@ -1615,61 +1648,58 @@ def local_blockwise_inc_subtensor(fgraph, node): ...@@ -1615,61 +1648,58 @@ def local_blockwise_inc_subtensor(fgraph, node):
x = broadcast_to(x, (*batch_shape, *x.shape[batch_ndim:])) x = broadcast_to(x, (*batch_shape, *x.shape[batch_ndim:]))
assert x.type.broadcastable == out.type.broadcastable assert x.type.broadcastable == out.type.broadcastable
# Massage indices so they respect blockwise semantics while using regular indexing # Step 2. Massage indices so they respect blockwise semantics
core_idxs = [] if isinstance(core_op, IncSubtensor):
for idx_entry in core_op.idx_list: # For basic IncSubtensor there are two cases:
if isinstance(idx_entry, slice): # 1. Slice entries -> We need to squeeze away dummy dimensions so we can convert back to slice
# Squeeze away dummy dimensions so we can convert to slice # 2. Integers -> Can be used as is, but we try to squeeze away dummy batch dimensions
new_entries = [None, None, None] # in case we can end up with a basic IncSubtensor again
for i, slice_idx_entry in enumerate( core_idxs = []
(idx_entry.start, idx_entry.stop, idx_entry.step) counter = 0
): for idx in core_op.idx_list:
if slice_idx_entry is None: if isinstance(idx, slice):
continue # Squeeze away dummy dimensions so we can convert to slice
else: new_entries = [None, None, None]
new_entries[i] = new_entry = idxs[slice_idx_entry].squeeze() for i, entry in enumerate((idx.start, idx.stop, idx.step)):
if new_entry.ndim > 0: if entry is None:
# If the slice entry has dimensions after the squeeze we can't convert it to a slice continue
# We could try to convert to equivalent integer indices, but nothing guarantees else:
# that the slice is "square". new_entries[i] = new_entry = idxs[counter].squeeze()
return None counter += 1
squeezed_index = slice(*new_entries) if new_entry.ndim > 0:
else: # If the slice entry has dimensions after the squeeze we can't convert it to a slice
if advanced: # We could try to convert to equivalent integer indices, but nothing guarantees
# For AdvancedIncSubtensor we have tensor integer indices, # that the slice is "square".
# We need to expand batch indexes on the right, so they don't interact with core index dimensions return None
# We still squeeze on the left in case that allows us to use simpler indices core_idxs.append(slice(*new_entries))
squeezed_index = _squeeze_left(
shape_padright(
idxs[idx_entry], max_idx_core_ndim - idxs_core_ndim[idx_entry]
),
stop_at_dim=batch_ndim,
)
else: else:
# For basic IncSubtensor integers indices can be used as is, but we try to squeeze away dummy core_idxs.append(_squeeze_left(idxs[counter]))
# batch dimensions in case we can end up with a basic IncSubtensor again counter += 1
squeezed_index = _squeeze_left(idxs[idx_entry]) else:
# For AdvancedIncSubtensor we have tensor integer indices,
core_idxs.append(squeezed_index) # We need to expand batch indexes on the right, so they don't interact with core index dimensions
# We still squeeze on the left in case that allows us to use simpler indices
core_idxs = [
_squeeze_left(
shape_padright(idx, max_idx_core_ndim - idx_core_ndim),
stop_at_dim=batch_ndim,
)
for idx, idx_core_ndim in zip(idxs, idxs_core_ndim)
]
# Create new indices for the batch dimensions # Step 3. Create new indices for the new batch dimension of x
has_batched_indices = not all( if not all(
all(idx.type.broadcastable[:batch_ndim]) all(idx.type.broadcastable[:batch_ndim])
for idx in idxs for idx in idxs
if not isinstance(idx, slice) if not isinstance(idx, slice)
) ):
if has_batched_indices: # If indices have batch dimensions in the indices, they will interact with the new dimensions of x
# If indices have batch dimensions, we need to align them element-wise with the respective batch dimensions of x # We build vectorized indexing with new arange indices that do not interact with core indices or each other
# We achieve this by creating `arange` indices and adding expand_dims for correct broadcasting. # (i.e., they broadcast)
# Example:
# x = pt.zeros(5); idx = [0, 1, 0]; out = x[idx].set(y) # Note: due to how numpy handles non-consecutive advanced indexing (transposing it to the front),
# batch_x = pt.zeros((2, 5)); batch_idx = [[0, 1, 0], [1, 1, 2]] # we don't want to create a mix of slice(None), and arange() indices for the new batch dimension,
# batch_out = batch_x[[0, 1][:, None], batch_idx].set(y) # even if not all batch dimensions have corresponding batch indices.
# If instead batch_x = pt.zeros((2, 2, 5))
# batch_out = batch_x[[0, 1][:, None, None], [0, 1][None, 1, None], batch_idx]
# Note: For simplicity we use arange for all batch dimensions of x,
# even if not all may have corresponding batch index dimensions
batch_slices = [ batch_slices = [
shape_padright(arange(x_batch_shape, dtype="int64"), n) shape_padright(arange(x_batch_shape, dtype="int64"), n)
for (x_batch_shape, n) in zip( for (x_batch_shape, n) in zip(
...@@ -1685,49 +1715,29 @@ def local_blockwise_inc_subtensor(fgraph, node): ...@@ -1685,49 +1715,29 @@ def local_blockwise_inc_subtensor(fgraph, node):
new_idxs = (*batch_slices, *core_idxs) new_idxs = (*batch_slices, *core_idxs)
x_view = x[new_idxs] x_view = x[new_idxs]
# Introduce any implicit expand_dims on core dimension of y # Step 4. Introduce any implicit expand_dims on core dimension of y
missing_y_core_ndim = x_view.type.ndim - y.type.ndim missing_y_core_ndim = x_view.type.ndim - y.type.ndim
implicit_axes = tuple(range(batch_ndim, batch_ndim + missing_y_core_ndim)) implicit_axes = tuple(range(batch_ndim, batch_ndim + missing_y_core_ndim))
y = expand_dims(y, implicit_axes) y = _squeeze_left(expand_dims(y, implicit_axes), stop_at_dim=batch_ndim)
# Transpose y if needed if isinstance(core_op, IncSubtensor):
if has_batched_indices: # Check if we can still use a basic IncSubtensor
# By introducing arange slices we may caused a transposition of the advanced group to the front if isinstance(x_view.owner.op, Subtensor):
# If this was not already happening in the core graph, we'll need to transpose y to align it correctly new_props = core_op._props_dict()
if max_idx_core_ndim and not ( new_props["idx_list"] = x_view.owner.op.idx_list
advanced and _non_consecutive_adv_indexing(core_idxs) new_core_op = type(core_op)(**new_props)
): symbolic_idxs = x_view.owner.inputs[1:]
integer_pos = [ new_out = new_core_op(x, y, *symbolic_idxs)
i for i, entry in enumerate(core_op.idx_list) if isinstance(entry, int) else:
] # We need to use AdvancedSet/IncSubtensor
slice_pos = [ if core_op.set_instead_of_inc:
i new_out = x[new_idxs].set(y)
for i, entry in enumerate(core_op.idx_list) else:
if isinstance(entry, slice) new_out = x[new_idxs].inc(y)
]
if slice_pos and integer_pos and (slice_pos[0] < integer_pos[-1]):
y = moveaxis(
y,
[batch_ndim + integer_pos[0] + i for i in range(max_idx_core_ndim)],
[batch_ndim + i for i in range(max_idx_core_ndim)],
)
else:
# Conversely if we tried to use `slice(None)` for the batch dimensions but there was already transposition
# in the core case, we'll need to move the batch slices of y to after the advanced indexing group
if advanced and _non_consecutive_adv_indexing(core_idxs):
y = moveaxis(
y,
[i for i in range(batch_ndim)], # noqa: C416
[max_idx_core_ndim + i for i in range(batch_ndim)],
)
# Remove useless left-batch dimensions of y (if any)
y = _squeeze_left(y, stop_at_dim=batch_ndim)
if core_op.set_instead_of_inc:
new_out = x[new_idxs].set(y)
else: else:
new_out = x[new_idxs].inc(y) # AdvancedIncSubtensor takes symbolic indices/slices directly, no need to create a new op
symbolic_idxs = x_view.owner.inputs[1:]
new_out = core_op(x, y, *symbolic_idxs)
copy_stack_trace(out, new_out) copy_stack_trace(out, new_out)
return [new_out] return [new_out]
...@@ -1744,12 +1754,10 @@ def bool_idx_to_nonzero(fgraph, node): ...@@ -1744,12 +1754,10 @@ def bool_idx_to_nonzero(fgraph, node):
else: else:
x, y, *idxs = node.inputs x, y, *idxs = node.inputs
idxs = indices_from_subtensor(idxs, node.op.idx_list)
bool_pos = { bool_pos = {
i i
for i, idx in enumerate(idxs) for i, idx in enumerate(idxs)
if isinstance(idx, TensorVariable) and idx.dtype == "bool" if (isinstance(idx.type, TensorType) and idx.dtype == "bool")
} }
if not bool_pos: if not bool_pos:
...@@ -1763,13 +1771,9 @@ def bool_idx_to_nonzero(fgraph, node): ...@@ -1763,13 +1771,9 @@ def bool_idx_to_nonzero(fgraph, node):
new_idxs.append(idx) new_idxs.append(idx)
if isinstance(node.op, AdvancedSubtensor): if isinstance(node.op, AdvancedSubtensor):
new_out = x[tuple(new_idxs)] new_out = node.op(x, *new_idxs)
else: else:
new_out = ( new_out = node.op(x, y, *new_idxs)
x[tuple(new_idxs)].set(y)
if node.op.set_instead_of_inc
else x[tuple(new_idxs)].inc(y)
)
return [copy_stack_trace(node.outputs[0], new_out)] return [copy_stack_trace(node.outputs[0], new_out)]
...@@ -1818,8 +1822,7 @@ def extract_diag_of_diagonal_set_subtensor(fgraph, node): ...@@ -1818,8 +1822,7 @@ def extract_diag_of_diagonal_set_subtensor(fgraph, node):
): ):
return None return None
x, y, *idx_variables = diag_x.owner.inputs x, y, *idxs = diag_x.owner.inputs
idxs = indices_from_subtensor(idx_variables, diag_x.owner.op.idx_list)
if not ( if not (
x.type.ndim >= 2 x.type.ndim >= 2
...@@ -1835,7 +1838,7 @@ def extract_diag_of_diagonal_set_subtensor(fgraph, node): ...@@ -1835,7 +1838,7 @@ def extract_diag_of_diagonal_set_subtensor(fgraph, node):
# Check all non-axis indices are full slices # Check all non-axis indices are full slices
axis = {op.axis1, op.axis2} axis = {op.axis1, op.axis2}
if not all(idx == slice(None) for i, idx in enumerate(idxs) if i not in axis): if not all(is_full_slice(idx) for i, idx in enumerate(idxs) if i not in axis):
return None return None
# Check axis indices are arange we would expect from setting on the diagonal # Check axis indices are arange we would expect from setting on the diagonal
......
...@@ -8,6 +8,7 @@ from pytensor import Variable ...@@ -8,6 +8,7 @@ from pytensor import Variable
from pytensor.compile import optdb from pytensor.compile import optdb
from pytensor.graph import Constant, FunctionGraph, node_rewriter, vectorize_graph from pytensor.graph import Constant, FunctionGraph, node_rewriter, vectorize_graph
from pytensor.graph.rewriting.basic import NodeRewriter, copy_stack_trace from pytensor.graph.rewriting.basic import NodeRewriter, copy_stack_trace
from pytensor.scalar import basic as ps
from pytensor.tensor.basic import ( from pytensor.tensor.basic import (
Alloc, Alloc,
Join, Join,
...@@ -30,7 +31,7 @@ from pytensor.tensor.rewriting.basic import ( ...@@ -30,7 +31,7 @@ from pytensor.tensor.rewriting.basic import (
register_stabilize, register_stabilize,
) )
from pytensor.tensor.rewriting.elemwise import local_dimshuffle_lift from pytensor.tensor.rewriting.elemwise import local_dimshuffle_lift
from pytensor.tensor.rewriting.subtensor import register_useless from pytensor.tensor.rewriting.subtensor import is_full_slice, register_useless
from pytensor.tensor.shape import ( from pytensor.tensor.shape import (
Shape, Shape,
SpecifyShape, SpecifyShape,
...@@ -49,6 +50,7 @@ from pytensor.tensor.subtensor import ( ...@@ -49,6 +50,7 @@ from pytensor.tensor.subtensor import (
indices_from_subtensor, indices_from_subtensor,
) )
from pytensor.tensor.type import TensorType from pytensor.tensor.type import TensorType
from pytensor.tensor.type_other import NoneTypeT, SliceType
from pytensor.tensor.variable import TensorVariable from pytensor.tensor.variable import TensorVariable
...@@ -69,7 +71,7 @@ def _axis_is_indexed_by_basic_index( ...@@ -69,7 +71,7 @@ def _axis_is_indexed_by_basic_index(
) -> bool: ) -> bool:
if isinstance(axis, int): if isinstance(axis, int):
axis = (axis,) axis = (axis,)
return any(ax < len(idxs) and not idxs[ax] == slice(None) for ax in axis) return any(ax < len(idxs) and not is_full_slice(idxs[ax]) for ax in axis)
def _lift_subtensor_non_axis( def _lift_subtensor_non_axis(
...@@ -81,7 +83,7 @@ def _lift_subtensor_non_axis( ...@@ -81,7 +83,7 @@ def _lift_subtensor_non_axis(
old_subtensor_variable: TensorVariable, old_subtensor_variable: TensorVariable,
) -> None | list[TensorVariable]: ) -> None | list[TensorVariable]:
# Apply generic subtensor lift rewrite along "non-axis" dimensions # Apply generic subtensor lift rewrite along "non-axis" dimensions
real_indices = [idx for idx in idx_tuple if not idx == slice(None)] real_indices = [idx for idx in idx_tuple if not is_full_slice(idx)]
if len(real_indices) > 1 and variable.type.ndim > 1: if len(real_indices) > 1 and variable.type.ndim > 1:
# Split the subtensor # Split the subtensor
idx_to_keep = idx_tuple[axis] idx_to_keep = idx_tuple[axis]
...@@ -204,7 +206,7 @@ def local_subtensor_of_batch_dims(fgraph, node): ...@@ -204,7 +206,7 @@ def local_subtensor_of_batch_dims(fgraph, node):
if len(idx_tuple) > batch_ndim: if len(idx_tuple) > batch_ndim:
# Indexing on core dimensions of Blockwise. We split the indices and lift the batch ones only # Indexing on core dimensions of Blockwise. We split the indices and lift the batch ones only
batch_indices, core_indices = idx_tuple[:batch_ndim], idx_tuple[batch_ndim:] batch_indices, core_indices = idx_tuple[:batch_ndim], idx_tuple[batch_ndim:]
if all(idx == slice(None) for idx in batch_indices): if all(is_full_slice(idx) for idx in batch_indices):
# No batch indices, nothing to do # No batch indices, nothing to do
return None return None
elem_with_batch_indices = elem[batch_indices] elem_with_batch_indices = elem[batch_indices]
...@@ -238,7 +240,7 @@ def local_subtensor_of_batch_dims(fgraph, node): ...@@ -238,7 +240,7 @@ def local_subtensor_of_batch_dims(fgraph, node):
strict=False, strict=False,
) )
): ):
if dim_idx == slice(None): if is_full_slice(dim_idx):
# Full slice can be safely applied to all inputs # Full slice can be safely applied to all inputs
continue continue
...@@ -427,7 +429,7 @@ def local_subtensor_of_expand_dims(fgraph, node): ...@@ -427,7 +429,7 @@ def local_subtensor_of_expand_dims(fgraph, node):
if i in expanded_axes: if i in expanded_axes:
if isinstance(idx_item, slice): if isinstance(idx_item, slice):
# Slice could be keeping or dropping this dimension # Slice could be keeping or dropping this dimension
if idx_item == slice(None): if is_full_slice(idx_item):
# A None slice, always keeps the dimension. # A None slice, always keeps the dimension.
# We skip the index, and later introduce the needed expand_dim # We skip the index, and later introduce the needed expand_dim
continue continue
...@@ -646,7 +648,10 @@ def local_subtensor_SpecifyShape_lift(fgraph, node): ...@@ -646,7 +648,10 @@ def local_subtensor_SpecifyShape_lift(fgraph, node):
indices = get_idx_list(node.inputs, node.op.idx_list) indices = get_idx_list(node.inputs, node.op.idx_list)
if any(isinstance(index, slice) for index in indices): if any(
isinstance(index, slice) or isinstance(getattr(index, "type", None), SliceType)
for index in indices
):
return False return False
new_obj_arg = obj_arg[indices] new_obj_arg = obj_arg[indices]
...@@ -697,12 +702,15 @@ def local_subtensor_make_vector(fgraph, node): ...@@ -697,12 +702,15 @@ def local_subtensor_make_vector(fgraph, node):
(idx,) = idxs (idx,) = idxs
if isinstance(idx, int): if isinstance(idx, ps.ScalarType | TensorType):
idx = node.inputs[1] old_idx, idx = idx, node.inputs[1]
assert idx.type.is_super(old_idx)
elif isinstance(node.op, AdvancedSubtensor1): elif isinstance(node.op, AdvancedSubtensor1):
idx = node.inputs[1] idx = node.inputs[1]
if isinstance(idx, Variable): if isinstance(idx, int | np.integer):
return [x.owner.inputs[idx]]
elif isinstance(idx, Variable):
if idx.ndim == 0: if idx.ndim == 0:
try: try:
v = get_underlying_scalar_constant_value( v = get_underlying_scalar_constant_value(
...@@ -825,6 +833,8 @@ def local_subtensor_shape_constant(fgraph, node): ...@@ -825,6 +833,8 @@ def local_subtensor_shape_constant(fgraph, node):
except NotScalarConstantError: except NotScalarConstantError:
return False return False
assert idx_val != np.newaxis
if not isinstance(shape_arg.type, TensorType): if not isinstance(shape_arg.type, TensorType):
return False return False
...@@ -861,24 +871,22 @@ def local_subtensor_of_adv_subtensor(fgraph, node): ...@@ -861,24 +871,22 @@ def local_subtensor_of_adv_subtensor(fgraph, node):
# AdvancedSubtensor involves a full_copy, so we don't want to do it twice # AdvancedSubtensor involves a full_copy, so we don't want to do it twice
return None return None
x, *adv_index_vars = adv_subtensor.owner.inputs x, *adv_idxs = adv_subtensor.owner.inputs
adv_idxs = indices_from_subtensor(adv_index_vars, adv_subtensor.owner.op.idx_list)
# Advanced indexing is a minefield, avoid all cases except for consecutive integer indices # Advanced indexing is a minefield, avoid all cases except for consecutive integer indices
if ( if any(
not all( (
( isinstance(adv_idx.type, NoneTypeT)
(isinstance(adv_idx, TensorVariable) and adv_idx.type.dtype != "bool") or (isinstance(adv_idx.type, TensorType) and adv_idx.type.dtype == "bool")
or (isinstance(adv_idx, slice) and adv_idx == slice(None)) or (isinstance(adv_idx.type, SliceType) and not is_full_slice(adv_idx))
)
for adv_idx in adv_idxs
) )
for adv_idx in adv_idxs
) or _non_consecutive_adv_indexing(adv_idxs): ) or _non_consecutive_adv_indexing(adv_idxs):
return None return None
for first_adv_idx_dim, adv_idx in enumerate(adv_idxs): for first_adv_idx_dim, adv_idx in enumerate(adv_idxs):
# We already made sure there were only None slices besides integer indexes # We already made sure there were only None slices besides integer indexes
if isinstance(adv_idx, TensorVariable): if isinstance(adv_idx.type, TensorType):
break break
else: # no-break else: # no-break
# Not sure if this should ever happen, but better safe than sorry # Not sure if this should ever happen, but better safe than sorry
...@@ -901,7 +909,7 @@ def local_subtensor_of_adv_subtensor(fgraph, node): ...@@ -901,7 +909,7 @@ def local_subtensor_of_adv_subtensor(fgraph, node):
copy_stack_trace([basic_subtensor, adv_subtensor], x_indexed) copy_stack_trace([basic_subtensor, adv_subtensor], x_indexed)
x_after_index_lift = expand_dims(x_indexed, dropped_dims) x_after_index_lift = expand_dims(x_indexed, dropped_dims)
x_after_adv_idx = adv_subtensor.owner.op(x_after_index_lift, *adv_index_vars) x_after_adv_idx = adv_subtensor.owner.op(x_after_index_lift, *adv_idxs)
copy_stack_trace([basic_subtensor, adv_subtensor], x_after_adv_idx) copy_stack_trace([basic_subtensor, adv_subtensor], x_after_adv_idx)
new_out = squeeze(x_after_adv_idx[basic_idxs_kept], dropped_dims) new_out = squeeze(x_after_adv_idx[basic_idxs_kept], dropped_dims)
......
...@@ -38,7 +38,7 @@ from pytensor.tensor.elemwise import CAReduce, DimShuffle ...@@ -38,7 +38,7 @@ from pytensor.tensor.elemwise import CAReduce, DimShuffle
from pytensor.tensor.math import Min, neg from pytensor.tensor.math import Min, neg
from pytensor.tensor.rewriting.basic import register_uncanonicalize from pytensor.tensor.rewriting.basic import register_uncanonicalize
from pytensor.tensor.shape import Reshape, reshape from pytensor.tensor.shape import Reshape, reshape
from pytensor.tensor.subtensor import Subtensor, indices_from_subtensor from pytensor.tensor.subtensor import Subtensor
@register_uncanonicalize @register_uncanonicalize
...@@ -193,42 +193,60 @@ def local_dimshuffle_subtensor(fgraph, node): ...@@ -193,42 +193,60 @@ def local_dimshuffle_subtensor(fgraph, node):
if not all(broadcastable[i] for i in missing_dims): if not all(broadcastable[i] for i in missing_dims):
return False return False
# create a new index tuple for a new Subtensor # create a new idx_list for a new Subtensor object
# Reconstruct the full indices from the subtensor node, then replace # have to loop on idx_list and inputs
# dimensions that are being dropped by dimshuffle with scalar index 0 # inputs has the length of sum of non None elements of idx_list
x = input_.owner.inputs[0] # (check in slice!).
indices = list( # len(missing_dims) can be < len(idx_list), this happens if
indices_from_subtensor( # tensor was indexed such as x[scalar, :, :], check that as well
input_.owner.inputs[1:], input_.owner.op.idx_list new_idx_list = list(input_.owner.op.idx_list)
) new_inputs = [input_.owner.inputs[0]]
)
zero = constant(0) zero = constant(0)
j = 0
# Track which output dimension each index corresponds to slice_i = -1
# Scalar indices remove dimensions, slices keep them subtensor_removed_dims = 0
output_dim = 0 for i, idx in enumerate(input_.owner.op.idx_list):
for i, idx in enumerate(indices):
if isinstance(idx, slice): if isinstance(idx, slice):
# This slice produces an output dimension slice_i += 1
if output_dim in missing_dims: if slice_i in missing_dims:
# This output dimension is being dropped, so replace slice with scalar # Missing dim is a slice(None), remove by indexing by 0
if idx == slice(None): if idx == slice(None):
indices[i] = zero new_idx_list[i] = zero
new_inputs += [zero]
# Missing dim is an ordinary slice with known output dim length of 1
# Remove by indexing by start
else: else:
# Use the start of the slice (or 0 if None) if idx.start is None:
indices[i] = idx.start if idx.start is not None else zero start = zero
output_dim += 1 else:
# Scalar indices don't contribute to output dimensions start = input_.owner.inputs[1 + j]
j += 1
# Handle trailing dimensions that weren't explicitly indexed new_idx_list[i] = start
for input_dim in range(len(indices), x.ndim): new_inputs += [start]
if output_dim in missing_dims:
# This unindexed dimension is being dropped, index with 0 # Ignore useless stop and step input if there is one
indices.append(zero) for slice_attr in ("stop", "step"):
if getattr(idx, slice_attr) is not None:
j += 1
# Keep non-dropped slice inputs
else:
for slice_attr in ("start", "stop", "step"):
if getattr(idx, slice_attr) is not None:
new_inputs += [input_.owner.inputs[1 + j]]
j += 1
# Keep non-dropped non-slice inputs
else: else:
# This unindexed dimension is kept, index with slice(None) new_inputs += [input_.owner.inputs[1 + j]]
indices.append(slice(None)) j += 1
output_dim += 1 subtensor_removed_dims += 1
# Verify the trailing dimensions the subtensor didn't look at.
return [x[tuple(indices)]] for idx in range(len(input_.owner.op.idx_list), new_inputs[0].ndim):
if (idx - subtensor_removed_dims) in missing_dims:
while len(new_idx_list) < idx:
new_idx_list.append(slice(None))
new_idx_list.append(zero)
new_inputs.append(zero)
return [Subtensor(new_idx_list)(*new_inputs)]
return False return False
import logging import logging
import sys import sys
import warnings import warnings
from collections.abc import Callable, Sequence from collections.abc import Callable, Iterable, Sequence
from itertools import chain, groupby, zip_longest from itertools import chain, groupby, zip_longest
from typing import TypeVar, cast, overload from typing import cast, overload
import numpy as np import numpy as np
from numpy.lib.array_utils import normalize_axis_tuple from numpy.lib.array_utils import normalize_axis_tuple
...@@ -15,6 +15,7 @@ from pytensor.gradient import disconnected_type ...@@ -15,6 +15,7 @@ from pytensor.gradient import disconnected_type
from pytensor.graph.basic import Apply, Constant, Variable from pytensor.graph.basic import Apply, Constant, Variable
from pytensor.graph.op import Op from pytensor.graph.op import Op
from pytensor.graph.replace import _vectorize_node from pytensor.graph.replace import _vectorize_node
from pytensor.graph.type import Type
from pytensor.graph.utils import MethodNotDefined from pytensor.graph.utils import MethodNotDefined
from pytensor.link.c.op import COp from pytensor.link.c.op import COp
from pytensor.link.c.params_type import ParamsType from pytensor.link.c.params_type import ParamsType
...@@ -37,114 +38,117 @@ from pytensor.tensor.basic import ( ...@@ -37,114 +38,117 @@ from pytensor.tensor.basic import (
) )
from pytensor.tensor.blockwise import vectorize_node_fallback from pytensor.tensor.blockwise import vectorize_node_fallback
from pytensor.tensor.elemwise import DimShuffle from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.exceptions import NotScalarConstantError from pytensor.tensor.exceptions import AdvancedIndexingError, NotScalarConstantError
from pytensor.tensor.math import add, clip from pytensor.tensor.math import add, clip
from pytensor.tensor.shape import ( from pytensor.tensor.shape import Reshape, Shape_i, specify_broadcastable
Reshape,
Shape_i,
specify_broadcastable,
)
from pytensor.tensor.type import ( from pytensor.tensor.type import (
TensorType, TensorType,
bscalar,
complex_dtypes, complex_dtypes,
cscalar,
discrete_dtypes, discrete_dtypes,
dscalar,
fscalar,
integer_dtypes, integer_dtypes,
iscalar,
lscalar,
tensor, tensor,
ubscalar,
uiscalar,
ulscalar,
uwscalar,
wscalar,
zscalar,
)
from pytensor.tensor.type_other import (
MakeSlice,
NoneConst,
NoneSliceConst,
NoneTypeT,
SliceConstant,
SliceType,
make_slice,
) )
from pytensor.tensor.type_other import NoneTypeT
from pytensor.tensor.variable import TensorConstant, TensorVariable from pytensor.tensor.variable import TensorConstant, TensorVariable
from pytensor.utils import unzip from pytensor.utils import unzip
_logger = logging.getLogger("pytensor.tensor.subtensor") _logger = logging.getLogger("pytensor.tensor.subtensor")
invalid_scal_types = (ps.float64, ps.float32, ps.float16)
T = TypeVar("T") scal_types = (
ps.int64,
ps.int32,
def flatten_index_variables( ps.int16,
idx_vars: Sequence[T | None | slice], ps.int8,
) -> tuple[list[int | slice], list[T]]: ps.uint64,
counter = 0 ps.uint32,
idx_list: list[int | slice] = [] ps.uint16,
flat_vars = [] ps.uint8,
for idx_var in idx_vars: )
if isinstance(idx_var, slice): tensor_types = (
slice_idx_list: list[None | int] = [] lscalar,
for arg_entry in (idx_var.start, idx_var.stop, idx_var.step): iscalar,
if arg_entry is None or ( wscalar,
isinstance(arg_entry, Variable) bscalar,
and isinstance(arg_entry.type, NoneTypeT) ulscalar,
): uiscalar,
slice_idx_list.append(None) uwscalar,
else: ubscalar,
flat_vars.append(arg_entry) )
slice_idx_list.append(counter) invalid_tensor_types = (
counter += 1 fscalar,
idx_list.append(slice(*slice_idx_list)) dscalar,
else: cscalar,
flat_vars.append(idx_var) zscalar,
idx_list.append(counter) )
counter += 1
return idx_list, flat_vars
def unflatten_index_variables(
flat_indices: Sequence[T],
idx_list: Sequence[slice | int],
) -> tuple[slice | T, ...]:
indices: list[T | slice] = []
for idx_entry in idx_list:
if isinstance(idx_entry, int):
indices.append(flat_indices[idx_entry])
elif isinstance(idx_entry, slice):
start, stop, step = idx_entry.start, idx_entry.stop, idx_entry.step
indices.append(
slice(
None if idx_entry.start is None else flat_indices[start],
None if idx_entry.stop is None else flat_indices[stop],
None if idx_entry.step is None else flat_indices[step],
)
)
else:
raise ValueError(f"idx_entry must be int or slice, got {type(idx_entry)}")
return tuple(indices)
def indices_from_subtensor( def indices_from_subtensor(
op_indices: Sequence[Variable], op_indices: Iterable[ScalarConstant],
idx_list: tuple[slice | int, ...], idx_list: list[Type | slice | Variable] | None,
) -> tuple[slice | Variable, ...]: ) -> tuple[slice | Variable, ...]:
"""Recreate the index tuple from which a ``*Subtensor**`` ``Op`` was created. """Recreate the index tuple from which a ``*Subtensor**`` ``Op`` was created.
Parameters Parameters
---------- ==========
op_indices op_indices
The flattened indices obtained from ``x.inputs``, when ``x`` is a ``*Subtensor*`` node. The flattened indices obtained from ``x.inputs``, when ``x`` is a
``*Subtensor*`` node.
idx_list idx_list
The values describing each dimension's index. This is obtained from The values describing the types of each dimension's index. This is
``op.idx_list``. Entries can be: obtained from ``op.idx_list``, when ``op`` is a ``*Subtensor*``
- Integer positions (indices into op_indices) ``Op``.
- slice objects with int/None components
Returns
-------
tuple[slice | Variable, ...]
A tuple containing a mix of ``slice`` objects and ``Variable`` objects.
Each element corresponds to one indexing dimension:
- ``slice`` objects for slice-based indexing (e.g., ``x[1:3]``)
- ``Variable`` objects for scalar or array-based indexing
Callers should handle both types when iterating over the result.
Example Example
------- =======
array, *op_indices = subtensor_node.inputs array, *op_indices = subtensor_node.inputs
indices = indices_from_subtensor(op_indices, subtensor_node.op.idx_list) idx_list = getattr(subtensor_node.op, "idx_list", None)
indices = indices_from_subtensor(op_indices, idx_list)
""" """
return unflatten_index_variables(op_indices, idx_list)
def convert_indices(indices, entry):
"""Reconstruct ``*Subtensor*`` index input parameter entries."""
if indices and isinstance(entry, Type):
rval = indices.pop(0)
return rval
elif isinstance(entry, slice):
return slice(
convert_indices(indices, entry.start),
convert_indices(indices, entry.stop),
convert_indices(indices, entry.step),
)
else:
return entry
op_indices = list(op_indices)
return (
tuple(convert_indices(op_indices, idx) for idx in idx_list)
if idx_list
else tuple(op_indices)
)
def as_index_constant( def as_index_constant(
...@@ -178,7 +182,7 @@ def as_index_literal(idx: None) -> None: ... ...@@ -178,7 +182,7 @@ def as_index_literal(idx: None) -> None: ...
@overload @overload
def as_index_literal(idx: slice) -> slice: ... def as_index_literal(idx: slice | SliceConstant) -> slice: ...
@overload @overload
...@@ -190,7 +194,14 @@ def as_index_literal(idx: Variable): ... ...@@ -190,7 +194,14 @@ def as_index_literal(idx: Variable): ...
def as_index_literal( def as_index_literal(
idx: None | int | np.integer | slice | ScalarConstant | TensorConstant | Variable, idx: None
| int
| np.integer
| slice
| SliceConstant
| ScalarConstant
| TensorConstant
| Variable,
) -> int | np.integer | slice | None: ) -> int | np.integer | slice | None:
"""Convert a symbolic index element to its Python equivalent. """Convert a symbolic index element to its Python equivalent.
...@@ -213,6 +224,9 @@ def as_index_literal( ...@@ -213,6 +224,9 @@ def as_index_literal(
if not isinstance(idx, Variable): if not isinstance(idx, Variable):
raise TypeError(f"Not an index element: {idx}") raise TypeError(f"Not an index element: {idx}")
if isinstance(idx.type, NoneTypeT):
return None
if isinstance(idx, ScalarConstant): if isinstance(idx, ScalarConstant):
return cast(int, idx.data) return cast(int, idx.data)
...@@ -226,6 +240,13 @@ def as_index_literal( ...@@ -226,6 +240,13 @@ def as_index_literal(
if isinstance(idx, TensorConstant): if isinstance(idx, TensorConstant):
return cast(int, idx.data.item()) return cast(int, idx.data.item())
if isinstance(idx, SliceConstant):
return cast(slice, idx.data)
if isinstance(idx.type, SliceType):
assert idx.owner is not None
return slice(*map(as_index_literal, idx.owner.inputs))
# Other kinds of variables are not supported # Other kinds of variables are not supported
raise NotScalarConstantError() raise NotScalarConstantError()
...@@ -254,8 +275,10 @@ def get_canonical_form_slice( ...@@ -254,8 +275,10 @@ def get_canonical_form_slice(
) -> tuple[slice | TensorVariable, int | TensorVariable]: ) -> tuple[slice | TensorVariable, int | TensorVariable]:
"""Convert indices or slices to canonical form. """Convert indices or slices to canonical form.
Handles slice objects with ScalarVariable (including ScalarConstant) or None components. Scalar integer indices or python Slices with Scalar/None attributes
Vector indices and advanced indexing operations are handled separately by AdvancedSubtensor. used in basic Subtensor Ops are supported.
Symbolic slices (of SliceType) or vector indices
used in advanced Subtensor Ops are not supported.
Given a slice [start:stop:step] transform it into a canonical form Given a slice [start:stop:step] transform it into a canonical form
that respects the conventions imposed by python and numpy. that respects the conventions imposed by python and numpy.
...@@ -469,20 +492,16 @@ def get_canonical_form_slice( ...@@ -469,20 +492,16 @@ def get_canonical_form_slice(
return slice(nw_start, nw_stop, nw_step), 1 return slice(nw_start, nw_stop, nw_step), 1
def slice_len(slc, n): def range_len(slc):
"""Compute the length of a slice for an array of a given length. """Length of a `range` object.
We're essentially computing `len(range(*slc.indices(n)))`.
Adapted from CPython. Adapted from CPython.
""" """
from pytensor.tensor import and_, gt, lt, switch from pytensor.tensor import and_, gt, lt, switch
# TODO: Do we need to do this or should we expect `slc` to already be canonicalized?
canon_slc, _ = get_canonical_form_slice(slc, n)
start, stop, step = tuple( start, stop, step = tuple(
as_index_constant(a) for a in [canon_slc.start, canon_slc.stop, canon_slc.step] as_index_constant(a) for a in [slc.start, slc.stop, slc.step]
) )
return switch( return switch(
and_(gt(step, 0), lt(start, stop)), and_(gt(step, 0), lt(start, stop)),
...@@ -495,6 +514,31 @@ def slice_len(slc, n): ...@@ -495,6 +514,31 @@ def slice_len(slc, n):
) )
def slice_len(slc, n):
"""Compute the length of a slice for an array of a given length.
We're essentially computing `len(range(*slc.indices(n)))`.
"""
# TODO: Do we need to do this or should we expect `slc` to
# already be canonicalized?
canon_slc, _ = get_canonical_form_slice(slc, n)
return range_len(canon_slc)
def is_basic_idx(idx):
"""Determine if an index is of the NumPy basic type.
XXX: This only checks a single index, so an integer is *not* considered a
basic index, because--depending on the other indices its used with--an
integer can indicate advanced indexing.
"""
return isinstance(idx, slice | type(None)) or isinstance(
getattr(idx, "type", None), SliceType | NoneTypeT
)
def basic_shape(shape, indices): def basic_shape(shape, indices):
r"""Computes the shape resulting from basic NumPy indexing. r"""Computes the shape resulting from basic NumPy indexing.
...@@ -513,8 +557,25 @@ def basic_shape(shape, indices): ...@@ -513,8 +557,25 @@ def basic_shape(shape, indices):
for n, idx in zip(shape[: len(indices)], indices, strict=True): for n, idx in zip(shape[: len(indices)], indices, strict=True):
if isinstance(idx, slice): if isinstance(idx, slice):
res_shape += (slice_len(idx, n),) res_shape += (slice_len(idx, n),)
elif isinstance(getattr(idx, "type", None), SliceType):
if idx.owner is None:
if not isinstance(idx, Constant):
# This is an input slice, we can't reason symbolically on it.
# We don't even know if we will get None entries or integers
res_shape += (None,)
continue
else:
sl: slice = idx.data
slice_inputs = (sl.start, sl.stop, sl.step)
elif isinstance(idx.owner.op, MakeSlice):
slice_inputs = idx.owner.inputs
else:
raise ValueError(f"Unexpected Slice producing Op {idx.owner.op}")
res_shape += (slice_len(slice(*slice_inputs), n),)
elif idx is None: elif idx is None:
res_shape += (ps.ScalarConstant(ps.int64, 1),) res_shape += (ps.ScalarConstant(ps.int64, 1),)
elif isinstance(getattr(idx, "type", None), NoneTypeT):
res_shape += (ps.ScalarConstant(ps.int64, 1),)
else: else:
raise ValueError(f"Invalid index type: {idx}") raise ValueError(f"Invalid index type: {idx}")
return res_shape return res_shape
...@@ -532,12 +593,14 @@ def group_indices(indices): ...@@ -532,12 +593,14 @@ def group_indices(indices):
""" """
idx_groups = [] idx_groups = []
dim_num = -1 dim_num = -1
for basic, grp_indices in groupby(indices, key=lambda x: isinstance(x, slice)): for basic, grp_indices in groupby(indices, key=is_basic_idx):
enum_grp_indices = [] enum_grp_indices = []
for idx in grp_indices: for idx in grp_indices:
# We "zip" the dimension number to each index, which means we can't # We "zip" the dimension number to each index, which means we can't
# count indices that add new axes # count indices that add new axes
if idx is not None: if (idx is not None) and not isinstance(
getattr(idx, "type", None), NoneTypeT
):
dim_num += 1 dim_num += 1
enum_grp_indices.append((dim_num, idx)) enum_grp_indices.append((dim_num, idx))
...@@ -584,7 +647,7 @@ def indexed_result_shape(array_shape, indices, indices_are_shapes=False): ...@@ -584,7 +647,7 @@ def indexed_result_shape(array_shape, indices, indices_are_shapes=False):
idx_groups = sorted(idx_groups, key=lambda x: x[0]) idx_groups = sorted(idx_groups, key=lambda x: x[0])
idx_groups = groupby( idx_groups = groupby(
chain.from_iterable(d_idx for _, d_idx in idx_groups), chain.from_iterable(d_idx for _, d_idx in idx_groups),
key=lambda x: isinstance(x[1], slice), key=lambda x: is_basic_idx(x[1]),
) )
for basic, grp_dim_indices in idx_groups: for basic, grp_dim_indices in idx_groups:
...@@ -644,6 +707,72 @@ def get_slice_elements( ...@@ -644,6 +707,72 @@ def get_slice_elements(
return ret return ret
def index_vars_to_types(entry, slice_ok=True):
r"""Change references to `Variable`s into references to `Type`s.
The `Subtensor.idx_list` field is unique to each `Subtensor` instance. It
is not unique to each `Apply` node, so it should not refer to specific
`Variable`s.
TODO WRITEME: This function also accepts an `entry` already being a `Type`;
when would that happen?
"""
if (
isinstance(entry, np.ndarray | Variable)
and hasattr(entry, "dtype")
and entry.dtype == "bool"
):
raise AdvancedIndexingError("Invalid index type or slice for Subtensor")
if isinstance(entry, Variable) and (
entry.type in invalid_scal_types or entry.type in invalid_tensor_types
):
raise TypeError("Expected an integer")
if isinstance(entry, Variable) and entry.type in scal_types:
return entry.type
elif isinstance(entry, Type) and entry in scal_types:
return entry
if (
isinstance(entry, Variable)
and entry.type in tensor_types
and all(entry.type.broadcastable)
):
return ps.get_scalar_type(entry.type.dtype)
elif isinstance(entry, Type) and entry in tensor_types and all(entry.broadcastable):
return ps.get_scalar_type(entry.dtype)
elif slice_ok and isinstance(entry, slice):
a = entry.start
b = entry.stop
c = entry.step
if a is not None:
slice_a = index_vars_to_types(a, False)
else:
slice_a = None
if b is not None and b != sys.maxsize:
# The special "maxsize" case is probably not needed here,
# as slices containing maxsize are not generated by
# __getslice__ anymore.
slice_b = index_vars_to_types(b, False)
else:
slice_b = None
if c is not None:
slice_c = index_vars_to_types(c, False)
else:
slice_c = None
return slice(slice_a, slice_b, slice_c)
elif isinstance(entry, int | np.integer):
raise TypeError()
else:
raise AdvancedIndexingError("Invalid index type or slice for Subtensor")
def get_constant_idx( def get_constant_idx(
idx_list, inputs, allow_partial=False, only_process_constants=False, elemwise=True idx_list, inputs, allow_partial=False, only_process_constants=False, elemwise=True
): ):
...@@ -674,7 +803,7 @@ def get_constant_idx( ...@@ -674,7 +803,7 @@ def get_constant_idx(
>>> a = matrix("a") >>> a = matrix("a")
>>> b = a[v, 1:3] >>> b = a[v, 1:3]
>>> b.owner.op.idx_list >>> b.owner.op.idx_list
(0, slice(1, 2, None)) (ScalarType(int64), slice(ScalarType(int64), ScalarType(int64), None))
>>> get_constant_idx(b.owner.op.idx_list, b.owner.inputs, allow_partial=True) >>> get_constant_idx(b.owner.op.idx_list, b.owner.inputs, allow_partial=True)
[v, slice(np.int64(1), np.int64(3), None)] [v, slice(np.int64(1), np.int64(3), None)]
>>> get_constant_idx(b.owner.op.idx_list, b.owner.inputs) >>> get_constant_idx(b.owner.op.idx_list, b.owner.inputs)
...@@ -706,11 +835,15 @@ def get_constant_idx( ...@@ -706,11 +835,15 @@ def get_constant_idx(
return list(map(conv, real_idx)) return list(map(conv, real_idx))
def as_scalar_index_variable(idx) -> ps.ScalarVariable: def as_nontensor_scalar(a: Variable) -> ps.ScalarVariable:
idx = ps.as_scalar(idx) """Convert a value to a `ScalarType` variable."""
if idx.type.dtype not in integer_dtypes: # Since ps.as_scalar does not know about tensor types (it would
raise TypeError("basic indices must be integers") # create a circular import) , this method converts either a
return idx # type: ignore[no-any-return] # TensorVariable or a ScalarVariable to a scalar.
if isinstance(a, Variable) and isinstance(a.type, TensorType):
return pytensor.tensor.scalar_from_tensor(a)
else:
return ps.as_scalar(a)
def slice_static_length(slc, dim_length): def slice_static_length(slc, dim_length):
...@@ -731,71 +864,17 @@ def slice_static_length(slc, dim_length): ...@@ -731,71 +864,17 @@ def slice_static_length(slc, dim_length):
return len(range(*slice(*entries).indices(dim_length))) return len(range(*slice(*entries).indices(dim_length)))
class BaseSubtensor: class Subtensor(COp):
"""Base class for Subtensor operations that handles idx_list and hash/equality."""
def __init__(self, idx_list: Sequence[int | slice]):
index_counter = -1
for idx_entry in idx_list:
if isinstance(idx_entry, int):
if idx_entry != (index_counter + 1):
raise ValueError(
f"idx_list entries should have consecutive integers, got {idx_list}"
)
index_counter = idx_entry
elif isinstance(idx_entry, slice):
for slice_idx_entry in (
idx_entry.start,
idx_entry.stop,
idx_entry.step,
):
if slice_idx_entry is not None:
if not isinstance(slice_idx_entry, int):
raise ValueError(
f"idx_list slice entries must be None or integer, got {slice_idx_entry} in {idx_entry}"
)
if slice_idx_entry != (index_counter + 1):
raise ValueError(
f"idx_list entries should have consecutive integers, got {idx_list}"
)
index_counter = slice_idx_entry
else:
raise ValueError(
f"idx_list entries must be int or slice, got {idx_entry}"
)
self.n_index_vars = index_counter + 1
self.idx_list = tuple(idx_list)
def _hashable_idx_list(self):
"""Return a hashable version of idx_list (slices converted to tuples).
Slices are not hashable in Python < 3.12, so we convert them to tuples.
"""
return tuple(
(slice, entry.start, entry.stop, entry.step)
if isinstance(entry, slice)
else entry
for entry in self.idx_list
)
def __hash__(self):
# Temporary workaround: slices are hashable in Python 3.12+
props_values = tuple(
self._hashable_idx_list() if prop == "idx_list" else getattr(self, prop)
for prop in self.__props__
)
return hash((type(self), props_values))
class Subtensor(BaseSubtensor, COp):
"""Basic NumPy indexing operator.""" """Basic NumPy indexing operator."""
check_input = False check_input = False
view_map = {0: [0]} view_map = {0: [0]}
_f16_ok = True _f16_ok = True
__props__ = ("idx_list",) __props__ = ("idx_list",)
__hash__ = BaseSubtensor.__hash__
def __init__(self, idx_list):
# TODO: Provide the type of `self.idx_list`
self.idx_list = tuple(map(index_vars_to_types, idx_list))
def make_node(self, x, *inputs): def make_node(self, x, *inputs):
""" """
...@@ -808,16 +887,23 @@ class Subtensor(BaseSubtensor, COp): ...@@ -808,16 +887,23 @@ class Subtensor(BaseSubtensor, COp):
""" """
x = as_tensor_variable(x) x = as_tensor_variable(x)
inputs = tuple(as_scalar_index_variable(a) for a in inputs) inputs = tuple(as_nontensor_scalar(a) for a in inputs)
idx_list = list(self.idx_list) idx_list = list(self.idx_list)
if len(idx_list) > x.type.ndim: if len(idx_list) > x.type.ndim:
raise IndexError("too many indices for array") raise IndexError("too many indices for array")
input_positions = get_slice_elements( input_types = get_slice_elements(
idx_list, lambda entry: isinstance(entry, int) idx_list, lambda entry: isinstance(entry, Type)
) )
assert len(inputs) == len(input_positions) assert len(inputs) == len(input_types)
for input, expected_type in zip(inputs, input_types, strict=True):
if not expected_type.is_super(input.type):
raise TypeError(
f"Incompatible types for Subtensor template. Expected {input.type}, got {expected_type}."
)
padded = [ padded = [
*indices_from_subtensor(inputs, self.idx_list), *indices_from_subtensor(inputs, self.idx_list),
...@@ -838,10 +924,13 @@ class Subtensor(BaseSubtensor, COp): ...@@ -838,10 +924,13 @@ class Subtensor(BaseSubtensor, COp):
def perform(self, node, inputs, out_): def perform(self, node, inputs, out_):
(out,) = out_ (out,) = out_
x, *index_variables = inputs x = inputs[0]
cdata = get_idx_list(inputs, self.idx_list)
if len(cdata) == 1:
cdata = cdata[0]
cdata = unflatten_index_variables(index_variables, self.idx_list) out[0] = np.asarray(x.__getitem__(cdata))
out[0] = np.asarray(x.__getitem__(tuple(cdata)))
def infer_shape(self, fgraph, node, shapes): def infer_shape(self, fgraph, node, shapes):
def _is_constant(const, x): def _is_constant(const, x):
...@@ -889,7 +978,8 @@ class Subtensor(BaseSubtensor, COp): ...@@ -889,7 +978,8 @@ class Subtensor(BaseSubtensor, COp):
def grad(self, inputs, grads): def grad(self, inputs, grads):
(gz,) = grads (gz,) = grads
x, *index_variables = inputs x = inputs[0]
rest = inputs[1:]
if x.dtype in discrete_dtypes: if x.dtype in discrete_dtypes:
first = x.zeros_like(dtype=config.floatX) first = x.zeros_like(dtype=config.floatX)
else: else:
...@@ -898,28 +988,43 @@ class Subtensor(BaseSubtensor, COp): ...@@ -898,28 +988,43 @@ class Subtensor(BaseSubtensor, COp):
# We have an optimization that will convert this to a # We have an optimization that will convert this to a
# set subtensor here at: # set subtensor here at:
# pytensor/tensor/opt.py:local_incsubtensor_of_zeros_to_setsubtensor() # pytensor/tensor/opt.py:local_incsubtensor_of_zeros_to_setsubtensor()
first = IncSubtensor(self.idx_list)(x.zeros_like(), gz, *index_variables) first = IncSubtensor(self.idx_list)(x.zeros_like(), gz, *rest)
return [first, *(disconnected_type() for _ in range(len(index_variables)))] return [first, *(disconnected_type() for _ in range(len(rest)))]
def connection_pattern(self, node): def connection_pattern(self, node):
_x, *index_variables = node.inputs rval = [[True], *([False] for _ in node.inputs[1:])]
rval = [[True], *([False] for _ in index_variables)]
return rval return rval
def __hash__(self):
msg = []
for entry in self.idx_list:
if isinstance(entry, slice):
msg += [(entry.start, entry.stop, entry.step)]
else:
msg += [entry]
idx_list = tuple(msg)
# backport
# idx_list = tuple((entry.start, entry.stop, entry.step)
# if isinstance(entry, slice)
# else entry
# for entry in self.idx_list)
return hash(idx_list)
@staticmethod @staticmethod
def str_from_slice(entry): def str_from_slice(entry):
if entry.step is not None: if entry.step:
return ":".join( return ":".join(
( (
"start" if entry.start is not None else "", "start" if entry.start else "",
"stop" if entry.stop is not None else "", "stop" if entry.stop else "",
"step", "step",
) )
) )
if entry.stop is not None: if entry.stop:
return f"{'start' if entry.start is not None else ''}:stop" return f"{'start' if entry.start else ''}:stop"
if entry.start is not None: if entry.start:
return "start:" return "start:"
return ":" return ":"
...@@ -1002,7 +1107,12 @@ class Subtensor(BaseSubtensor, COp): ...@@ -1002,7 +1107,12 @@ class Subtensor(BaseSubtensor, COp):
return pos[1] return pos[1]
def init_entry(entry, depth=0): def init_entry(entry, depth=0):
if isinstance(entry, int): if isinstance(entry, np.integer | int):
init_cmds.append(f"subtensor_spec[{spec_pos()}] = {entry};")
inc_spec_pos(1)
if depth == 0:
is_slice.append(0)
elif isinstance(entry, Type):
init_cmds.append( init_cmds.append(
f"subtensor_spec[{spec_pos()}] = {inputs[input_pos()]};" f"subtensor_spec[{spec_pos()}] = {inputs[input_pos()]};"
) )
...@@ -1265,58 +1375,7 @@ class Subtensor(BaseSubtensor, COp): ...@@ -1265,58 +1375,7 @@ class Subtensor(BaseSubtensor, COp):
# (they should be defaulted to zeros_like by the global R_op) # (they should be defaulted to zeros_like by the global R_op)
if eval_points[0] is None: if eval_points[0] is None:
return [None] return [None]
_x, *index_variables = inputs return self(eval_points[0], *inputs[1:], return_list=True)
return self(eval_points[0], *index_variables, return_list=True)
def basic_subtensor(x, *index_variables):
idx_list, flat_index_vars = flatten_index_variables(index_variables)
return Subtensor(idx_list)(x, *flat_index_vars)
@_get_vector_length.register(Subtensor) # type: ignore
def _get_vector_length_Subtensor(op, var):
# If we take a slice, we know how many elements it will result in
# TODO: We can cover more `*Subtensor` cases.
try:
indices = get_idx_list(var.owner.inputs, var.owner.op.idx_list)
start = (
None
if indices[0].start is None
else get_scalar_constant_value(indices[0].start)
)
stop = (
None
if indices[0].stop is None
else get_scalar_constant_value(indices[0].stop)
)
step = (
None
if indices[0].step is None
else get_scalar_constant_value(indices[0].step)
)
if start == stop:
return 0
arg_len = get_vector_length(var.owner.inputs[0])
return len(range(*slice(start, stop, step).indices(arg_len)))
except (ValueError, NotScalarConstantError):
raise ValueError(f"Length of {var} cannot be determined")
@_vectorize_node.register(Subtensor)
def vectorize_subtensor(op: Subtensor, node, batch_x, *batch_idxs):
"""Rewrite subtensor with non-batched indexes as another Subtensor with prepended empty slices."""
# TODO: Vectorize Subtensor with non-slice batched indexes as AdvancedSubtensor
if any(batch_inp.type.ndim > 0 for batch_inp in batch_idxs):
return vectorize_node_fallback(op, node, batch_x, *batch_idxs)
old_x, *_ = node.inputs
batch_ndims = batch_x.type.ndim - old_x.type.ndim
new_idx_list = (slice(None),) * batch_ndims + op.idx_list
return Subtensor(new_idx_list).make_node(batch_x, *batch_idxs)
class SubtensorPrinter(Printer): class SubtensorPrinter(Printer):
...@@ -1328,28 +1387,25 @@ class SubtensorPrinter(Printer): ...@@ -1328,28 +1387,25 @@ class SubtensorPrinter(Printer):
input = inputs.pop(0) input = inputs.pop(0)
sidxs = [] sidxs = []
getattr(pstate, "precedence", None) getattr(pstate, "precedence", None)
def process_slice_component(comp):
"""Process a slice component, returning string representation."""
if comp is None:
return ""
elif isinstance(comp, int):
with set_precedence(pstate):
return pstate.pprinter.process(inputs.pop(0))
else:
return str(comp)
for entry in idxs: for entry in idxs:
if isinstance(entry, int): if isinstance(entry, ps.ScalarType):
with set_precedence(pstate): with set_precedence(pstate):
sidxs.append(pstate.pprinter.process(inputs.pop(0))) sidxs.append(pstate.pprinter.process(inputs.pop()))
elif isinstance(entry, slice): elif isinstance(entry, slice):
msg1 = process_slice_component(entry.start) if entry.start is None or entry.start == 0:
msg2 = process_slice_component(entry.stop) msg1 = ""
else:
msg1 = entry.start
if entry.stop is None or entry.stop == sys.maxsize:
msg2 = ""
else:
msg2 = entry.stop
if entry.step is None: if entry.step is None:
msg3 = "" msg3 = ""
else: else:
msg3 = f":{process_slice_component(entry.step)}" msg3 = f":{entry.step}"
sidxs.append(f"{msg1}:{msg2}{msg3}") sidxs.append(f"{msg1}:{msg2}{msg3}")
...@@ -1362,97 +1418,336 @@ class SubtensorPrinter(Printer): ...@@ -1362,97 +1418,336 @@ class SubtensorPrinter(Printer):
pprint.assign(Subtensor, SubtensorPrinter()) pprint.assign(Subtensor, SubtensorPrinter())
class IncSubtensor(BaseSubtensor, COp): @_vectorize_node.register(Subtensor)
""" def vectorize_subtensor(op: Subtensor, node, batch_x, *batch_idxs):
Increment a subtensor. """Rewrite subtensor with non-batched indexes as another Subtensor with prepended empty slices."""
This is like numpy's
x[i,j,k] += y # TODO: Vectorize Subtensor with non-slice batched indexes as AdvancedSubtensor
if any(batch_inp.type.ndim > 0 for batch_inp in batch_idxs):
return vectorize_node_fallback(op, node, batch_x, *batch_idxs)
It is used internally to implement the gradient on SubTensor. old_x, *_ = node.inputs
batch_ndims = batch_x.type.ndim - old_x.type.ndim
new_idx_list = (slice(None),) * batch_ndims + op.idx_list
return Subtensor(new_idx_list).make_node(batch_x, *batch_idxs)
Parameters
----------
set_instead_of_inc
If True set the subtensor to the value instead of incrementing it by
that value.
def set_subtensor(x, y, inplace=False, tolerate_inplace_aliasing=False):
""" """
Return x with the given subtensor overwritten by y.
check_input = False Parameters
__props__ = ( ----------
"idx_list", x
"inplace", Symbolic variable for the lvalue of = operation.
"set_instead_of_inc", y
"destroyhandler_tolerate_aliased", Symbolic variable for the rvalue of = operation.
) tolerate_inplace_aliasing
__hash__ = BaseSubtensor.__hash__ See inc_subtensor for documentation.
def __init__(
self,
idx_list,
inplace=False,
set_instead_of_inc=False,
destroyhandler_tolerate_aliased=None,
):
if destroyhandler_tolerate_aliased is None:
destroyhandler_tolerate_aliased = ()
super().__init__(idx_list)
self.inplace = inplace
if inplace:
self.destroy_map = {0: [0]}
self.destroyhandler_tolerate_aliased = tuple(destroyhandler_tolerate_aliased)
self.set_instead_of_inc = set_instead_of_inc
def __str__(self): Examples
name = "SetSubtensor" if self.set_instead_of_inc else "IncSubtensor" --------
return f"{name}{{{Subtensor.str_from_indices(self.idx_list)}}}" To replicate the numpy expression ``r[10:] = 5``, type
def make_node(self, x, y, *inputs): .. code-block:: python
"""
Parameters
----------
x
The tensor to increment.
y
The value to increment by.
inputs
The indeces/slices list to increment in combination with idx_list.
E.g. self._idx_list = (0, slice(1, None, None), 2, slice(3, None, 4)) from pytensor.tensor import set_subtensor, vector
tell to use inputs[0] as the first dim.
"""
x, y = map(as_tensor_variable, [x, y])
if y.ndim > x.ndim:
raise ValueError(
f"Trying to increment a {int(x.ndim)}-dimensional "
f"subtensor with a {int(y.ndim)}-dimensional value."
)
inputs = tuple(map(as_scalar_index_variable, inputs))
idx_list = list(self.idx_list) r = vector("r")
if len(idx_list) > x.type.ndim: new_r = set_subtensor(r[10:], 5)
raise IndexError("too many indices for array")
if len(inputs) != self.n_index_vars: Consider using :meth:`pytensor.tensor.variable.TensorVariable.set` instead.
raise ValueError(
"Not enough inputs to fill in the Subtensor template.", inputs, idx_list
)
return Apply(self, (x, y, *inputs), [x.type()]) """
return inc_subtensor(
x,
y,
inplace,
set_instead_of_inc=True,
tolerate_inplace_aliasing=tolerate_inplace_aliasing,
)
def decl_view(self):
return "PyArrayObject * zview = NULL;"
def perform(self, node, inputs, output_storage): def inc_subtensor(
x,
y,
inplace=False,
set_instead_of_inc=False,
tolerate_inplace_aliasing=False,
ignore_duplicates=False,
):
"""Update the value of an indexed array by a given amount.
This is equivalent to ``x[indices] += y`` or ``np.add.at(x, indices, y)``,
depending on the value of `ignore_duplicates`.
Parameters
----------
x
The symbolic result of a Subtensor operation.
y
The amount by which to increment the array.
inplace
Don't use. PyTensor will do in-place operations itself, when possible.
set_instead_of_inc
If True, do a set_subtensor instead.
tolerate_inplace_aliasing:
Allow `x` and `y` to be views of a single underlying array even while
working in-place. For correct results, `x` and `y` must not be overlapping
views; if they overlap, the result of this `Op` will generally be
incorrect. This value has no effect if ``inplace=False``.
ignore_duplicates
This determines whether ``x[indices] += y`` is used or
``np.add.at(x, indices, y)``.
Examples
--------
To replicate the expression ``r[10:] += 5``:
.. code-block:: python
from pytensor.tensor import ivector, inc_subtensor
r = ivector("r")
new_r = inc_subtensor(r[10:], 5)
To replicate the expression ``r[[0, 1, 0]] += 5``:
.. code-block:: python
r = ivector("r")
new_r = inc_subtensor(r[[0, 1, 0]], 5, ignore_duplicates=True)
Consider using :meth:`pytensor.tensor.variable.TensorVariable.inc` instead.
"""
# First of all, y cannot have a higher dimension than x,
# nor have non-broadcastable dimensions where x is broadcastable.
x = as_tensor_variable(x)
y = as_tensor_variable(y)
if y.ndim > x.ndim:
raise TypeError(
f"Trying to increment a {int(x.ndim)}-dimensional "
f"subtensor with a {int(y.ndim)}-dimensional value."
)
dim_offset = x.ndim - y.ndim
for dim in range(y.ndim):
if x.broadcastable[dim + dim_offset] and not y.broadcastable[dim]:
# It is acceptable to try to increment a subtensor with a
# broadcastable dim with a tensor that is not broadcastable
# on that dimension. However, its length must then be 1.
# We insert a SpecifyShape Op to make sure it is the case.
y = specify_broadcastable(y, dim)
if x.owner is None:
raise TypeError("x must be the result of a subtensor operation")
# retrieve idx_list from x.owner
if isinstance(x.owner.op, Subtensor):
if tolerate_inplace_aliasing:
destroyhandler_tolerate_aliased = [[0, 1]]
else:
destroyhandler_tolerate_aliased = []
the_op = IncSubtensor(
x.owner.op.idx_list,
inplace,
set_instead_of_inc,
destroyhandler_tolerate_aliased=destroyhandler_tolerate_aliased,
)
real_x = x.owner.inputs[0]
real_idxargs = x.owner.inputs[1:]
return the_op(real_x, y, *real_idxargs)
elif isinstance(x.owner.op, AdvancedSubtensor1):
real_x = x.owner.inputs[0]
ilist = x.owner.inputs[1]
if ignore_duplicates:
the_op = AdvancedIncSubtensor(
inplace, set_instead_of_inc=set_instead_of_inc, ignore_duplicates=True
)
else:
the_op = AdvancedIncSubtensor1(
inplace, set_instead_of_inc=set_instead_of_inc
)
return the_op(real_x, y, ilist)
elif isinstance(x.owner.op, AdvancedSubtensor):
real_x = x.owner.inputs[0]
ilist = x.owner.inputs[1:]
the_op = AdvancedIncSubtensor(
inplace,
set_instead_of_inc=set_instead_of_inc,
ignore_duplicates=ignore_duplicates,
)
return the_op(real_x, y, *ilist)
elif isinstance(x.owner.op, DimShuffle):
inner_x = x.owner.inputs[0]
# In the dimshuffle case, there are in fact two dimshuffles:
# one to make the indexed dimension the last one,
# and one to put it back where it was. So, in the case where we have
# inc_subtensor(x[:,i], y), the graph is actually
# inc_subtensor((x.T)[i].T, y).
# We could get all the way to x, and then get rid of the dimshuffles
# completely, but the problem is that advanced_inc_subtensor1 can only
# work on the first (outer-most, left-most) dimension of x,
# just like advanced_subtensor1.
# So we call advanced_inc_subtensor1(x.T, i, y.T) (as we also need to
# transpose y if it is not a scalar or a vector), but then we need to
# return something that has the same shape as x, not as x.T (inner_x).
# So re-apply the outer dimshuffle on the new inc_subtensor,
# and return advanced_inc_subtensor1(x.T, i, y.T).T.
# Get the dimshuffle pattern to apply to y.
x_order = x.owner.op.new_order
y_order = ["x"] * x.ndim
for i, v in enumerate(x_order):
if v != "x" and (v - dim_offset) >= 0:
y_order[v - dim_offset] = i
inner_incsubtensor = inc_subtensor(
inner_x,
y.dimshuffle(y_order),
inplace=inplace,
set_instead_of_inc=set_instead_of_inc,
tolerate_inplace_aliasing=tolerate_inplace_aliasing,
ignore_duplicates=ignore_duplicates,
)
# The broadcastable pattern of inner_x may not be the same as
# the one of x, so we have to build a new dimshuffle here,
# instead of reusing x.owner.op().
return inner_incsubtensor.dimshuffle(x.owner.op.new_order)
elif isinstance(x.owner.op, Reshape):
# This case happens when the indices are not arranged as a vector, but
# as a higher-dimensional array. This is handled by the subtensor
# by flattening this list, taking the subtensor, then reshaping the
# result.
inner_x = x.owner.inputs[0]
# Try to apply inc_subtensor on inner_x.
# If it works, there is no need to reshape, as the inc_subtensor
# will have the same shape as inner_x, which is what we want.
# We also explicitly duplicate y to its broadcasted shape
# before we partially flatten it to inner_x dimension. This is
# not strictly needed in all cases, but it is easier this way.
if y.ndim > 0:
# This if is needed to prevent some useless warning about
# old code bug.
expanded_y = alloc(y, *[x.shape[i] for i in range(x.ndim)])
flattened_y = expanded_y.reshape(inner_x.shape)
else:
flattened_y = y
inner_incsubtensor = inc_subtensor(
inner_x,
flattened_y,
inplace=inplace,
set_instead_of_inc=set_instead_of_inc,
tolerate_inplace_aliasing=tolerate_inplace_aliasing,
ignore_duplicates=ignore_duplicates,
)
return inner_incsubtensor
else:
raise TypeError("x must be the result of a subtensor operation")
class IncSubtensor(COp):
"""
Increment a subtensor.
This is like numpy's
x[i,j,k] += y
It is used internally to implement the gradient on SubTensor.
Parameters
----------
set_instead_of_inc
If True set the subtensor to the value instead of incrementing it by
that value.
"""
check_input = False
__props__ = ("idx_list", "inplace", "set_instead_of_inc")
def __init__(
self,
idx_list,
inplace=False,
set_instead_of_inc=False,
destroyhandler_tolerate_aliased=None,
):
if destroyhandler_tolerate_aliased is None:
destroyhandler_tolerate_aliased = []
self.idx_list = list(map(index_vars_to_types, idx_list))
self.inplace = inplace
if inplace:
self.destroy_map = {0: [0]}
self.destroyhandler_tolerate_aliased = list(destroyhandler_tolerate_aliased)
self.set_instead_of_inc = set_instead_of_inc
def __hash__(self):
idx_list = tuple(
(entry.start, entry.stop, entry.step) if isinstance(entry, slice) else entry
for entry in self.idx_list
)
return hash((type(self), idx_list, self.inplace, self.set_instead_of_inc))
def __str__(self):
name = "SetSubtensor" if self.set_instead_of_inc else "IncSubtensor"
return f"{name}{{{Subtensor.str_from_indices(self.idx_list)}}}"
def make_node(self, x, y, *inputs):
"""
Parameters
----------
x
The tensor to increment.
y
The value to increment by.
inputs: TODO WRITEME
"""
x, y = map(as_tensor_variable, [x, y])
if y.ndim > x.ndim:
raise ValueError(
f"Trying to increment a {int(x.ndim)}-dimensional "
f"subtensor with a {int(y.ndim)}-dimensional value."
)
inputs = tuple(map(as_nontensor_scalar, inputs))
idx_list = list(self.idx_list)
if len(idx_list) > x.type.ndim:
raise IndexError("too many indices for array")
input_types = get_slice_elements(
idx_list, lambda entry: isinstance(entry, Type)
)
if len(inputs) != len(input_types):
raise IndexError(
"Not enough inputs to fill in the Subtensor template.", inputs, idx_list
)
for input, expected_type in zip(inputs, input_types, strict=True):
if not expected_type.is_super(input.type):
raise TypeError(
f"Wrong type for Subtensor template. Expected {input.type}, got {expected_type}."
)
return Apply(self, (x, y, *inputs), [x.type()])
def decl_view(self):
return "PyArrayObject * zview = NULL;"
def perform(self, node, inputs, output_storage):
x, y, *flat_indices = inputs x, y, *flat_indices = inputs
flat_indices_iterator = iter(flat_indices) flat_indices_iterator = iter(flat_indices)
indices = tuple( indices = tuple(
( (
next(flat_indices_iterator) next(flat_indices_iterator)
if isinstance(entry, int) if isinstance(entry, Type)
else slice( else slice(
None if entry.start is None else next(flat_indices_iterator), None if entry.start is None else next(flat_indices_iterator),
None if entry.stop is None else next(flat_indices_iterator), None if entry.stop is None else next(flat_indices_iterator),
...@@ -1697,18 +1992,17 @@ class IncSubtensor(BaseSubtensor, COp): ...@@ -1697,18 +1992,17 @@ class IncSubtensor(BaseSubtensor, COp):
return [None] return [None]
# Again we ignore eval points for indices because incsubtensor is # Again we ignore eval points for indices because incsubtensor is
# not differentiable wrt to those # not differentiable wrt to those
_x, _y, *index_variables = inputs return self(eval_points[0], eval_points[1], *inputs[2:], return_list=True)
return self(eval_points[0], eval_points[1], *index_variables, return_list=True)
def connection_pattern(self, node): def connection_pattern(self, node):
_x, _y, *index_variables = node.inputs rval = [[True], [True], *([False] for _ in node.inputs[2:])]
rval = [[True], [True], *([False] for _ in index_variables)]
return rval return rval
def grad(self, inputs, grads): def grad(self, inputs, grads):
(g_output,) = grads (g_output,) = grads
x, y, *index_variables = inputs x, y = inputs[:2]
idx_list = inputs[2:]
if x.dtype in discrete_dtypes: if x.dtype in discrete_dtypes:
# The output dtype is the same as x # The output dtype is the same as x
...@@ -1722,25 +2016,25 @@ class IncSubtensor(BaseSubtensor, COp): ...@@ -1722,25 +2016,25 @@ class IncSubtensor(BaseSubtensor, COp):
else: else:
if self.set_instead_of_inc: if self.set_instead_of_inc:
gx = set_subtensor( gx = set_subtensor(
Subtensor(idx_list=self.idx_list)(g_output, *index_variables), Subtensor(idx_list=self.idx_list)(g_output, *idx_list),
pytensor.tensor.zeros_like(y), pytensor.tensor.zeros_like(y),
) )
else: else:
gx = g_output gx = g_output
gy = Subtensor(idx_list=self.idx_list)(g_output, *index_variables) gy = Subtensor(idx_list=self.idx_list)(g_output, *idx_list)
gy = _sum_grad_over_bcasted_dims(y, gy) gy = _sum_grad_over_bcasted_dims(y, gy)
return [gx, gy, *(disconnected_type() for _ in range(len(index_variables)))] return [gx, gy, *(disconnected_type() for _ in range(len(idx_list)))]
class IncSubtensorPrinter(SubtensorPrinter): class IncSubtensorPrinter(SubtensorPrinter):
def process(self, r, pstate): def process(self, r, pstate):
x, y, *index_variables = r.owner.inputs x, _y, *idx_args = r.owner.inputs
res = self._process(r.owner.op.idx_list, [x, *index_variables], pstate) res = self._process(r.owner.op.idx_list, [x, *idx_args], pstate)
with set_precedence(pstate, 1000): with set_precedence(pstate, 1000):
y_str = pstate.pprinter.process(y, pstate) y_str = pstate.pprinter.process(r.owner.inputs[1], pstate)
if r.owner.op.set_instead_of_inc: if r.owner.op.set_instead_of_inc:
res = f"set_subtensor({res}, {y_str})" res = f"set_subtensor({res}, {y_str})"
...@@ -1801,13 +2095,9 @@ class AdvancedSubtensor1(COp): ...@@ -1801,13 +2095,9 @@ class AdvancedSubtensor1(COp):
# sparse_grad doesn't go in here since it only affects the output # sparse_grad doesn't go in here since it only affects the output
# of the grad() method. # of the grad() method.
__props__ = () __props__ = ()
idx_list = (0,)
_f16_ok = True _f16_ok = True
check_input = False check_input = False
def __hash__(self):
return hash(type(self))
def __init__(self, sparse_grad=False): def __init__(self, sparse_grad=False):
self.sparse_grad = sparse_grad self.sparse_grad = sparse_grad
...@@ -1831,8 +2121,7 @@ class AdvancedSubtensor1(COp): ...@@ -1831,8 +2121,7 @@ class AdvancedSubtensor1(COp):
output_storage[0][0] = x.take(i, axis=0, out=None) output_storage[0][0] = x.take(i, axis=0, out=None)
def connection_pattern(self, node): def connection_pattern(self, node):
_x, *index_variables = node.inputs rval = [[True], *([False] for _ in node.inputs[1:])]
rval = [[True], *([False] for _ in index_variables)]
return rval return rval
...@@ -1862,8 +2151,7 @@ class AdvancedSubtensor1(COp): ...@@ -1862,8 +2151,7 @@ class AdvancedSubtensor1(COp):
def R_op(self, inputs, eval_points): def R_op(self, inputs, eval_points):
if eval_points[0] is None: if eval_points[0] is None:
return [None] return [None]
_x, *index_variables = inputs return self.make_node(eval_points[0], *inputs[1:]).outputs
return self.make_node(eval_points[0], *index_variables).outputs
def infer_shape(self, fgraph, node, ishapes): def infer_shape(self, fgraph, node, ishapes):
x, ilist = ishapes x, ilist = ishapes
...@@ -1957,17 +2245,13 @@ class AdvancedSubtensor1(COp): ...@@ -1957,17 +2245,13 @@ class AdvancedSubtensor1(COp):
advanced_subtensor1 = AdvancedSubtensor1() advanced_subtensor1 = AdvancedSubtensor1()
class AdvancedIncSubtensor1(BaseSubtensor, COp): class AdvancedIncSubtensor1(COp):
""" """
Increments a subtensor using advanced slicing (list of index). Increments a subtensor using advanced slicing (list of index).
""" """
__props__ = ( __props__ = ("inplace", "set_instead_of_inc")
"inplace",
"set_instead_of_inc",
)
idx_list = (0,)
check_input = False check_input = False
params_type = ParamsType(inplace=ps.bool, set_instead_of_inc=ps.bool) params_type = ParamsType(inplace=ps.bool, set_instead_of_inc=ps.bool)
...@@ -1983,20 +2267,8 @@ class AdvancedIncSubtensor1(BaseSubtensor, COp): ...@@ -1983,20 +2267,8 @@ class AdvancedIncSubtensor1(BaseSubtensor, COp):
if inplace: if inplace:
self.destroy_map = {0: [0]} self.destroy_map = {0: [0]}
def __hash__(self):
return hash(
(
type(self),
self.inplace,
self.set_instead_of_inc,
)
)
def clone_inplace(self): def clone_inplace(self):
return self.__class__( return self.__class__(inplace=True, set_instead_of_inc=self.set_instead_of_inc)
inplace=True,
set_instead_of_inc=self.set_instead_of_inc,
)
def __str__(self): def __str__(self):
if self.inplace: if self.inplace:
...@@ -2222,8 +2494,7 @@ class AdvancedIncSubtensor1(BaseSubtensor, COp): ...@@ -2222,8 +2494,7 @@ class AdvancedIncSubtensor1(BaseSubtensor, COp):
def R_op(self, inputs, eval_points): def R_op(self, inputs, eval_points):
if None in eval_points[:2]: if None in eval_points[:2]:
return [None] return [None]
_x, _y, *index_variables = inputs return self.make_node(eval_points[0], eval_points[1], *inputs[2:]).outputs
return self.make_node(eval_points[0], eval_points[1], *index_variables).outputs
def connection_pattern(self, node): def connection_pattern(self, node):
rval = [[True], [True], [False]] rval = [[True], [True], [False]]
...@@ -2256,8 +2527,15 @@ advanced_inc_subtensor1 = AdvancedIncSubtensor1() ...@@ -2256,8 +2527,15 @@ advanced_inc_subtensor1 = AdvancedIncSubtensor1()
advanced_set_subtensor1 = AdvancedIncSubtensor1(set_instead_of_inc=True) advanced_set_subtensor1 = AdvancedIncSubtensor1(set_instead_of_inc=True)
def as_tensor_index_variable(idx): def as_index_variable(idx):
"""Convert index to Variable form for advanced indexing.""" if idx is None:
return NoneConst.clone()
if isinstance(idx, slice):
return make_slice(idx)
if isinstance(idx, Variable) and isinstance(idx.type, SliceType):
return idx
if isinstance(idx, Variable) and isinstance(idx.type, NoneTypeT):
return idx
idx = as_tensor_variable(idx) idx = as_tensor_variable(idx)
if idx.type.dtype not in discrete_dtypes: if idx.type.dtype not in discrete_dtypes:
raise TypeError("index must be integers or a boolean mask") raise TypeError("index must be integers or a boolean mask")
...@@ -2269,45 +2547,53 @@ def as_tensor_index_variable(idx): ...@@ -2269,45 +2547,53 @@ def as_tensor_index_variable(idx):
return idx return idx
class AdvancedSubtensor(BaseSubtensor, COp): def check_advanced_indexing_dimensions(input, idx_list):
"""Implements NumPy's advanced indexing.""" """
This function checks if the index list in idx_list is correct.
__props__ = ("idx_list",) If there are any boolean masks, we check if the mask has the
__hash__ = BaseSubtensor.__hash__ same shape as the input. This is enforced in NumPy 0.13.0 and
newer, but not by earlier versions. If the size is not the same,
def c_code_cache_version(self): this method raises an IndexError.
hv = Subtensor.helper_c_code_cache_version() """
if hv: dim_seen = 0
return (3, hv) for index in idx_list:
if index is np.newaxis:
# skip, does not count as an input dimension
pass
elif isinstance(index, np.ndarray) and index.dtype == "bool":
for i in range(index.ndim):
if index.shape[i] != input.shape[dim_seen + i]:
raise IndexError(
"boolean index did not match indexed array "
f"along dimension {int(dim_seen + i)}; dimension is "
f"{int(input.shape[dim_seen + i])} but "
f"corresponding boolean dimension is {index.shape[i]}"
)
dim_seen += index.ndim
else: else:
return () dim_seen += 1
def make_node(self, x, *index_variables):
if len(index_variables) != self.n_index_vars:
raise ValueError(
f"Expected {self.n_index_vars} inputs, got {len(index_variables)}"
)
x = as_tensor_variable(x) class AdvancedSubtensor(Op):
index_variables = tuple(as_tensor_index_variable(a) for a in index_variables) """Implements NumPy's advanced indexing."""
idx_list = self.idx_list __props__ = ()
if len(idx_list) > x.type.ndim:
raise IndexError("too many indices for array")
reconstructed_indices = unflatten_index_variables(index_variables, idx_list) def make_node(self, x, *indices):
x = as_tensor_variable(x)
indices = tuple(map(as_index_variable, indices))
explicit_indices = [] explicit_indices = []
for idx in reconstructed_indices: new_axes = []
if isinstance(idx, slice): for idx in indices:
explicit_indices.append(idx) if isinstance(idx.type, TensorType) and idx.dtype == "bool":
elif hasattr(idx, "dtype") and idx.dtype == "bool":
if idx.type.ndim == 0: if idx.type.ndim == 0:
raise NotImplementedError( raise NotImplementedError(
"Indexing with scalar booleans not supported" "Indexing with scalar booleans not supported"
) )
axis = len(explicit_indices) # Check static shape aligned
axis = len(explicit_indices) - len(new_axes)
indexed_shape = x.type.shape[axis : axis + idx.type.ndim] indexed_shape = x.type.shape[axis : axis + idx.type.ndim]
for j, (indexed_length, indexer_length) in enumerate( for j, (indexed_length, indexer_length) in enumerate(
zip(indexed_shape, idx.type.shape) zip(indexed_shape, idx.type.shape)
...@@ -2325,27 +2611,48 @@ class AdvancedSubtensor(BaseSubtensor, COp): ...@@ -2325,27 +2611,48 @@ class AdvancedSubtensor(BaseSubtensor, COp):
if isinstance(idx, Constant): if isinstance(idx, Constant):
nonzero_indices = [tensor_constant(i) for i in idx.data.nonzero()] nonzero_indices = [tensor_constant(i) for i in idx.data.nonzero()]
else: else:
# Note: Sometimes we could infer a shape error by reasoning about the largest possible size of nonzero
# and seeing that other integer indices cannot possible match it
nonzero_indices = idx.nonzero() nonzero_indices = idx.nonzero()
explicit_indices.extend(nonzero_indices) explicit_indices.extend(nonzero_indices)
else: else:
if isinstance(idx.type, NoneTypeT):
new_axes.append(len(explicit_indices))
explicit_indices.append(idx) explicit_indices.append(idx)
if len(explicit_indices) > x.type.ndim: if (len(explicit_indices) - len(new_axes)) > x.type.ndim:
raise IndexError( raise IndexError(
f"too many indices for array: tensor is {x.type.ndim}-dimensional, but {len(explicit_indices)} were indexed" f"too many indices for array: tensor is {x.type.ndim}-dimensional, but {len(explicit_indices) - len(new_axes)} were indexed"
) )
# Perform basic and advanced indexing shape inference separately (no newaxis) # Perform basic and advanced indexing shape inference separately
basic_group_shape = [] basic_group_shape = []
advanced_indices = [] advanced_indices = []
adv_group_axis = None adv_group_axis = None
last_adv_group_axis = None last_adv_group_axis = None
if new_axes:
expanded_x_shape_list = list(x.type.shape)
for new_axis in new_axes:
expanded_x_shape_list.insert(new_axis, 1)
expanded_x_shape = tuple(expanded_x_shape_list)
else:
expanded_x_shape = x.type.shape
for i, (idx, dim_length) in enumerate( for i, (idx, dim_length) in enumerate(
zip_longest(explicit_indices, x.type.shape, fillvalue=slice(None)) zip_longest(explicit_indices, expanded_x_shape, fillvalue=NoneSliceConst)
): ):
if isinstance(idx, slice): if isinstance(idx.type, NoneTypeT):
basic_group_shape.append(slice_static_length(idx, dim_length)) basic_group_shape.append(1) # New-axis
else: # TensorType (advanced index) elif isinstance(idx.type, SliceType):
if isinstance(idx, Constant):
basic_group_shape.append(slice_static_length(idx.data, dim_length))
elif idx.owner is not None and isinstance(idx.owner.op, MakeSlice):
basic_group_shape.append(
slice_static_length(slice(*idx.owner.inputs), dim_length)
)
else:
# Symbolic root slice (owner is None), or slice operation we don't understand
basic_group_shape.append(None)
else: # TensorType
# Keep track of advanced group axis # Keep track of advanced group axis
if adv_group_axis is None: if adv_group_axis is None:
# First time we see an advanced index # First time we see an advanced index
...@@ -2380,15 +2687,14 @@ class AdvancedSubtensor(BaseSubtensor, COp): ...@@ -2380,15 +2687,14 @@ class AdvancedSubtensor(BaseSubtensor, COp):
return Apply( return Apply(
self, self,
[x, *index_variables], [x, *indices],
[tensor(dtype=x.type.dtype, shape=tuple(indexed_shape))], [tensor(dtype=x.type.dtype, shape=tuple(indexed_shape))],
) )
def R_op(self, inputs, eval_points): def R_op(self, inputs, eval_points):
if eval_points[0] is None: if eval_points[0] is None:
return [None] return [None]
_x, *index_variables = inputs return self.make_node(eval_points[0], *inputs[1:]).outputs
return self.make_node(eval_points[0], *index_variables).outputs
def infer_shape(self, fgraph, node, ishapes): def infer_shape(self, fgraph, node, ishapes):
def is_bool_index(idx): def is_bool_index(idx):
...@@ -2397,32 +2703,30 @@ class AdvancedSubtensor(BaseSubtensor, COp): ...@@ -2397,32 +2703,30 @@ class AdvancedSubtensor(BaseSubtensor, COp):
or getattr(idx, "dtype", None) == "bool" or getattr(idx, "dtype", None) == "bool"
) )
_x, *index_variables = node.inputs indices = node.inputs[1:]
full_indices = unflatten_index_variables(index_variables, self.idx_list)
index_shapes = [] index_shapes = []
for idx in full_indices: for idx, ishape in zip(indices, ishapes[1:], strict=True):
if isinstance(idx, slice): # Mixed bool indexes are converted to nonzero entries
shape0_op = Shape_i(0)
if is_bool_index(idx):
index_shapes.extend((shape0_op(nz_dim),) for nz_dim in nonzero(idx))
# The `ishapes` entries for `SliceType`s will be None, and
# we need to give `indexed_result_shape` the actual slices.
elif isinstance(getattr(idx, "type", None), SliceType):
index_shapes.append(idx) index_shapes.append(idx)
else: else:
shape0_op = Shape_i(0) index_shapes.append(ishape)
if is_bool_index(idx):
index_shapes.extend((shape0_op(nz_dim),) for nz_dim in nonzero(idx))
else:
input_shape_idx = (
index_variables.index(idx) + 1
) # +1 because ishapes[0] is x
index_shapes.append(ishapes[input_shape_idx])
res_shape = list( res_shape = list(
indexed_result_shape(ishapes[0], index_shapes, indices_are_shapes=True) indexed_result_shape(ishapes[0], index_shapes, indices_are_shapes=True)
) )
for i, res_dim_length in enumerate(res_shape): for i, res_dim_length in enumerate(res_shape):
if res_dim_length is None: if res_dim_length is None:
# This can happen when we have a Slice provided by the user (not a constant nor the result of MakeSlice)
# We must compute the Op to find its shape # We must compute the Op to find its shape
res_shape[i] = Shape_i(i)(node.out) res_shape[i] = Shape_i(i)(node.out)
adv_indices = [idx for idx in full_indices if not isinstance(idx, slice)] adv_indices = [idx for idx in indices if not is_basic_idx(idx)]
bool_indices = [idx for idx in adv_indices if is_bool_index(idx)] bool_indices = [idx for idx in adv_indices if is_bool_index(idx)]
# Special logic when the only advanced index group is of bool type. # Special logic when the only advanced index group is of bool type.
...@@ -2433,7 +2737,7 @@ class AdvancedSubtensor(BaseSubtensor, COp): ...@@ -2433,7 +2737,7 @@ class AdvancedSubtensor(BaseSubtensor, COp):
# Because there are no more advanced index groups, there is exactly # Because there are no more advanced index groups, there is exactly
# one output dim per index variable up to the bool group. # one output dim per index variable up to the bool group.
# Note: Scalar integer indexing counts as advanced indexing. # Note: Scalar integer indexing counts as advanced indexing.
start_dim = full_indices.index(bool_index) start_dim = indices.index(bool_index)
res_shape[start_dim] = bool_index.sum() res_shape[start_dim] = bool_index.sum()
assert node.outputs[0].ndim == len(res_shape) assert node.outputs[0].ndim == len(res_shape)
...@@ -2441,31 +2745,25 @@ class AdvancedSubtensor(BaseSubtensor, COp): ...@@ -2441,31 +2745,25 @@ class AdvancedSubtensor(BaseSubtensor, COp):
def perform(self, node, inputs, out_): def perform(self, node, inputs, out_):
(out,) = out_ (out,) = out_
check_advanced_indexing_dimensions(inputs[0], inputs[1:])
x, *index_variables = inputs rval = inputs[0].__getitem__(tuple(inputs[1:]))
full_indices = unflatten_index_variables(index_variables, self.idx_list)
rval = x.__getitem__(tuple(full_indices))
# When there are no arrays, we are not actually doing advanced # When there are no arrays, we are not actually doing advanced
# indexing, so __getitem__ will not return a copy. # indexing, so __getitem__ will not return a copy.
# Since no view_map is set, we need to copy the returned value # Since no view_map is set, we need to copy the returned value
if not any( if not any(
isinstance(idx, np.ndarray) and idx.ndim > 0 for idx in full_indices isinstance(v.type, TensorType) and v.ndim > 0 for v in node.inputs[1:]
): ):
rval = rval.copy() rval = rval.copy()
out[0] = rval out[0] = rval
def connection_pattern(self, node): def connection_pattern(self, node):
_x, *index_variables = node.inputs rval = [[True], *([False] for _ in node.inputs[1:])]
rval = [[True], *([False] for _ in index_variables)]
return rval return rval
def grad(self, inputs, grads): def grad(self, inputs, grads):
(gz,) = grads (gz,) = grads
x, *index_variables = inputs x = inputs[0]
if x.dtype in discrete_dtypes: if x.dtype in discrete_dtypes:
# The output dtype is the same as x # The output dtype is the same as x
gx = x.zeros_like(dtype=config.floatX) gx = x.zeros_like(dtype=config.floatX)
...@@ -2473,10 +2771,10 @@ class AdvancedSubtensor(BaseSubtensor, COp): ...@@ -2473,10 +2771,10 @@ class AdvancedSubtensor(BaseSubtensor, COp):
raise NotImplementedError("No support for complex grad yet") raise NotImplementedError("No support for complex grad yet")
else: else:
gx = x.zeros_like() gx = x.zeros_like()
rest = inputs[1:]
return [ return [
AdvancedIncSubtensor(self.idx_list)(gx, gz, *index_variables), advanced_inc_subtensor(gx, gz, *rest),
*(disconnected_type() for _ in range(len(index_variables))), *(disconnected_type() for _ in range(len(rest))),
] ]
@staticmethod @staticmethod
...@@ -2493,7 +2791,7 @@ class AdvancedSubtensor(BaseSubtensor, COp): ...@@ -2493,7 +2791,7 @@ class AdvancedSubtensor(BaseSubtensor, COp):
This function checks if the advanced indexing is non-consecutive, This function checks if the advanced indexing is non-consecutive,
in which case the advanced index dimensions are placed on the left of the in which case the advanced index dimensions are placed on the left of the
output array, regardless of their original position. output array, regardless of their opriginal position.
See: https://numpy.org/doc/stable/user/basics.indexing.html#combining-advanced-and-basic-indexing See: https://numpy.org/doc/stable/user/basics.indexing.html#combining-advanced-and-basic-indexing
...@@ -2508,21 +2806,11 @@ class AdvancedSubtensor(BaseSubtensor, COp): ...@@ -2508,21 +2806,11 @@ class AdvancedSubtensor(BaseSubtensor, COp):
bool bool
True if the advanced indexing is non-consecutive, False otherwise. True if the advanced indexing is non-consecutive, False otherwise.
""" """
indices = indices_from_subtensor(node.inputs[1:], node.op.idx_list) _, *idxs = node.inputs
return _non_consecutive_adv_indexing(indices) return _non_consecutive_adv_indexing(idxs)
class AdvancedSubtensorPrinter(SubtensorPrinter): advanced_subtensor = AdvancedSubtensor()
def process(self, r, pstate):
return self._process(r.owner.op.idx_list, r.owner.inputs, pstate)
pprint.assign(AdvancedSubtensor, AdvancedSubtensorPrinter())
def advanced_subtensor(x, *index_variables):
idx_list, flat_index_vars = flatten_index_variables(index_variables)
return AdvancedSubtensor(idx_list)(x, *flat_index_vars)
@_vectorize_node.register(AdvancedSubtensor) @_vectorize_node.register(AdvancedSubtensor)
...@@ -2542,33 +2830,30 @@ def vectorize_advanced_subtensor(op: AdvancedSubtensor, node, *batch_inputs): ...@@ -2542,33 +2830,30 @@ def vectorize_advanced_subtensor(op: AdvancedSubtensor, node, *batch_inputs):
# which would put the indexed results to the left of the batch dimensions! # which would put the indexed results to the left of the batch dimensions!
# TODO: Not all cases must be handled by Blockwise, but the logic is complex # TODO: Not all cases must be handled by Blockwise, but the logic is complex
return vectorize_node_fallback(op, node, batch_x, *batch_idxs) # Blockwise doesn't accept None or Slices types so we raise informative error here
# TODO: Implement these internally, so Blockwise is always a safe fallback
if any(not isinstance(idx, TensorVariable) for idx in idxs):
raise NotImplementedError(
"Vectorized AdvancedSubtensor with batched indexes or non-consecutive advanced indexing "
"and slices or newaxis is currently not supported."
)
else:
return vectorize_node_fallback(op, node, batch_x, *batch_idxs)
# Otherwise we just need to add None slices for every new batch dim # Otherwise we just need to add None slices for every new batch dim
x_batch_ndim = batch_x.type.ndim - x.type.ndim x_batch_ndim = batch_x.type.ndim - x.type.ndim
new_idx_list = (slice(None),) * x_batch_ndim + op.idx_list empty_slices = (slice(None),) * x_batch_ndim
return type(op)(new_idx_list).make_node(batch_x, *batch_idxs) return op.make_node(batch_x, *empty_slices, *batch_idxs)
class AdvancedIncSubtensor(BaseSubtensor, Op): class AdvancedIncSubtensor(Op):
"""Increments a subtensor using advanced indexing.""" """Increments a subtensor using advanced indexing."""
__props__ = ( __props__ = ("inplace", "set_instead_of_inc", "ignore_duplicates")
"idx_list",
"inplace",
"set_instead_of_inc",
"ignore_duplicates",
)
__hash__ = BaseSubtensor.__hash__
def __init__( def __init__(
self, self, inplace=False, set_instead_of_inc=False, ignore_duplicates=False
idx_list,
inplace=False,
set_instead_of_inc=False,
ignore_duplicates=False,
): ):
super().__init__(idx_list)
self.set_instead_of_inc = set_instead_of_inc self.set_instead_of_inc = set_instead_of_inc
self.inplace = inplace self.inplace = inplace
if inplace: if inplace:
...@@ -2582,27 +2867,25 @@ class AdvancedIncSubtensor(BaseSubtensor, Op): ...@@ -2582,27 +2867,25 @@ class AdvancedIncSubtensor(BaseSubtensor, Op):
else "AdvancedIncSubtensor" else "AdvancedIncSubtensor"
) )
def make_node(self, x, y, *index_variables): def make_node(self, x, y, *inputs):
if len(index_variables) != self.n_index_vars:
raise ValueError(
f"Expected {self.n_index_vars} tensor inputs but got {len(index_variables)}"
)
index_variables = tuple(
as_tensor_index_variable(idx) for idx in index_variables
)
x = as_tensor_variable(x) x = as_tensor_variable(x)
y = as_tensor_variable(y) y = as_tensor_variable(y)
new_inputs = []
for inp in inputs:
if isinstance(inp, list | tuple):
inp = as_tensor_variable(inp)
new_inputs.append(inp)
return Apply( return Apply(
self, self,
[x, y, *index_variables], (x, y, *new_inputs),
[x.type()], [x.type()],
) )
def perform(self, node, inputs, out_): def perform(self, node, inputs, out_):
x, y, *index_variables = inputs x, y, *indices = inputs
full_indices = unflatten_index_variables(index_variables, self.idx_list) check_advanced_indexing_dimensions(x, indices)
(out,) = out_ (out,) = out_
if not self.inplace: if not self.inplace:
...@@ -2611,29 +2894,28 @@ class AdvancedIncSubtensor(BaseSubtensor, Op): ...@@ -2611,29 +2894,28 @@ class AdvancedIncSubtensor(BaseSubtensor, Op):
out[0] = x out[0] = x
if self.set_instead_of_inc: if self.set_instead_of_inc:
out[0][tuple(full_indices)] = y out[0][tuple(indices)] = y
elif self.ignore_duplicates: elif self.ignore_duplicates:
out[0][tuple(full_indices)] += y out[0][tuple(indices)] += y
else: else:
np.add.at(out[0], tuple(full_indices), y) np.add.at(out[0], tuple(indices), y)
def infer_shape(self, fgraph, node, ishapes): def infer_shape(self, fgraph, node, ishapes):
return [ishapes[0]] return [ishapes[0]]
def connection_pattern(self, node): def connection_pattern(self, node):
_x, _y, *index_variables = node.inputs rval = [[True], [True], *([False] for _ in node.inputs[2:])]
rval = [[True], [True], *([False] for _ in index_variables)]
return rval return rval
def R_op(self, inputs, eval_points): def R_op(self, inputs, eval_points):
if None in eval_points[:2]: if None in eval_points[:2]:
return [None] return [None]
_x, _y, *index_variables = inputs return self.make_node(eval_points[0], eval_points[1], *inputs[2:]).outputs
return self.make_node(eval_points[0], eval_points[1], *index_variables).outputs
def grad(self, inpt, output_gradients): def grad(self, inpt, output_gradients):
x, y, *index_variables = inpt x, y = inpt[:2]
idxs = inpt[2:]
(outgrad,) = output_gradients (outgrad,) = output_gradients
if x.dtype in discrete_dtypes: if x.dtype in discrete_dtypes:
# The output dtype is the same as x # The output dtype is the same as x
...@@ -2646,22 +2928,21 @@ class AdvancedIncSubtensor(BaseSubtensor, Op): ...@@ -2646,22 +2928,21 @@ class AdvancedIncSubtensor(BaseSubtensor, Op):
raise NotImplementedError("No support for complex grad yet") raise NotImplementedError("No support for complex grad yet")
else: else:
if self.set_instead_of_inc: if self.set_instead_of_inc:
gx = ( gx = advanced_set_subtensor(outgrad, y.zeros_like(), *idxs)
type(self)(self.idx_list, set_instead_of_inc=True)
.make_node(outgrad, y.zeros_like(), *index_variables)
.outputs[0]
)
else: else:
gx = outgrad gx = outgrad
gy = ( gy = advanced_subtensor(outgrad, *idxs)
AdvancedSubtensor(self.idx_list)
.make_node(outgrad, *index_variables)
.outputs[0]
)
# Make sure to sum gy over the dimensions of y that have been # Make sure to sum gy over the dimensions of y that have been
# added or broadcasted # added or broadcasted
gy = _sum_grad_over_bcasted_dims(y, gy) gy = _sum_grad_over_bcasted_dims(y, gy)
return [gx, gy, *(disconnected_type() for _ in range(len(index_variables)))] return [gx, gy, *(disconnected_type() for _ in range(len(idxs)))]
@staticmethod
def non_contiguous_adv_indexing(node: Apply) -> bool:
warnings.warn(
"Method was renamed to `non_consecutive_adv_indexing`", FutureWarning
)
return AdvancedIncSubtensor.non_consecutive_adv_indexing(node)
@staticmethod @staticmethod
def non_consecutive_adv_indexing(node: Apply) -> bool: def non_consecutive_adv_indexing(node: Apply) -> bool:
...@@ -2670,7 +2951,7 @@ class AdvancedIncSubtensor(BaseSubtensor, Op): ...@@ -2670,7 +2951,7 @@ class AdvancedIncSubtensor(BaseSubtensor, Op):
This function checks if the advanced indexing is non-consecutive, This function checks if the advanced indexing is non-consecutive,
in which case the advanced index dimensions are placed on the left of the in which case the advanced index dimensions are placed on the left of the
output array, regardless of their original position. output array, regardless of their opriginal position.
See: https://numpy.org/doc/stable/user/basics.indexing.html#combining-advanced-and-basic-indexing See: https://numpy.org/doc/stable/user/basics.indexing.html#combining-advanced-and-basic-indexing
...@@ -2685,257 +2966,16 @@ class AdvancedIncSubtensor(BaseSubtensor, Op): ...@@ -2685,257 +2966,16 @@ class AdvancedIncSubtensor(BaseSubtensor, Op):
bool bool
True if the advanced indexing is non-consecutive, False otherwise. True if the advanced indexing is non-consecutive, False otherwise.
""" """
indices = indices_from_subtensor(node.inputs[2:], node.op.idx_list) _, _, *idxs = node.inputs
return _non_consecutive_adv_indexing(indices) return _non_consecutive_adv_indexing(idxs)
def advanced_inc_subtensor(x, y, *args, **kwargs):
idx_list, flat_index_vars = flatten_index_variables(args)
return AdvancedIncSubtensor(idx_list, **kwargs)(x, y, *flat_index_vars)
def advanced_set_subtensor(x, y, *args, **kwargs): advanced_inc_subtensor = AdvancedIncSubtensor()
return advanced_inc_subtensor(x, y, *args, set_instead_of_inc=True, **kwargs) advanced_set_subtensor = AdvancedIncSubtensor(set_instead_of_inc=True)
advanced_inc_subtensor_nodup = AdvancedIncSubtensor(ignore_duplicates=True)
advanced_set_subtensor_nodup = AdvancedIncSubtensor(
class AdvancedIncSubtensorPrinter(SubtensorPrinter): set_instead_of_inc=True, ignore_duplicates=True
def process(self, r, pstate): )
x, y, *index_variables = r.owner.inputs
res = self._process(r.owner.op.idx_list, [x, *index_variables], pstate)
with set_precedence(pstate, 1000):
y_str = pstate.pprinter.process(y, pstate)
if r.owner.op.set_instead_of_inc:
res = f"set_subtensor({res}, {y_str})"
else:
res = f"inc_subtensor({res}, {y_str})"
return res
pprint.assign(AdvancedIncSubtensor, AdvancedIncSubtensorPrinter())
def set_subtensor(x, y, inplace=False, tolerate_inplace_aliasing=False):
"""
Return x with the given subtensor overwritten by y.
Parameters
----------
x
Symbolic variable for the lvalue of = operation.
y
Symbolic variable for the rvalue of = operation.
tolerate_inplace_aliasing
See inc_subtensor for documentation.
Examples
--------
To replicate the numpy expression ``r[10:] = 5``, type
.. code-block:: python
from pytensor.tensor import set_subtensor, vector
r = vector("r")
new_r = set_subtensor(r[10:], 5)
Consider using :meth:`pytensor.tensor.variable.TensorVariable.set` instead.
"""
return inc_subtensor(
x,
y,
inplace,
set_instead_of_inc=True,
tolerate_inplace_aliasing=tolerate_inplace_aliasing,
)
def inc_subtensor(
x,
y,
inplace=False,
set_instead_of_inc=False,
tolerate_inplace_aliasing=False,
ignore_duplicates=False,
):
"""Update the value of an indexed array by a given amount.
This is equivalent to ``x[indices] += y`` or ``np.add.at(x, indices, y)``,
depending on the value of `ignore_duplicates`.
Parameters
----------
x
The symbolic result of a Subtensor operation.
y
The amount by which to increment the array.
inplace
Don't use. PyTensor will do in-place operations itself, when possible.
set_instead_of_inc
If True, do a set_subtensor instead.
tolerate_inplace_aliasing:
Allow `x` and `y` to be views of a single underlying array even while
working in-place. For correct results, `x` and `y` must not be overlapping
views; if they overlap, the result of this `Op` will generally be
incorrect. This value has no effect if ``inplace=False``.
ignore_duplicates
This determines whether ``x[indices] += y`` is used or
``np.add.at(x, indices, y)``.
Examples
--------
To replicate the expression ``r[10:] += 5``:
.. code-block:: python
from pytensor.tensor import ivector, inc_subtensor
r = ivector("r")
new_r = inc_subtensor(r[10:], 5)
To replicate the expression ``r[[0, 1, 0]] += 5``:
.. code-block:: python
r = ivector("r")
new_r = inc_subtensor(r[[0, 1, 0]], 5, ignore_duplicates=True)
Consider using :meth:`pytensor.tensor.variable.TensorVariable.inc` instead.
"""
# First of all, y cannot have a higher dimension than x,
# nor have non-broadcastable dimensions where x is broadcastable.
x = as_tensor_variable(x)
y = as_tensor_variable(y)
if y.ndim > x.ndim:
raise TypeError(
f"Trying to increment a {int(x.ndim)}-dimensional "
f"subtensor with a {int(y.ndim)}-dimensional value."
)
dim_offset = x.ndim - y.ndim
for dim in range(y.ndim):
if x.broadcastable[dim + dim_offset] and not y.broadcastable[dim]:
# It is acceptable to try to increment a subtensor with a
# broadcastable dim with a tensor that is not broadcastable
# on that dimension. However, its length must then be 1.
# We insert a SpecifyShape Op to make sure it is the case.
y = specify_broadcastable(y, dim)
if x.owner is None:
raise TypeError("x must be the result of a subtensor operation")
# retrieve idx_list from x.owner
if isinstance(x.owner.op, Subtensor):
if tolerate_inplace_aliasing:
destroyhandler_tolerate_aliased = [[0, 1]]
else:
destroyhandler_tolerate_aliased = []
the_op = IncSubtensor(
x.owner.op.idx_list,
inplace,
set_instead_of_inc,
destroyhandler_tolerate_aliased=destroyhandler_tolerate_aliased,
)
real_x, *index_variables = x.owner.inputs
return the_op(real_x, y, *index_variables)
elif isinstance(x.owner.op, AdvancedSubtensor1):
real_x = x.owner.inputs[0]
ilist = x.owner.inputs[1]
if ignore_duplicates:
the_op = AdvancedIncSubtensor(
(0,),
inplace,
set_instead_of_inc=set_instead_of_inc,
ignore_duplicates=True,
)
else:
the_op = AdvancedIncSubtensor1(
inplace, set_instead_of_inc=set_instead_of_inc
)
return the_op(real_x, y, ilist)
elif isinstance(x.owner.op, AdvancedSubtensor):
real_x, *index_variables = x.owner.inputs
the_op = AdvancedIncSubtensor(
x.owner.op.idx_list,
inplace,
set_instead_of_inc=set_instead_of_inc,
ignore_duplicates=ignore_duplicates,
)
return the_op(real_x, y, *index_variables)
elif isinstance(x.owner.op, DimShuffle):
inner_x = x.owner.inputs[0]
# In the dimshuffle case, there are in fact two dimshuffles:
# one to make the indexed dimension the last one,
# and one to put it back where it was. So, in the case where we have
# inc_subtensor(x[:,i], y), the graph is actually
# inc_subtensor((x.T)[i].T, y).
# We could get all the way to x, and then get rid of the dimshuffles
# completely, but the problem is that advanced_inc_subtensor1 can only
# work on the first (outer-most, left-most) dimension of x,
# just like advanced_subtensor1.
# So we call advanced_inc_subtensor1(x.T, i, y.T) (as we also need to
# transpose y if it is not a scalar or a vector), but then we need to
# return something that has the same shape as x, not as x.T (inner_x).
# So re-apply the outer dimshuffle on the new inc_subtensor,
# and return advanced_inc_subtensor1(x.T, i, y.T).T.
# Get the dimshuffle pattern to apply to y.
x_order = x.owner.op.new_order
y_order = ["x"] * x.ndim
for i, v in enumerate(x_order):
if v != "x" and (v - dim_offset) >= 0:
y_order[v - dim_offset] = i
inner_incsubtensor = inc_subtensor(
inner_x,
y.dimshuffle(y_order),
inplace=inplace,
set_instead_of_inc=set_instead_of_inc,
tolerate_inplace_aliasing=tolerate_inplace_aliasing,
ignore_duplicates=ignore_duplicates,
)
# The broadcastable pattern of inner_x may not be the same as
# the one of x, so we have to build a new dimshuffle here,
# instead of reusing x.owner.op().
return inner_incsubtensor.dimshuffle(x.owner.op.new_order)
elif isinstance(x.owner.op, Reshape):
# This case happens when the indices are not arranged as a vector, but
# as a higher-dimensional array. This is handled by the subtensor
# by flattening this list, taking the subtensor, then reshaping the
# result.
inner_x = x.owner.inputs[0]
# Try to apply inc_subtensor on inner_x.
# If it works, there is no need to reshape, as the inc_subtensor
# will have the same shape as inner_x, which is what we want.
# We also explicitly duplicate y to its broadcasted shape
# before we partially flatten it to inner_x dimension. This is
# not strictly needed in all cases, but it is easier this way.
if y.ndim > 0:
# This if is needed to prevent some useless warning about
# old code bug.
expanded_y = alloc(y, *[x.shape[i] for i in range(x.ndim)])
flattened_y = expanded_y.reshape(inner_x.shape)
else:
flattened_y = y
inner_incsubtensor = inc_subtensor(
inner_x,
flattened_y,
inplace=inplace,
set_instead_of_inc=set_instead_of_inc,
tolerate_inplace_aliasing=tolerate_inplace_aliasing,
ignore_duplicates=ignore_duplicates,
)
return inner_incsubtensor
else:
raise TypeError("x must be the result of a subtensor operation")
def take(a, indices, axis=None, mode="raise"): def take(a, indices, axis=None, mode="raise"):
...@@ -2981,6 +3021,39 @@ def take(a, indices, axis=None, mode="raise"): ...@@ -2981,6 +3021,39 @@ def take(a, indices, axis=None, mode="raise"):
return a[full_indices] return a[full_indices]
@_get_vector_length.register(Subtensor) # type: ignore
def _get_vector_length_Subtensor(op, var):
# If we take a slice, we know how many elements it will result in
# TODO: We can cover more `*Subtensor` cases.
try:
indices = pytensor.tensor.subtensor.get_idx_list(
var.owner.inputs, var.owner.op.idx_list
)
start = (
None
if indices[0].start is None
else get_scalar_constant_value(indices[0].start)
)
stop = (
None
if indices[0].stop is None
else get_scalar_constant_value(indices[0].stop)
)
step = (
None
if indices[0].step is None
else get_scalar_constant_value(indices[0].step)
)
if start == stop:
return 0
arg_len = get_vector_length(var.owner.inputs[0])
return len(range(*slice(start, stop, step).indices(arg_len)))
except (ValueError, NotScalarConstantError):
raise ValueError(f"Length of {var} cannot be determined")
def slice_at_axis(sl: slice, axis: int) -> tuple[slice, ...]: def slice_at_axis(sl: slice, axis: int) -> tuple[slice, ...]:
""" """
Construct tuple of slices to slice an array in the given dimension. Construct tuple of slices to slice an array in the given dimension.
......
...@@ -15,8 +15,9 @@ from pytensor.scalar import ( ...@@ -15,8 +15,9 @@ from pytensor.scalar import (
ComplexError, ComplexError,
) )
from pytensor.tensor import _get_vector_length from pytensor.tensor import _get_vector_length
from pytensor.tensor.exceptions import AdvancedIndexingError
from pytensor.tensor.type import TensorType from pytensor.tensor.type import TensorType
from pytensor.tensor.type_other import NoneTypeT from pytensor.tensor.type_other import NoneConst
from pytensor.tensor.utils import hash_from_ndarray from pytensor.tensor.utils import hash_from_ndarray
...@@ -454,14 +455,15 @@ class _tensor_py_operators: ...@@ -454,14 +455,15 @@ class _tensor_py_operators:
elif not isinstance(args, tuple): elif not isinstance(args, tuple):
args = (args,) args = (args,)
# Count the dimensions, check for bools and find ellipses.
ellipses = [] ellipses = []
index_dim_count = 0 index_dim_count = 0
for i, arg in enumerate(args): for i, arg in enumerate(args):
if arg is None or ( if arg is np.newaxis or arg is NoneConst:
isinstance(arg, Variable) and isinstance(arg.type, NoneTypeT) # no increase in index_dim_count
):
pass pass
elif arg is Ellipsis: elif arg is Ellipsis:
# no increase in index_dim_count
ellipses.append(i) ellipses.append(i)
elif ( elif (
isinstance(arg, np.ndarray | Variable) isinstance(arg, np.ndarray | Variable)
...@@ -503,41 +505,6 @@ class _tensor_py_operators: ...@@ -503,41 +505,6 @@ class _tensor_py_operators:
self.ndim - index_dim_count self.ndim - index_dim_count
) )
if any(
arg is None
or (isinstance(arg, Variable) and isinstance(arg.type, NoneTypeT))
for arg in args
):
expansion_axes = []
new_args = []
# Track dims consumed by args and inserted `None`s after ellipsis
counter = 0
nones = 0
for arg in args:
if arg is None or (
isinstance(arg, Variable) and isinstance(arg.type, NoneTypeT)
):
expansion_axes.append(counter + nones) # Expand here
nones += 1
new_args.append(slice(None))
else:
new_args.append(arg)
consumed = 1
if hasattr(arg, "dtype") and arg.dtype == "bool":
consumed = arg.ndim
counter += consumed
expanded = pt.expand_dims(self, expansion_axes)
if all(
isinstance(arg, slice)
and arg.start is None
and arg.stop is None
and arg.step is None
for arg in new_args
):
return expanded
return expanded[tuple(new_args)]
def is_empty_array(val): def is_empty_array(val):
return (isinstance(val, tuple | list) and len(val) == 0) or ( return (isinstance(val, tuple | list) and len(val) == 0) or (
isinstance(val, np.ndarray) and val.size == 0 isinstance(val, np.ndarray) and val.size == 0
...@@ -553,16 +520,74 @@ class _tensor_py_operators: ...@@ -553,16 +520,74 @@ class _tensor_py_operators:
for inp in args for inp in args
) )
if all( # Determine if advanced indexing is needed or not. The logic is
( # already in `index_vars_to_types`: if it succeeds, standard indexing is
isinstance(arg, slice | int | float | np.number) # used; if it fails with `AdvancedIndexingError`, advanced indexing is
or (hasattr(arg, "ndim") and arg.ndim == 0 and arg.dtype != "bool") # used
) advanced = False
for arg in args for i, arg in enumerate(args):
): if includes_bool(arg):
return pt.subtensor.basic_subtensor(self, *args) advanced = True
else: break
if arg is not np.newaxis and arg is not NoneConst:
try:
pt.subtensor.index_vars_to_types(arg)
except AdvancedIndexingError:
if advanced:
break
else:
advanced = True
if advanced:
return pt.subtensor.advanced_subtensor(self, *args) return pt.subtensor.advanced_subtensor(self, *args)
else:
if np.newaxis in args or NoneConst in args:
# `np.newaxis` (i.e. `None`) in NumPy indexing mean "add a new
# broadcastable dimension at this location". Since PyTensor adds
# new broadcastable dimensions via the `DimShuffle` `Op`, the
# following code uses said `Op` to add one of the new axes and
# then uses recursion to apply any other indices and add any
# remaining new axes.
counter = 0
pattern = []
new_args = []
for arg in args:
if arg is np.newaxis or arg is NoneConst:
pattern.append("x")
new_args.append(slice(None, None, None))
else:
pattern.append(counter)
counter += 1
new_args.append(arg)
pattern.extend(list(range(counter, self.ndim)))
view = self.dimshuffle(pattern)
full_slices = True
for arg in new_args:
# We can't do arg == slice(None, None, None) as in
# Python 2.7, this call __lt__ if we have a slice
# with some symbolic variable.
if not (
isinstance(arg, slice)
and (arg.start is None or arg.start is NoneConst)
and (arg.stop is None or arg.stop is NoneConst)
and (arg.step is None or arg.step is NoneConst)
):
full_slices = False
if full_slices:
return view
else:
return view.__getitem__(tuple(new_args))
else:
return pt.subtensor.Subtensor(args)(
self,
*pt.subtensor.get_slice_elements(
args, lambda entry: isinstance(entry, Variable)
),
)
def __setitem__(self, key, value): def __setitem__(self, key, value):
raise TypeError( raise TypeError(
......
...@@ -2,10 +2,9 @@ from itertools import zip_longest ...@@ -2,10 +2,9 @@ from itertools import zip_longest
from pytensor import as_symbolic from pytensor import as_symbolic
from pytensor.graph import Constant, node_rewriter from pytensor.graph import Constant, node_rewriter
from pytensor.tensor import arange, specify_shape from pytensor.tensor import TensorType, arange, specify_shape
from pytensor.tensor.subtensor import _non_consecutive_adv_indexing, inc_subtensor from pytensor.tensor.subtensor import _non_consecutive_adv_indexing, inc_subtensor
from pytensor.tensor.type_other import NoneTypeT, SliceType from pytensor.tensor.type_other import NoneTypeT, SliceType
from pytensor.tensor.variable import TensorVariable
from pytensor.xtensor.basic import tensor_from_xtensor, xtensor_from_tensor from pytensor.xtensor.basic import tensor_from_xtensor, xtensor_from_tensor
from pytensor.xtensor.indexing import Index, IndexUpdate, index from pytensor.xtensor.indexing import Index, IndexUpdate, index
from pytensor.xtensor.rewriting.utils import register_lower_xtensor from pytensor.xtensor.rewriting.utils import register_lower_xtensor
...@@ -107,7 +106,7 @@ def _lower_index(node): ...@@ -107,7 +106,7 @@ def _lower_index(node):
# We can use basic indexing directly if no other index acts on this dimension # We can use basic indexing directly if no other index acts on this dimension
# This is an optimization that avoids creating an unnecessary arange tensor # This is an optimization that avoids creating an unnecessary arange tensor
# and facilitates the use of the specialized AdvancedSubtensor1 when possible # and facilitates the use of the specialized AdvancedSubtensor1 when possible
aligned_idxs.append(to_basic_idx(idx)) aligned_idxs.append(idx)
basic_idx_axis.append(out_dims.index(x_dim)) basic_idx_axis.append(out_dims.index(x_dim))
else: else:
# Otherwise we need to convert the basic index into an equivalent advanced indexing # Otherwise we need to convert the basic index into an equivalent advanced indexing
...@@ -132,7 +131,7 @@ def _lower_index(node): ...@@ -132,7 +131,7 @@ def _lower_index(node):
if basic_idx_axis: if basic_idx_axis:
aligned_idxs = [ aligned_idxs = [
idx.squeeze(axis=basic_idx_axis) idx.squeeze(axis=basic_idx_axis)
if (isinstance(idx, TensorVariable) and idx.type.ndim > 0) if (isinstance(idx.type, TensorType) and idx.type.ndim > 0)
else idx else idx
for idx in aligned_idxs for idx in aligned_idxs
] ]
......
...@@ -26,7 +26,9 @@ from pytensor.graph.rewriting.unify import LiteralString, OpPattern ...@@ -26,7 +26,9 @@ from pytensor.graph.rewriting.unify import LiteralString, OpPattern
from pytensor.raise_op import assert_op from pytensor.raise_op import assert_op
from pytensor.tensor.math import Dot, add, dot, exp from pytensor.tensor.math import Dot, add, dot, exp
from pytensor.tensor.rewriting.basic import constant_folding from pytensor.tensor.rewriting.basic import constant_folding
from pytensor.tensor.subtensor import AdvancedSubtensor
from pytensor.tensor.type import matrix, values_eq_approx_always_true, vector from pytensor.tensor.type import matrix, values_eq_approx_always_true, vector
from pytensor.tensor.type_other import MakeSlice, SliceConstant, slicetype
from tests.graph.utils import ( from tests.graph.utils import (
MyOp, MyOp,
MyType, MyType,
...@@ -627,6 +629,21 @@ def test_pre_constant_merge(): ...@@ -627,6 +629,21 @@ def test_pre_constant_merge():
assert res == [o2] assert res == [o2]
assert o2.owner.inputs[2] is c2 assert o2.owner.inputs[2] is c2
# What is this supposed to test?
ms = MakeSlice()(1)
res = pre_constant_merge(empty_fgraph, [ms])
assert res == [ms]
const_slice = SliceConstant(type=slicetype, data=slice(1, None, 2))
assert isinstance(const_slice, Constant)
adv = AdvancedSubtensor()(matrix(), [2, 3], const_slice)
res = pre_constant_merge(empty_fgraph, adv)
assert res == [adv]
def test_pre_greedy_node_rewriter(): def test_pre_greedy_node_rewriter():
empty_fgraph = FunctionGraph([], []) empty_fgraph = FunctionGraph([], [])
...@@ -662,6 +679,15 @@ def test_pre_greedy_node_rewriter(): ...@@ -662,6 +679,15 @@ def test_pre_greedy_node_rewriter():
assert cst.owner.inputs[0] is o1 assert cst.owner.inputs[0] is o1
assert cst.owner.inputs[4] is cst.owner.inputs[0] assert cst.owner.inputs[4] is cst.owner.inputs[0]
# What exactly is this supposed to test?
ms = MakeSlice()(1)
cst = pre_greedy_node_rewriter(empty_fgraph, [constant_folding], ms)
assert isinstance(cst, SliceConstant)
# Make sure constant of slice signature is hashable.
assert isinstance(hash(cst.signature()), int)
@pytest.mark.parametrize("tracks", [True, False]) @pytest.mark.parametrize("tracks", [True, False])
@pytest.mark.parametrize("out_pattern", [(op2, "x"), "x", 1.0]) @pytest.mark.parametrize("out_pattern", [(op2, "x"), "x", 1.0])
......
...@@ -225,37 +225,6 @@ def test_jax_IncSubtensor(): ...@@ -225,37 +225,6 @@ def test_jax_IncSubtensor():
compare_jax_and_py([], [out_pt], []) compare_jax_and_py([], [out_pt], [])
@pytest.mark.parametrize(
"func", (pt_subtensor.advanced_inc_subtensor1, pt_subtensor.advanced_set_subtensor1)
)
def test_jax_AdvancedIncSubtensor1_runtime_broadcast(func):
"""Test that JAX backend checks for runtime broadcasting in AdvancedIncSubtensor1.
JAX silently broadcasts when using .at[].set() or .at[].add(), but PyTensor
requires explicit broadcastable dimensions. This test ensures we raise the same
error as the Python/C backend when runtime broadcasting would occur.
"""
from pytensor import function
y = pt.matrix("y", dtype="float64", shape=(None, None))
x = pt.zeros((10, 5))
idxs = np.repeat(np.arange(10), 2) # 20 indices
out = func(x, y, idxs)
f = function([y], out, mode="JAX")
# Should work with correctly sized y
f(np.ones((20, 5)))
# Should raise for runtime broadcasting on first dimension
with pytest.raises(ValueError, match="Runtime broadcasting not allowed"):
f(np.ones((1, 5)))
# Should raise for runtime broadcasting on second dimension
with pytest.raises(ValueError, match="Runtime broadcasting not allowed"):
f(np.ones((20, 1)))
def test_jax_IncSubtensor_boolean_indexing_reexpressible(): def test_jax_IncSubtensor_boolean_indexing_reexpressible():
"""Setting or incrementing values with boolean indexing. """Setting or incrementing values with boolean indexing.
......
...@@ -187,6 +187,27 @@ def test_mlx_inplace_variants(): ...@@ -187,6 +187,27 @@ def test_mlx_inplace_variants():
compare_mlx_and_py([], [out_pt], []) compare_mlx_and_py([], [out_pt], [])
@pytest.mark.xfail(
reason="MLX slice indices must be integers or None, dynamic slices not supported"
)
def test_mlx_MakeSlice():
"""Test MakeSlice operation."""
# Test slice creation
start = pt.iscalar("start")
stop = pt.iscalar("stop")
step = pt.iscalar("step")
# Create a slice using MakeSlice
slice_op = pt_subtensor.MakeSlice()
slice_pt = slice_op(start, stop, step)
# Use simple constant array instead of arange
x_pt = pt.constant(np.arange(10, dtype=np.float32))
out_pt = x_pt[slice_pt]
compare_mlx_and_py([start, stop, step], [out_pt], [1, 8, 2])
def test_mlx_subtensor_edge_cases(): def test_mlx_subtensor_edge_cases():
"""Test edge cases and boundary conditions.""" """Test edge cases and boundary conditions."""
# Empty slices - use constant array # Empty slices - use constant array
......
...@@ -3,7 +3,9 @@ import contextlib ...@@ -3,7 +3,9 @@ import contextlib
import numpy as np import numpy as np
import pytest import pytest
import pytensor.scalar as ps
import pytensor.tensor as pt import pytensor.tensor as pt
from pytensor import Mode, as_symbolic
from pytensor.tensor import as_tensor from pytensor.tensor import as_tensor
from pytensor.tensor.subtensor import ( from pytensor.tensor.subtensor import (
AdvancedIncSubtensor, AdvancedIncSubtensor,
...@@ -18,16 +20,51 @@ from pytensor.tensor.subtensor import ( ...@@ -18,16 +20,51 @@ from pytensor.tensor.subtensor import (
inc_subtensor, inc_subtensor,
set_subtensor, set_subtensor,
) )
from tests.link.numba.test_basic import ( from tests.link.numba.test_basic import compare_numba_and_py, numba_mode
compare_numba_and_py,
numba_inplace_mode,
numba_mode,
)
rng = np.random.default_rng(sum(map(ord, "Numba subtensors"))) rng = np.random.default_rng(sum(map(ord, "Numba subtensors")))
@pytest.mark.parametrize("step", [None, 1, 2, -2, "x"], ids=lambda x: f"step={x}")
@pytest.mark.parametrize("stop", [None, 10, "x"], ids=lambda x: f"stop={x}")
@pytest.mark.parametrize("start", [None, 0, 3, "x"], ids=lambda x: f"start={x}")
def test_slice(start, stop, step):
x = ps.int64("x")
sym_slice = as_symbolic(
slice(
x if start == "x" else start,
x if stop == "x" else stop,
x if step == "x" else step,
)
)
no_opt_mode = Mode(linker="numba", optimizer=None)
evaled_slice = sym_slice.eval({x: -5}, on_unused_input="ignore", mode=no_opt_mode)
assert isinstance(evaled_slice, slice)
if start == "x":
assert evaled_slice.start == -5
elif start is None and (evaled_slice.step is None or evaled_slice.step > 0):
# Numba can convert to 0 (and sometimes does) in this case
assert evaled_slice.start in (None, 0)
else:
assert evaled_slice.start == start
if stop == "x":
assert evaled_slice.stop == -5
else:
assert evaled_slice.stop == stop
if step == "x":
assert evaled_slice.step == -5
elif step is None:
# Numba can convert to 1 (and sometimes does) in this case
assert evaled_slice.step in (None, 1)
else:
assert evaled_slice.step == step
@pytest.mark.parametrize( @pytest.mark.parametrize(
"x, indices", "x, indices",
[ [
...@@ -145,11 +182,6 @@ def test_AdvancedSubtensor1_out_of_bounds(): ...@@ -145,11 +182,6 @@ def test_AdvancedSubtensor1_out_of_bounds():
as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
([[1, 2], [2, 1]], slice(1, None), [[0, 0], [0, 0]]), ([[1, 2], [2, 1]], slice(1, None), [[0, 0], [0, 0]]),
), ),
# Newaxis with vector indexing
(
as_tensor(np.arange(4 * 4).reshape((4, 4))),
(None, [0, 1, 2], [0, 1, 2]),
),
], ],
) )
@pytest.mark.filterwarnings("error") # Raise if we did not expect objmode to be needed @pytest.mark.filterwarnings("error") # Raise if we did not expect objmode to be needed
...@@ -415,13 +447,6 @@ def test_AdvancedIncSubtensor1(x, y, indices): ...@@ -415,13 +447,6 @@ def test_AdvancedIncSubtensor1(x, y, indices):
False, False,
False, False,
), ),
(
np.arange(4 * 4).reshape((4, 4)),
np.array(5), # Broadcasted scalar value
(None, [0, 1, 2], [0, 1, 2]), # Newaxis with vector indexing
False,
False,
),
], ],
) )
@pytest.mark.parametrize("inplace", (False, True)) @pytest.mark.parametrize("inplace", (False, True))
...@@ -435,9 +460,7 @@ def test_AdvancedIncSubtensor( ...@@ -435,9 +460,7 @@ def test_AdvancedIncSubtensor(
inplace, inplace,
): ):
# Need rewrite to support certain forms of advanced indexing without object mode # Need rewrite to support certain forms of advanced indexing without object mode
# Use inplace_mode when testing inplace operations to preserve inplace flag mode = numba_mode.including("specialize")
base_mode = numba_inplace_mode if inplace else numba_mode
mode = base_mode.including("specialize")
x_pt = pt.as_tensor(x).type("x") x_pt = pt.as_tensor(x).type("x")
y_pt = pt.as_tensor(y).type("y") y_pt = pt.as_tensor(y).type("y")
...@@ -491,3 +514,22 @@ def test_AdvancedIncSubtensor( ...@@ -491,3 +514,22 @@ def test_AdvancedIncSubtensor(
x_orig = x.copy() x_orig = x.copy()
fn(x, y) fn(x, y)
assert not np.all(x == x_orig) assert not np.all(x == x_orig)
def test_advanced_indexing_with_newaxis_fallback_obj_mode():
# This should be automatically solved with https://github.com/pymc-devs/pytensor/issues/1564
# After which we can add these parametrizations to the relevant tests above
x = pt.matrix("x")
out = x[None, [0, 1, 2], [0, 1, 2]]
with pytest.warns(
UserWarning,
match=r"Numba will use object mode to run AdvancedSubtensor's perform method",
):
compare_numba_and_py([x], [out], [np.random.normal(size=(4, 4))])
out = x[None, [0, 1, 2], [0, 1, 2]].inc(5)
with pytest.warns(
UserWarning,
match=r"Numba will use object mode to run AdvancedIncSubtensor's perform method",
):
compare_numba_and_py([x], [out], [np.random.normal(size=(4, 4))])
...@@ -1642,15 +1642,9 @@ def test_InplaceElemwiseOptimizer_bug(): ...@@ -1642,15 +1642,9 @@ def test_InplaceElemwiseOptimizer_bug():
# with config.change_flags(tensor__insert_inplace_optimizer_validate_nb=10): # with config.change_flags(tensor__insert_inplace_optimizer_validate_nb=10):
rewrite_graph(fgraph, include=("inplace",)) rewrite_graph(fgraph, include=("inplace",))
# Save original value to restore later pytensor.config.tensor__insert_inplace_optimizer_validate_nb = 1
original_value = pytensor.config.tensor__insert_inplace_optimizer_validate_nb with pytest.warns(
try: FutureWarning,
pytensor.config.tensor__insert_inplace_optimizer_validate_nb = 1 match="tensor__insert_inplace_optimizer_validate_nb config is deprecated",
with pytest.warns( ):
FutureWarning, rewrite_graph(fgraph, include=("inplace",))
match="tensor__insert_inplace_optimizer_validate_nb config is deprecated",
):
rewrite_graph(fgraph, include=("inplace",))
finally:
# Restore original value to avoid affecting other tests
pytensor.config.tensor__insert_inplace_optimizer_validate_nb = original_value
...@@ -52,6 +52,7 @@ from pytensor.tensor.type import ( ...@@ -52,6 +52,7 @@ from pytensor.tensor.type import (
tensor4, tensor4,
vector, vector,
) )
from pytensor.tensor.type_other import make_slice
from tests import unittest_tools as utt from tests import unittest_tools as utt
from tests.unittest_tools import create_pytensor_param from tests.unittest_tools import create_pytensor_param
...@@ -1700,11 +1701,11 @@ def test_local_uint_constant_indices(): ...@@ -1700,11 +1701,11 @@ def test_local_uint_constant_indices():
assert isinstance(new_index, Constant) assert isinstance(new_index, Constant)
assert new_index.type.dtype == "uint8" assert new_index.type.dtype == "uint8"
# `AdvancedSubtensor`, two indices, one slice, convert # `AdvancedSubtensor`, two indices, one symbolic slice, convert
x = pt.matrix("x") x = pt.matrix("x")
indices = ( indices = (
pt.as_tensor_variable(np.array([1], np.int64)), pt.as_tensor_variable(np.array(1, np.int64)),
slice(None, 10), make_slice(slice(None, 10)),
) )
z = x[indices] z = x[indices]
...@@ -1791,7 +1792,7 @@ def test_local_uint_constant_indices(): ...@@ -1791,7 +1792,7 @@ def test_local_uint_constant_indices():
z_fn = pytensor.function([x], z, mode=mode) z_fn = pytensor.function([x], z, mode=mode)
subtensor_node = z_fn.maker.fgraph.outputs[0].owner subtensor_node = z_fn.maker.fgraph.outputs[0].owner
assert isinstance(subtensor_node.op, (AdvancedSubtensor, AdvancedSubtensor1)) assert isinstance(subtensor_node.op, AdvancedSubtensor)
new_index = subtensor_node.inputs[1] new_index = subtensor_node.inputs[1]
assert isinstance(new_index, Constant) assert isinstance(new_index, Constant)
assert new_index.type.dtype == "uint8" assert new_index.type.dtype == "uint8"
...@@ -1842,6 +1843,7 @@ class TestBlockwiseIncSubtensor: ...@@ -1842,6 +1843,7 @@ class TestBlockwiseIncSubtensor:
out = vectorize_graph(core_graph, replace={core_x: x, core_y: y}) out = vectorize_graph(core_graph, replace={core_x: x, core_y: y})
fn, ref_fn = self.compile_fn_and_ref([x, y], out) fn, ref_fn = self.compile_fn_and_ref([x, y], out)
assert self.has_blockwise(ref_fn) assert self.has_blockwise(ref_fn)
assert not self.has_blockwise(fn)
test_x = np.ones(x.type.shape, dtype=x.type.dtype) test_x = np.ones(x.type.shape, dtype=x.type.dtype)
test_y = rng.integers(1, 10, size=y.type.shape, dtype=y.type.dtype) test_y = rng.integers(1, 10, size=y.type.shape, dtype=y.type.dtype)
np.testing.assert_allclose(fn(test_x, test_y), ref_fn(test_x, test_y)) np.testing.assert_allclose(fn(test_x, test_y), ref_fn(test_x, test_y))
...@@ -1946,7 +1948,15 @@ class TestBlockwiseIncSubtensor: ...@@ -1946,7 +1948,15 @@ class TestBlockwiseIncSubtensor:
@pytest.mark.parametrize( @pytest.mark.parametrize(
"basic_idx", "basic_idx",
[True, False], [
True,
pytest.param(
False,
marks=pytest.mark.xfail(
reason="AdvancedIncSubtensor with slices can't be blockwise"
),
),
],
ids=["basic_idx", "adv_idx"], ids=["basic_idx", "adv_idx"],
) )
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -1963,7 +1973,7 @@ class TestBlockwiseIncSubtensor: ...@@ -1963,7 +1973,7 @@ class TestBlockwiseIncSubtensor:
core_idx = pt.tensor("idx", dtype=int, shape=() if basic_idx else (2,)) core_idx = pt.tensor("idx", dtype=int, shape=() if basic_idx else (2,))
# The empty slice before core_idx, will lead to a transposition of the advanced view # The empty slice before core_idx, will lead to a transposition of the advanced view
# once it is paired with a new arange slice on the batched dimensions. # once it is paired with an new arange slice on the batched dimensions.
# That's why core_v is (2, 3), and not (3, 2), in the case of advanced indexing # That's why core_v is (2, 3), and not (3, 2), in the case of advanced indexing
core_out = core_a[0, :, core_idx].set(core_v) core_out = core_a[0, :, core_idx].set(core_v)
......
...@@ -32,6 +32,7 @@ from pytensor.tensor import ( ...@@ -32,6 +32,7 @@ from pytensor.tensor import (
lscalars, lscalars,
matrix, matrix,
shape, shape,
slicetype,
specify_shape, specify_shape,
tensor, tensor,
tensor3, tensor3,
...@@ -556,7 +557,7 @@ class TestLocalSubtensorSpecifyShapeLift: ...@@ -556,7 +557,7 @@ class TestLocalSubtensorSpecifyShapeLift:
( (
matrix(), matrix(),
(iscalar(), iscalar()), (iscalar(), iscalar()),
(slice(iscalar(), iscalar(), iscalar()),), (slicetype(),),
), ),
( (
matrix(), matrix(),
...@@ -788,12 +789,12 @@ def test_local_subtensor_shape_constant(): ...@@ -788,12 +789,12 @@ def test_local_subtensor_shape_constant():
(lambda x: x[:, [0, 1]][0], True), (lambda x: x[:, [0, 1]][0], True),
(lambda x: x[:, [0, 1], [0, 0]][1:], True), (lambda x: x[:, [0, 1], [0, 0]][1:], True),
(lambda x: x[:, [[0, 1], [0, 0]]][1:], True), (lambda x: x[:, [[0, 1], [0, 0]]][1:], True),
(lambda x: x[:, None, [0, 1]][0], True),
# Not supported, basic indexing on advanced indexing dim # Not supported, basic indexing on advanced indexing dim
(lambda x: x[[0, 1]][0], False), (lambda x: x[[0, 1]][0], False),
# Not supported, basic indexing on the right of advanced indexing # Not implemented, basic indexing on the right of advanced indexing
(lambda x: x[[0, 1]][:, 0], False), (lambda x: x[[0, 1]][:, 0], False),
# Not implemented, complex flavors of advanced indexing # Not implemented, complex flavors of advanced indexing
(lambda x: x[:, None, [0, 1]][0], False),
(lambda x: x[:, 5:, [0, 1]][0], False), (lambda x: x[:, 5:, [0, 1]][0], False),
(lambda x: x[:, :, np.array([True, False, False])][0], False), (lambda x: x[:, :, np.array([True, False, False])][0], False),
(lambda x: x[[0, 1], :, [0, 1]][:, 0], False), (lambda x: x[[0, 1], :, [0, 1]][:, 0], False),
......
...@@ -31,8 +31,6 @@ from pytensor.tensor.blockwise import ( ...@@ -31,8 +31,6 @@ from pytensor.tensor.blockwise import (
vectorize_node_fallback, vectorize_node_fallback,
) )
from pytensor.tensor.nlinalg import MatrixInverse, eig from pytensor.tensor.nlinalg import MatrixInverse, eig
from pytensor.tensor.random import normal
from pytensor.tensor.random.op import default_rng
from pytensor.tensor.rewriting.blas import specialize_matmul_to_batched_dot from pytensor.tensor.rewriting.blas import specialize_matmul_to_batched_dot
from pytensor.tensor.signal import convolve1d from pytensor.tensor.signal import convolve1d
from pytensor.tensor.slinalg import ( from pytensor.tensor.slinalg import (
...@@ -116,18 +114,16 @@ def test_vectorize_blockwise(): ...@@ -116,18 +114,16 @@ def test_vectorize_blockwise():
def test_vectorize_node_fallback_unsupported_type(): def test_vectorize_node_fallback_unsupported_type():
rng = default_rng() x = tensor("x", shape=(2, 6))
node = normal(rng=rng).owner node = x[:, [0, 2, 4]].owner
with pytest.raises( with pytest.raises(
NotImplementedError, NotImplementedError,
match=re.escape( match=re.escape(
'Cannot vectorize node normal_rv{"(),()->()"}(' "Cannot vectorize node AdvancedSubtensor(x, MakeSlice.0, [0 2 4]) with input MakeSlice.0 of type slice"
"DefaultGeneratorMakerOp.0, NoneConst{None}, 0.0, 1.0)"
" with input DefaultGeneratorMakerOp.0 of type RandomGeneratorType"
), ),
): ):
vectorize_node_fallback(node.op, node, *node.inputs) vectorize_node_fallback(node.op, node, node.inputs)
def check_blockwise_runtime_broadcasting(mode): def check_blockwise_runtime_broadcasting(mode):
......
...@@ -11,19 +11,20 @@ from numpy.testing import assert_array_equal ...@@ -11,19 +11,20 @@ from numpy.testing import assert_array_equal
import pytensor import pytensor
import pytensor.scalar as scal import pytensor.scalar as scal
import pytensor.tensor.basic as ptb import pytensor.tensor.basic as ptb
from pytensor import function, shared from pytensor import function
from pytensor.compile import DeepCopyOp from pytensor.compile import DeepCopyOp, shared
from pytensor.compile.io import In from pytensor.compile.io import In
from pytensor.compile.mode import Mode, get_default_mode from pytensor.compile.mode import Mode, get_default_mode
from pytensor.configdefaults import config from pytensor.configdefaults import config
from pytensor.gradient import grad from pytensor.gradient import grad
from pytensor.graph import Constant
from pytensor.graph.basic import equal_computations from pytensor.graph.basic import equal_computations
from pytensor.graph.op import get_test_value from pytensor.graph.op import get_test_value
from pytensor.graph.rewriting.utils import is_same_graph from pytensor.graph.rewriting.utils import is_same_graph
from pytensor.link.numba import NumbaLinker from pytensor.link.numba import NumbaLinker
from pytensor.printing import pprint from pytensor.printing import pprint
from pytensor.scalar.basic import as_scalar, int16 from pytensor.scalar.basic import as_scalar, int16
from pytensor.tensor import as_tensor, constant, get_vector_length, ivector, vectorize from pytensor.tensor import as_tensor, constant, get_vector_length, vectorize
from pytensor.tensor.blockwise import Blockwise, BlockwiseWithCoreShape from pytensor.tensor.blockwise import Blockwise, BlockwiseWithCoreShape
from pytensor.tensor.elemwise import DimShuffle from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.math import exp, isinf, lt, switch from pytensor.tensor.math import exp, isinf, lt, switch
...@@ -32,6 +33,7 @@ from pytensor.tensor.shape import specify_broadcastable, specify_shape ...@@ -32,6 +33,7 @@ from pytensor.tensor.shape import specify_broadcastable, specify_shape
from pytensor.tensor.subtensor import ( from pytensor.tensor.subtensor import (
AdvancedIncSubtensor, AdvancedIncSubtensor,
AdvancedIncSubtensor1, AdvancedIncSubtensor1,
AdvancedIndexingError,
AdvancedSubtensor, AdvancedSubtensor,
AdvancedSubtensor1, AdvancedSubtensor1,
IncSubtensor, IncSubtensor,
...@@ -47,6 +49,7 @@ from pytensor.tensor.subtensor import ( ...@@ -47,6 +49,7 @@ from pytensor.tensor.subtensor import (
flip, flip,
get_canonical_form_slice, get_canonical_form_slice,
inc_subtensor, inc_subtensor,
index_vars_to_types,
indexed_result_shape, indexed_result_shape,
set_subtensor, set_subtensor,
slice_at_axis, slice_at_axis,
...@@ -77,7 +80,13 @@ from pytensor.tensor.type import ( ...@@ -77,7 +80,13 @@ from pytensor.tensor.type import (
tensor5, tensor5,
vector, vector,
) )
from pytensor.tensor.type_other import NoneConst from pytensor.tensor.type_other import (
NoneConst,
SliceConstant,
as_symbolic_slice,
make_slice,
slicetype,
)
from tests import unittest_tools as utt from tests import unittest_tools as utt
from tests.tensor.utils import inplace_func, integers_ranged, random from tests.tensor.utils import inplace_func, integers_ranged, random
...@@ -97,12 +106,20 @@ def test_as_index_literal(): ...@@ -97,12 +106,20 @@ def test_as_index_literal():
assert res == slice(1, None) assert res == slice(1, None)
res = as_index_literal(slice(None, None, ptb.as_tensor(2))) res = as_index_literal(slice(None, None, ptb.as_tensor(2)))
assert res == slice(None, None, 2) assert res == slice(None, None, 2)
res = as_index_literal(SliceConstant(slicetype, slice(None)))
assert res == slice(None)
res = as_index_literal(make_slice(None, ptb.as_tensor(1)))
assert res == slice(None, 1)
res = as_index_literal(ptb.as_tensor(2)) res = as_index_literal(ptb.as_tensor(2))
assert res == 2 assert res == 2
res = as_index_literal(np.newaxis) res = as_index_literal(np.newaxis)
assert res is np.newaxis assert res is np.newaxis
res = as_index_literal(NoneConst)
assert res is np.newaxis
res = as_index_literal(NoneConst.clone())
assert res is np.newaxis
class TestGetCanonicalFormSlice: class TestGetCanonicalFormSlice:
...@@ -111,6 +128,8 @@ class TestGetCanonicalFormSlice: ...@@ -111,6 +128,8 @@ class TestGetCanonicalFormSlice:
[ [
NoneConst, NoneConst,
None, None,
as_symbolic_slice(slice(3, 7, 2)),
as_symbolic_slice(slice(3, int16(), 2)),
vector(), vector(),
], ],
) )
...@@ -118,19 +137,6 @@ class TestGetCanonicalFormSlice: ...@@ -118,19 +137,6 @@ class TestGetCanonicalFormSlice:
with pytest.raises(ValueError, match="not a supported slice"): with pytest.raises(ValueError, match="not a supported slice"):
get_canonical_form_slice(idx, 5) get_canonical_form_slice(idx, 5)
@pytest.mark.parametrize(
"idx,expected_direction",
[
(slice(3, 7, 2), 1),
(slice(None, None), 1),
(slice(None, None, -1), -1),
],
)
def test_python_slice_support(self, idx, expected_direction):
result, direction = get_canonical_form_slice(idx, 10)
assert isinstance(result, slice)
assert direction == expected_direction
def test_scalar_constant(self): def test_scalar_constant(self):
a = as_scalar(0) a = as_scalar(0)
length = lscalar() length = lscalar()
...@@ -402,7 +408,7 @@ class TestSubtensor(utt.OptimizationTestMixin): ...@@ -402,7 +408,7 @@ class TestSubtensor(utt.OptimizationTestMixin):
f = inplace_func([], t, mode=mode) f = inplace_func([], t, mode=mode)
topo = f.maker.fgraph.toposort() topo = f.maker.fgraph.toposort()
topo_ = [node for node in topo if not isinstance(node.op, DeepCopyOp)] topo_ = [node for node in topo if not isinstance(node.op, DeepCopyOp)]
assert len(topo_) == length, f.dprint() assert len(topo_) == length
if length == 1: if length == 1:
assert isinstance(topo_[0].op, op_type) assert isinstance(topo_[0].op, op_type)
tval = f() tval = f()
...@@ -617,7 +623,7 @@ class TestSubtensor(utt.OptimizationTestMixin): ...@@ -617,7 +623,7 @@ class TestSubtensor(utt.OptimizationTestMixin):
(3, DimShuffle, np.index_exp[..., [0, 2, 3]]), (3, DimShuffle, np.index_exp[..., [0, 2, 3]]),
(1, DimShuffle, np.index_exp[np.newaxis, ...]), (1, DimShuffle, np.index_exp[np.newaxis, ...]),
( (
4 if config.mode == "FAST_COMPILE" else 3, 1,
AdvancedSubtensor, AdvancedSubtensor,
np.index_exp[..., np.newaxis, [1, 2]], np.index_exp[..., np.newaxis, [1, 2]],
), ),
...@@ -1961,7 +1967,7 @@ class TestAdvancedSubtensor: ...@@ -1961,7 +1967,7 @@ class TestAdvancedSubtensor:
x = self.shared(x_val, name="x") x = self.shared(x_val, name="x")
y = tensor(dtype="float32", shape=(None,) * len(y_val.shape), name="y") y = tensor(dtype="float32", shape=(None,) * len(y_val.shape), name="y")
sym_idx = [ptb.as_tensor_variable(ix) for ix in idx] sym_idx = [ptb.as_tensor_variable(ix) for ix in idx]
expr = advanced_inc_subtensor(x, y, *sym_idx, inplace=inplace) expr = AdvancedIncSubtensor(inplace=inplace)(x, y, *sym_idx)
f = pytensor.function( f = pytensor.function(
[y], expr, mode=self.mode.excluding("inplace"), accept_inplace=inplace [y], expr, mode=self.mode.excluding("inplace"), accept_inplace=inplace
) )
...@@ -2297,29 +2303,20 @@ class TestAdvancedSubtensor: ...@@ -2297,29 +2303,20 @@ class TestAdvancedSubtensor:
def test_adv_sub_slice(self): def test_adv_sub_slice(self):
# Reported in https://github.com/Theano/Theano/issues/5898 # Reported in https://github.com/Theano/Theano/issues/5898
var = self.shared(np.zeros([3, 3], dtype=config.floatX)) var = self.shared(np.zeros([3, 3], dtype=config.floatX))
slc = slicetype()
f = pytensor.function([slc], var[slc], mode=self.mode)
s = slice(1, 3)
assert f(s).shape == (2, 3)
# Test with scalar variables for slice boundaries f_shape0 = pytensor.function([slc], var[slc].shape[0], mode=self.mode)
start = lscalar("start") assert f_shape0(s) == 2
stop = lscalar("stop")
# Create sliced output
f = pytensor.function([start, stop], var[start:stop], mode=self.mode)
result = f(1, 3)
assert result.shape == (2, 3)
f_shape0 = pytensor.function( f_shape1 = pytensor.function([slc], var[slc].shape[1], mode=self.mode)
[start, stop], var[start:stop].shape[0], mode=self.mode
)
assert f_shape0(1, 3) == 2
f_shape1 = pytensor.function(
[start, stop], var[start:stop].shape[1], mode=self.mode
)
assert not any( assert not any(
isinstance(node.op, AdvancedSubtensor) isinstance(node.op, AdvancedSubtensor)
for node in f_shape1.maker.fgraph.toposort() for node in f_shape1.maker.fgraph.toposort()
) )
assert f_shape1(1, 3) == 3 assert f_shape1(s) == 3
def test_adv_grouped(self): def test_adv_grouped(self):
# Reported in https://github.com/Theano/Theano/issues/6152 # Reported in https://github.com/Theano/Theano/issues/6152
...@@ -2801,8 +2798,8 @@ class TestInferShape(utt.InferShapeTester): ...@@ -2801,8 +2798,8 @@ class TestInferShape(utt.InferShapeTester):
def test_advanced_subtensor_constant_slice(self): def test_advanced_subtensor_constant_slice(self):
x = dmatrix("x") x = dmatrix("x")
# Use Python slice directly instead of as_symbolic(slice()) constant_slice = pytensor.as_symbolic(slice(1, None, None))
constant_slice = slice(1, None, None) assert isinstance(constant_slice, Constant)
adv_indices = ptb.constant(np.zeros((2, 3)), dtype="int") adv_indices = ptb.constant(np.zeros((2, 3)), dtype="int")
y = advanced_subtensor(x, constant_slice, adv_indices) y = advanced_subtensor(x, constant_slice, adv_indices)
assert tuple(y.shape.eval({x: np.zeros((10, 10))})) == (9, 2, 3) assert tuple(y.shape.eval({x: np.zeros((10, 10))})) == (9, 2, 3)
...@@ -2811,7 +2808,7 @@ class TestInferShape(utt.InferShapeTester): ...@@ -2811,7 +2808,7 @@ class TestInferShape(utt.InferShapeTester):
@config.change_flags(compute_test_value="raise") @config.change_flags(compute_test_value="raise")
def test_basic_shape(): def test_basic_shape():
test_shape = (5, 4) test_shape = (5, 4)
test_indices = (slice(1, 3, None),) # Python slice instead of make_slice() test_indices = (make_slice(1, 3, None),)
res = basic_shape(test_shape, test_indices) res = basic_shape(test_shape, test_indices)
assert get_test_value(res) == (2,) assert get_test_value(res) == (2,)
...@@ -2849,6 +2846,18 @@ test_idx = np.ix_(np.array([True, True]), np.array([True]), np.array([True, True ...@@ -2849,6 +2846,18 @@ test_idx = np.ix_(np.array([True, True]), np.array([True]), np.array([True, True
np.arange(np.prod((5, 6, 7, 8))).reshape((5, 6, 7, 8)), np.arange(np.prod((5, 6, 7, 8))).reshape((5, 6, 7, 8)),
(slice(None, None), *test_idx[:1]), (slice(None, None), *test_idx[:1]),
), ),
(
np.arange(np.prod((5, 6, 7, 8))).reshape((5, 6, 7, 8)),
(slice(None, None), None, *test_idx[1:2]),
),
(
np.arange(np.prod((5, 6, 7, 8))).reshape((5, 6, 7, 8)),
(np.array(1), slice(None, None), None),
),
(
np.arange(np.prod((5, 6, 7, 8))).reshape((5, 6, 7, 8)),
(slice(None, None), None, np.array(1)),
),
( (
np.arange(np.prod((5, 6, 7, 8))).reshape((5, 6, 7, 8)), np.arange(np.prod((5, 6, 7, 8))).reshape((5, 6, 7, 8)),
(*test_idx[:1], slice(None, None), *test_idx[1:2]), (*test_idx[:1], slice(None, None), *test_idx[1:2]),
...@@ -2857,6 +2866,10 @@ test_idx = np.ix_(np.array([True, True]), np.array([True]), np.array([True, True ...@@ -2857,6 +2866,10 @@ test_idx = np.ix_(np.array([True, True]), np.array([True]), np.array([True, True
np.arange(np.prod((5, 6, 7, 8))).reshape((5, 6, 7, 8)), np.arange(np.prod((5, 6, 7, 8))).reshape((5, 6, 7, 8)),
(*test_idx[:1], slice(None, None), *test_idx[1:2], slice(None, None)), (*test_idx[:1], slice(None, None), *test_idx[1:2], slice(None, None)),
), ),
(
np.arange(np.prod((5, 6, 7, 8))).reshape((5, 6, 7, 8)),
(*test_idx[:1], None, *test_idx[1:2]),
),
(np.arange(np.prod((5, 4))).reshape((5, 4)), ([1, 3, 2], slice(1, 3))), (np.arange(np.prod((5, 4))).reshape((5, 4)), ([1, 3, 2], slice(1, 3))),
(np.arange(np.prod((5, 4))).reshape((5, 4)), (slice(1, 3), [1, 3, 2])), (np.arange(np.prod((5, 4))).reshape((5, 4)), (slice(1, 3), [1, 3, 2])),
( (
...@@ -2916,11 +2929,12 @@ def test_get_vector_length(): ...@@ -2916,11 +2929,12 @@ def test_get_vector_length():
"indices, exp_res", "indices, exp_res",
[ [
((0,), "x[0]"), ((0,), "x[0]"),
((slice(None, 2),), "x[:2]"), # TODO: The numbers should be printed
((slice(0, None),), "x[0:]"), ((slice(None, 2),), "x[:int64]"),
((slice(0, 2),), "x[0:2]"), ((slice(0, None),), "x[int64:]"),
((slice(0, 2, 2),), "x[0:2:2]"), ((slice(0, 2),), "x[int64:int64]"),
((slice(0, 2), 0, slice(0, 2)), "x[0:2, 0, 0:2]"), ((slice(0, 2, 2),), "x[int64:int64:int64]"),
((slice(0, 2), 0, slice(0, 2)), "x[int64:int64, 2, int64:int64]"),
], ],
) )
def test_pprint_Subtensor(indices, exp_res): def test_pprint_Subtensor(indices, exp_res):
...@@ -2934,7 +2948,7 @@ def test_pprint_Subtensor(indices, exp_res): ...@@ -2934,7 +2948,7 @@ def test_pprint_Subtensor(indices, exp_res):
[ [
((0,), False, "inc_subtensor(x[0], z)"), ((0,), False, "inc_subtensor(x[0], z)"),
((0,), True, "set_subtensor(x[0], z)"), ((0,), True, "set_subtensor(x[0], z)"),
((slice(0, 2),), True, "set_subtensor(x[0:2], z)"), ((slice(0, 2),), True, "set_subtensor(x[int64:int64], z)"),
], ],
) )
def test_pprint_IncSubtensor(indices, set_instead_of_inc, exp_res): def test_pprint_IncSubtensor(indices, set_instead_of_inc, exp_res):
...@@ -2944,38 +2958,22 @@ def test_pprint_IncSubtensor(indices, set_instead_of_inc, exp_res): ...@@ -2944,38 +2958,22 @@ def test_pprint_IncSubtensor(indices, set_instead_of_inc, exp_res):
assert pprint(y) == exp_res assert pprint(y) == exp_res
@pytest.mark.parametrize( def test_index_vars_to_types():
"indices, exp_res", x = ptb.as_tensor_variable(np.array([True, False]))
[
# Vector index
((ivector("idx"),), "x[idx]"),
# Two vector indices
((ivector("idx"), ivector("idx2")), "x[idx, idx2]"),
# Vector index with scalar (triggers advanced indexing)
((ivector("idx"), 0), "x[idx, 0]"),
# Vector index with constant slice
((ivector("idx"), slice(0, 5)), "x[idx, 0:5]"),
],
)
def test_pprint_AdvancedSubtensor(indices, exp_res):
x = tensor4("x")
y = advanced_subtensor(x, *indices)
assert pprint(y) == exp_res
with pytest.raises(AdvancedIndexingError):
index_vars_to_types(x)
@pytest.mark.parametrize( with pytest.raises(TypeError):
"indices, set_instead_of_inc, exp_res", index_vars_to_types(1)
[
((ivector("idx"),), False, "inc_subtensor(x[idx], z)"), res = index_vars_to_types(iscalar)
((ivector("idx"),), True, "set_subtensor(x[idx], z)"), assert isinstance(res, scal.ScalarType)
((ivector("idx"), slice(None, 5)), True, "set_subtensor(x[idx, :5], z)"),
], x = scal.constant(1, dtype=np.uint8)
) assert isinstance(x.type, scal.ScalarType)
def test_pprint_AdvancedIncSubtensor(indices, set_instead_of_inc, exp_res): res = index_vars_to_types(x)
x = tensor4("x") assert res == x.type
z = tensor3("z")
y = advanced_inc_subtensor(x, z, *indices, set_instead_of_inc=set_instead_of_inc)
assert pprint(y) == exp_res
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -3068,12 +3066,15 @@ def test_vectorize_subtensor_without_batch_indices(): ...@@ -3068,12 +3066,15 @@ def test_vectorize_subtensor_without_batch_indices():
(2,), (2,),
False, False,
), ),
# (this is currently failing because PyTensor tries to vectorize the slice(None) operation,
# due to the exact same None constant being used there and in the np.newaxis)
pytest.param( pytest.param(
(lambda x, idx: x[:, idx, None]), (lambda x, idx: x[:, idx, None]),
"(7,5,3),(2)->(7,2,1,3)", "(7,5,3),(2)->(7,2,1,3)",
(11, 7, 5, 3), (11, 7, 5, 3),
(2,), (2,),
False, False,
marks=pytest.mark.xfail(raises=NotImplementedError),
), ),
( (
(lambda x, idx: x[:, idx, idx, :]), (lambda x, idx: x[:, idx, idx, :]),
...@@ -3082,23 +3083,27 @@ def test_vectorize_subtensor_without_batch_indices(): ...@@ -3082,23 +3083,27 @@ def test_vectorize_subtensor_without_batch_indices():
(2,), (2,),
False, False,
), ),
# (not supported, because fallback Blocwise can't handle slices)
pytest.param( pytest.param(
(lambda x, idx: x[:, idx, :, idx]), (lambda x, idx: x[:, idx, :, idx]),
"(7,5,3,5),(2)->(2,7,3)", "(7,5,3,5),(2)->(2,7,3)",
(11, 7, 5, 3, 5), (11, 7, 5, 3, 5),
(2,), (2,),
True, True,
marks=pytest.mark.xfail(raises=NotImplementedError),
), ),
# Core x, batched idx # Core x, batched idx
((lambda x, idx: x[idx]), "(t1),(idx)->(tx)", (7,), (11, 2), True), ((lambda x, idx: x[idx]), "(t1),(idx)->(tx)", (7,), (11, 2), True),
# Batched x, batched idx # Batched x, batched idx
((lambda x, idx: x[idx]), "(t1),(idx)->(tx)", (11, 7), (11, 2), True), ((lambda x, idx: x[idx]), "(t1),(idx)->(tx)", (11, 7), (11, 2), True),
# (not supported, because fallback Blocwise can't handle slices)
pytest.param( pytest.param(
(lambda x, idx: x[:, idx, :]), (lambda x, idx: x[:, idx, :]),
"(t1,t2,t3),(idx)->(t1,tx,t3)", "(t1,t2,t3),(idx)->(t1,tx,t3)",
(11, 7, 5, 3), (11, 7, 5, 3),
(11, 2), (11, 2),
True, True,
marks=pytest.mark.xfail(raises=NotImplementedError),
), ),
], ],
) )
...@@ -3233,37 +3238,3 @@ class TestBenchmarks: ...@@ -3233,37 +3238,3 @@ class TestBenchmarks:
) )
fn.vm.allow_gc = gc fn.vm.allow_gc = gc
benchmark(fn, x_values) benchmark(fn, x_values)
def test_subtensor_hash_and_eq():
s1 = Subtensor(idx_list=[slice(None, None, None), 0])
s2 = Subtensor(idx_list=[slice(None, None, None), 0])
assert s1 == s2
assert hash(s1) == hash(s2)
s3 = AdvancedSubtensor(idx_list=[slice(None, None, None), 0])
s4 = AdvancedIncSubtensor(idx_list=[slice(0, 1, None), 2])
assert s3 != s4
assert hash(s3) != hash(s4)
assert s1 != s3
inc1 = IncSubtensor(
idx_list=[slice(None)], inplace=True, destroyhandler_tolerate_aliased=[(0, 1)]
)
inc2 = IncSubtensor(
idx_list=[slice(None)], inplace=True, destroyhandler_tolerate_aliased=[(0, 1)]
)
inc3 = IncSubtensor(
idx_list=[slice(None)], inplace=True, destroyhandler_tolerate_aliased=[(0, 2)]
)
assert inc1 == inc2
assert hash(inc1) == hash(inc2)
assert inc1 != inc3
if hash(inc1) == hash(inc3):
assert inc1 == inc3
s_mix1 = Subtensor(idx_list=[0, slice(None), slice(None, 1)])
s_mix2 = Subtensor(idx_list=[0, slice(None), slice(None, 1)])
assert s_mix1 == s_mix2
assert hash(s_mix1) == hash(s_mix2)
...@@ -4,8 +4,30 @@ import pytensor ...@@ -4,8 +4,30 @@ import pytensor
from pytensor import as_symbolic from pytensor import as_symbolic
from pytensor.graph.basic import Constant from pytensor.graph.basic import Constant
from pytensor.tensor.math import argmax from pytensor.tensor.math import argmax
from pytensor.tensor.type import vector from pytensor.tensor.type import iscalar, vector
from pytensor.tensor.type_other import NoneConst, NoneTypeT from pytensor.tensor.type_other import (
MakeSlice,
NoneConst,
NoneTypeT,
SliceConstant,
SliceType,
make_slice,
)
def test_SliceType():
st = SliceType()
assert st == st.clone()
def test_make_slice_merge():
# In the past, this was crahsing during compilation.
i = iscalar()
s1 = make_slice(0, i)
s2 = make_slice(0, i)
f = pytensor.function([i], [s1, s2])
nodes = f.maker.fgraph.apply_nodes
assert len([n for n in nodes if isinstance(n.op, MakeSlice)]) == 1
def test_none_Constant(): def test_none_Constant():
...@@ -25,6 +47,8 @@ def test_none_Constant(): ...@@ -25,6 +47,8 @@ def test_none_Constant():
# This trigger equals that returned the wrong answer in the past. # This trigger equals that returned the wrong answer in the past.
import pickle import pickle
import pytensor
x = vector("x") x = vector("x")
y = argmax(x) y = argmax(x)
kwargs = {} kwargs = {}
...@@ -36,18 +60,11 @@ def test_none_Constant(): ...@@ -36,18 +60,11 @@ def test_none_Constant():
def test_as_symbolic(): def test_as_symbolic():
# Remove this when xtensor is not using symbolic slices
from pytensor.tensor.type import iscalar
from pytensor.tensor.type_other import SliceConstant, slicetype
res = as_symbolic(None) res = as_symbolic(None)
assert res is NoneConst assert res is NoneConst
res = as_symbolic(slice(iscalar()))
assert res.owner.op == make_slice
res = as_symbolic(slice(1, 2)) res = as_symbolic(slice(1, 2))
assert isinstance(res, SliceConstant) assert isinstance(res, SliceConstant)
assert res.type == slicetype
assert res.data == slice(1, 2)
i = iscalar()
res = as_symbolic(slice(i))
assert res.owner is not None
...@@ -35,7 +35,7 @@ from pytensor.tensor.type import ( ...@@ -35,7 +35,7 @@ from pytensor.tensor.type import (
scalar, scalar,
tensor3, tensor3,
) )
from pytensor.tensor.type_other import NoneConst from pytensor.tensor.type_other import MakeSlice, NoneConst
from pytensor.tensor.variable import ( from pytensor.tensor.variable import (
DenseTensorConstant, DenseTensorConstant,
DenseTensorVariable, DenseTensorVariable,
...@@ -232,11 +232,11 @@ def test__getitem__AdvancedSubtensor(): ...@@ -232,11 +232,11 @@ def test__getitem__AdvancedSubtensor():
z = x[:, i] z = x[:, i]
op_types = [type(node.op) for node in io_toposort([x, i], [z])] op_types = [type(node.op) for node in io_toposort([x, i], [z])]
assert op_types == [AdvancedSubtensor] assert op_types == [MakeSlice, AdvancedSubtensor]
z = x[..., i, None] z = x[..., i, None]
op_types = [type(node.op) for node in io_toposort([x, i], [z])] op_types = [type(node.op) for node in io_toposort([x, i], [z])]
assert op_types == [DimShuffle, AdvancedSubtensor] assert op_types == [MakeSlice, AdvancedSubtensor]
z = x[i, None] z = x[i, None]
op_types = [type(node.op) for node in io_toposort([x, i], [z])] op_types = [type(node.op) for node in io_toposort([x, i], [z])]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论