提交 db7fa079 authored 作者: Jaan Erik Pihel's avatar Jaan Erik Pihel 提交者: Ricardo Vieira

Refactor AdvancedSubtensor

- newaxis is handled as explicit DimShuffel on the inputs - slices are encoded internally, so the Ops only take numerical inputs Co-authored-by: 's avatarRicardo Vieira <28983449+ricardov94@users.noreply.github.com>
上级 87470065
...@@ -771,9 +771,9 @@ class DestroyHandler(Bookkeeper): ...@@ -771,9 +771,9 @@ class DestroyHandler(Bookkeeper):
} }
tolerated.add(destroyed_idx) tolerated.add(destroyed_idx)
tolerate_aliased = getattr( tolerate_aliased = getattr(
app.op, "destroyhandler_tolerate_aliased", [] app.op, "destroyhandler_tolerate_aliased", ()
) )
assert isinstance(tolerate_aliased, list) assert isinstance(tolerate_aliased, tuple | list)
ignored = { ignored = {
idx1 for idx0, idx1 in tolerate_aliased if idx0 == destroyed_idx idx1 for idx0, idx1 in tolerate_aliased if idx0 == destroyed_idx
} }
......
...@@ -8,7 +8,6 @@ from pytensor.tensor.subtensor import ( ...@@ -8,7 +8,6 @@ from pytensor.tensor.subtensor import (
Subtensor, Subtensor,
indices_from_subtensor, indices_from_subtensor,
) )
from pytensor.tensor.type_other import MakeSlice
BOOLEAN_MASK_ERROR = """JAX does not support resizing arrays with boolean BOOLEAN_MASK_ERROR = """JAX does not support resizing arrays with boolean
...@@ -35,10 +34,8 @@ slice length. ...@@ -35,10 +34,8 @@ slice length.
@jax_funcify.register(AdvancedSubtensor) @jax_funcify.register(AdvancedSubtensor)
@jax_funcify.register(AdvancedSubtensor1) @jax_funcify.register(AdvancedSubtensor1)
def jax_funcify_Subtensor(op, node, **kwargs): def jax_funcify_Subtensor(op, node, **kwargs):
idx_list = getattr(op, "idx_list", None)
def subtensor(x, *ilists): def subtensor(x, *ilists):
indices = indices_from_subtensor(ilists, idx_list) indices = indices_from_subtensor(ilists, op.idx_list)
if len(indices) == 1: if len(indices) == 1:
indices = indices[0] indices = indices[0]
...@@ -48,10 +45,9 @@ def jax_funcify_Subtensor(op, node, **kwargs): ...@@ -48,10 +45,9 @@ def jax_funcify_Subtensor(op, node, **kwargs):
@jax_funcify.register(IncSubtensor) @jax_funcify.register(IncSubtensor)
@jax_funcify.register(AdvancedIncSubtensor)
@jax_funcify.register(AdvancedIncSubtensor1) @jax_funcify.register(AdvancedIncSubtensor1)
def jax_funcify_IncSubtensor(op, node, **kwargs): def jax_funcify_IncSubtensor(op, node, **kwargs):
idx_list = getattr(op, "idx_list", None)
if getattr(op, "set_instead_of_inc", False): if getattr(op, "set_instead_of_inc", False):
def jax_fn(x, indices, y): def jax_fn(x, indices, y):
...@@ -62,7 +58,7 @@ def jax_funcify_IncSubtensor(op, node, **kwargs): ...@@ -62,7 +58,7 @@ def jax_funcify_IncSubtensor(op, node, **kwargs):
def jax_fn(x, indices, y): def jax_fn(x, indices, y):
return x.at[indices].add(y) return x.at[indices].add(y)
def incsubtensor(x, y, *ilist, jax_fn=jax_fn, idx_list=idx_list): def incsubtensor(x, y, *ilist, jax_fn=jax_fn, idx_list=op.idx_list):
indices = indices_from_subtensor(ilist, idx_list) indices = indices_from_subtensor(ilist, idx_list)
if len(indices) == 1: if len(indices) == 1:
indices = indices[0] indices = indices[0]
...@@ -73,29 +69,3 @@ def jax_funcify_IncSubtensor(op, node, **kwargs): ...@@ -73,29 +69,3 @@ def jax_funcify_IncSubtensor(op, node, **kwargs):
return jax_fn(x, indices, y) return jax_fn(x, indices, y)
return incsubtensor return incsubtensor
@jax_funcify.register(AdvancedIncSubtensor)
def jax_funcify_AdvancedIncSubtensor(op, node, **kwargs):
if getattr(op, "set_instead_of_inc", False):
def jax_fn(x, indices, y):
return x.at[indices].set(y)
else:
def jax_fn(x, indices, y):
return x.at[indices].add(y)
def advancedincsubtensor(x, y, *ilist, jax_fn=jax_fn):
return jax_fn(x, ilist, y)
return advancedincsubtensor
@jax_funcify.register(MakeSlice)
def jax_funcify_MakeSlice(op, **kwargs):
def makeslice(*x):
return slice(*x)
return makeslice
...@@ -10,15 +10,14 @@ from pytensor.tensor.subtensor import ( ...@@ -10,15 +10,14 @@ from pytensor.tensor.subtensor import (
Subtensor, Subtensor,
indices_from_subtensor, indices_from_subtensor,
) )
from pytensor.tensor.type_other import MakeSlice
@mlx_funcify.register(Subtensor) @mlx_funcify.register(Subtensor)
def mlx_funcify_Subtensor(op, node, **kwargs): def mlx_funcify_Subtensor(op, node, **kwargs):
idx_list = getattr(op, "idx_list", None)
def subtensor(x, *ilists): def subtensor(x, *ilists):
indices = indices_from_subtensor([int(element) for element in ilists], idx_list) indices = indices_from_subtensor(
[int(element) for element in ilists], op.idx_list
)
if len(indices) == 1: if len(indices) == 1:
indices = indices[0] indices = indices[0]
...@@ -30,10 +29,8 @@ def mlx_funcify_Subtensor(op, node, **kwargs): ...@@ -30,10 +29,8 @@ def mlx_funcify_Subtensor(op, node, **kwargs):
@mlx_funcify.register(AdvancedSubtensor) @mlx_funcify.register(AdvancedSubtensor)
@mlx_funcify.register(AdvancedSubtensor1) @mlx_funcify.register(AdvancedSubtensor1)
def mlx_funcify_AdvancedSubtensor(op, node, **kwargs): def mlx_funcify_AdvancedSubtensor(op, node, **kwargs):
idx_list = getattr(op, "idx_list", None)
def advanced_subtensor(x, *ilists): def advanced_subtensor(x, *ilists):
indices = indices_from_subtensor(ilists, idx_list) indices = indices_from_subtensor(ilists, op.idx_list)
if len(indices) == 1: if len(indices) == 1:
indices = indices[0] indices = indices[0]
...@@ -45,8 +42,6 @@ def mlx_funcify_AdvancedSubtensor(op, node, **kwargs): ...@@ -45,8 +42,6 @@ def mlx_funcify_AdvancedSubtensor(op, node, **kwargs):
@mlx_funcify.register(IncSubtensor) @mlx_funcify.register(IncSubtensor)
@mlx_funcify.register(AdvancedIncSubtensor1) @mlx_funcify.register(AdvancedIncSubtensor1)
def mlx_funcify_IncSubtensor(op, node, **kwargs): def mlx_funcify_IncSubtensor(op, node, **kwargs):
idx_list = getattr(op, "idx_list", None)
if getattr(op, "set_instead_of_inc", False): if getattr(op, "set_instead_of_inc", False):
def mlx_fn(x, indices, y): def mlx_fn(x, indices, y):
...@@ -63,7 +58,7 @@ def mlx_funcify_IncSubtensor(op, node, **kwargs): ...@@ -63,7 +58,7 @@ def mlx_funcify_IncSubtensor(op, node, **kwargs):
x[indices] += y x[indices] += y
return x return x
def incsubtensor(x, y, *ilist, mlx_fn=mlx_fn, idx_list=idx_list): def incsubtensor(x, y, *ilist, mlx_fn=mlx_fn, idx_list=op.idx_list):
indices = indices_from_subtensor(ilist, idx_list) indices = indices_from_subtensor(ilist, idx_list)
if len(indices) == 1: if len(indices) == 1:
indices = indices[0] indices = indices[0]
...@@ -95,11 +90,3 @@ def mlx_funcify_AdvancedIncSubtensor(op, node, **kwargs): ...@@ -95,11 +90,3 @@ def mlx_funcify_AdvancedIncSubtensor(op, node, **kwargs):
return mlx_fn(x, ilist, y) return mlx_fn(x, ilist, y)
return advancedincsubtensor return advancedincsubtensor
@mlx_funcify.register(MakeSlice)
def mlx_funcify_MakeSlice(op, **kwargs):
def makeslice(*x):
return slice(*x)
return makeslice
...@@ -9,7 +9,6 @@ from pytensor.tensor.subtensor import ( ...@@ -9,7 +9,6 @@ from pytensor.tensor.subtensor import (
Subtensor, Subtensor,
indices_from_subtensor, indices_from_subtensor,
) )
from pytensor.tensor.type_other import MakeSlice, SliceType
def check_negative_steps(indices): def check_negative_steps(indices):
...@@ -47,23 +46,11 @@ def pytorch_funcify_Subtensor(op, node, **kwargs): ...@@ -47,23 +46,11 @@ def pytorch_funcify_Subtensor(op, node, **kwargs):
return subtensor return subtensor
@pytorch_funcify.register(MakeSlice)
def pytorch_funcify_makeslice(op, **kwargs):
def makeslice(start, stop, step):
# Torch does not like numpy integers in indexing slices
return slice(
None if start is None else int(start),
None if stop is None else int(stop),
None if step is None else int(step),
)
return makeslice
@pytorch_funcify.register(AdvancedSubtensor1) @pytorch_funcify.register(AdvancedSubtensor1)
@pytorch_funcify.register(AdvancedSubtensor) @pytorch_funcify.register(AdvancedSubtensor)
def pytorch_funcify_AdvSubtensor(op, node, **kwargs): def pytorch_funcify_AdvSubtensor(op, node, **kwargs):
def advsubtensor(x, *indices): def advsubtensor(x, *indices):
indices = indices_from_subtensor(indices, op.idx_list)
check_negative_steps(indices) check_negative_steps(indices)
return x[indices] return x[indices]
...@@ -102,12 +89,14 @@ def pytorch_funcify_IncSubtensor(op, node, **kwargs): ...@@ -102,12 +89,14 @@ def pytorch_funcify_IncSubtensor(op, node, **kwargs):
@pytorch_funcify.register(AdvancedIncSubtensor) @pytorch_funcify.register(AdvancedIncSubtensor)
@pytorch_funcify.register(AdvancedIncSubtensor1) @pytorch_funcify.register(AdvancedIncSubtensor1)
def pytorch_funcify_AdvancedIncSubtensor(op, node, **kwargs): def pytorch_funcify_AdvancedIncSubtensor(op, node, **kwargs):
idx_list = op.idx_list
inplace = op.inplace inplace = op.inplace
ignore_duplicates = getattr(op, "ignore_duplicates", False) ignore_duplicates = getattr(op, "ignore_duplicates", False)
if op.set_instead_of_inc: if op.set_instead_of_inc:
def adv_set_subtensor(x, y, *indices): def adv_set_subtensor(x, y, *flattened_indices):
indices = indices_from_subtensor(flattened_indices, idx_list)
check_negative_steps(indices) check_negative_steps(indices)
if isinstance(op, AdvancedIncSubtensor1): if isinstance(op, AdvancedIncSubtensor1):
op._check_runtime_broadcasting(node, x, y, indices) op._check_runtime_broadcasting(node, x, y, indices)
...@@ -120,7 +109,8 @@ def pytorch_funcify_AdvancedIncSubtensor(op, node, **kwargs): ...@@ -120,7 +109,8 @@ def pytorch_funcify_AdvancedIncSubtensor(op, node, **kwargs):
elif ignore_duplicates: elif ignore_duplicates:
def adv_inc_subtensor_no_duplicates(x, y, *indices): def adv_inc_subtensor_no_duplicates(x, y, *flattened_indices):
indices = indices_from_subtensor(flattened_indices, idx_list)
check_negative_steps(indices) check_negative_steps(indices)
if isinstance(op, AdvancedIncSubtensor1): if isinstance(op, AdvancedIncSubtensor1):
op._check_runtime_broadcasting(node, x, y, indices) op._check_runtime_broadcasting(node, x, y, indices)
...@@ -132,13 +122,14 @@ def pytorch_funcify_AdvancedIncSubtensor(op, node, **kwargs): ...@@ -132,13 +122,14 @@ def pytorch_funcify_AdvancedIncSubtensor(op, node, **kwargs):
return adv_inc_subtensor_no_duplicates return adv_inc_subtensor_no_duplicates
else: else:
if any(isinstance(idx.type, SliceType) for idx in node.inputs[2:]): if any(isinstance(entry, slice) for entry in idx_list):
raise NotImplementedError( raise NotImplementedError(
"IncSubtensor with potential duplicates indexes and slice indexing not implemented in PyTorch" "IncSubtensor with potential duplicates indexes and slice indexing not implemented in PyTorch"
) )
def adv_inc_subtensor(x, y, *indices): def adv_inc_subtensor(x, y, *flattened_indices):
# Not needed because slices aren't supported indices = indices_from_subtensor(flattened_indices, idx_list)
# Not needed because slices aren't supported in this path
# check_negative_steps(indices) # check_negative_steps(indices)
if not inplace: if not inplace:
x = x.clone() x = x.clone()
......
...@@ -72,9 +72,9 @@ from pytensor.tensor.shape import shape ...@@ -72,9 +72,9 @@ from pytensor.tensor.shape import shape
from pytensor.tensor.subtensor import ( from pytensor.tensor.subtensor import (
IncSubtensor, IncSubtensor,
Subtensor, Subtensor,
basic_subtensor,
get_canonical_form_slice, get_canonical_form_slice,
get_idx_list, get_idx_list,
get_slice_elements,
set_subtensor, set_subtensor,
) )
from pytensor.tensor.variable import TensorConstant, TensorVariable from pytensor.tensor.variable import TensorConstant, TensorVariable
...@@ -1211,7 +1211,7 @@ def _is_default_scan_buffer(final_buffer: TensorVariable, taps: int) -> bool: ...@@ -1211,7 +1211,7 @@ def _is_default_scan_buffer(final_buffer: TensorVariable, taps: int) -> bool:
if not ( if not (
isinstance(op, IncSubtensor) isinstance(op, IncSubtensor)
and op.set_instead_of_inc and op.set_instead_of_inc
and op.idx_list == [slice(None, ps.int64)] and op.idx_list == (slice(None, 0),)
): ):
return False return False
...@@ -1389,12 +1389,6 @@ def scan_save_mem_rewrite(fgraph, node, backend_supports_output_pre_allocation: ...@@ -1389,12 +1389,6 @@ def scan_save_mem_rewrite(fgraph, node, backend_supports_output_pre_allocation:
else: else:
# 2.3.1 extract idx list of subtensor # 2.3.1 extract idx list of subtensor
this_slice = get_idx_list(cl.inputs, cl.op.idx_list) this_slice = get_idx_list(cl.inputs, cl.op.idx_list)
if this_slice is None:
# if unable to extract idx_list
# => outputs needs all its intermediate values
global_nsteps = None
slices[i] = None
break
# 2.3.2 extract the begin/end of the first dimension # 2.3.2 extract the begin/end of the first dimension
if i >= op_info.n_mit_mot: if i >= op_info.n_mit_mot:
...@@ -1487,9 +1481,6 @@ def scan_save_mem_rewrite(fgraph, node, backend_supports_output_pre_allocation: ...@@ -1487,9 +1481,6 @@ def scan_save_mem_rewrite(fgraph, node, backend_supports_output_pre_allocation:
break break
else: else:
this_slice = get_idx_list(cl.inputs, cl.op.idx_list) this_slice = get_idx_list(cl.inputs, cl.op.idx_list)
if this_slice is None:
store_steps[i] = 0
break
if isinstance(this_slice[0], slice): if isinstance(this_slice[0], slice):
start = this_slice[0].start start = this_slice[0].start
...@@ -1711,16 +1702,9 @@ def scan_save_mem_rewrite(fgraph, node, backend_supports_output_pre_allocation: ...@@ -1711,16 +1702,9 @@ def scan_save_mem_rewrite(fgraph, node, backend_supports_output_pre_allocation:
) )
else: else:
fslice = sanitize(cnf_slice[0]) fslice = sanitize(cnf_slice[0])
nw_slice = (fslice, *old_slices[1:])
nw_pos = inv_compress_map[idx] nw_pos = inv_compress_map[idx]
new_o = basic_subtensor(new_outs[nw_pos], fslice, *old_slices[1:])
subtens = Subtensor(nw_slice)
# slice inputs
sl_ins = get_slice_elements(
nw_slice, lambda entry: isinstance(entry, Variable)
)
new_o = cast(TensorVariable, subtens(new_outs[nw_pos], *sl_ins))
if new_o.ndim > 0: if new_o.ndim > 0:
new_o = new_o[:: cnf_slice[1]] new_o = new_o[:: cnf_slice[1]]
replaced_outs.append(idx) replaced_outs.append(idx)
...@@ -1771,11 +1755,7 @@ def scan_save_mem_rewrite(fgraph, node, backend_supports_output_pre_allocation: ...@@ -1771,11 +1755,7 @@ def scan_save_mem_rewrite(fgraph, node, backend_supports_output_pre_allocation:
) )
nw_slice = (sanitize(position), *old_slices[1:]) nw_slice = (sanitize(position), *old_slices[1:])
subtens = Subtensor(nw_slice) new_o = basic_subtensor(new_outs[nw_pos], *nw_slice)
sl_ins = get_slice_elements(
nw_slice, lambda entry: isinstance(entry, Variable)
)
new_o = cast(TensorVariable, subtens(new_outs[nw_pos], *sl_ins))
if new_o.ndim > 0: if new_o.ndim > 0:
new_o = new_o[:: cnf_slice[1]] new_o = new_o[:: cnf_slice[1]]
old_new += [(old, new_o)] old_new += [(old, new_o)]
......
...@@ -29,7 +29,7 @@ from pytensor.graph.fg import FunctionGraph, Output ...@@ -29,7 +29,7 @@ from pytensor.graph.fg import FunctionGraph, Output
from pytensor.graph.op import Op from pytensor.graph.op import Op
from pytensor.graph.replace import _vectorize_node from pytensor.graph.replace import _vectorize_node
from pytensor.graph.rewriting.db import EquilibriumDB from pytensor.graph.rewriting.db import EquilibriumDB
from pytensor.graph.type import HasShape, Type from pytensor.graph.type import HasShape
from pytensor.link.c.op import COp from pytensor.link.c.op import COp
from pytensor.link.c.params_type import ParamsType from pytensor.link.c.params_type import ParamsType
from pytensor.printing import Printer, min_informative_str, pprint, set_precedence from pytensor.printing import Printer, min_informative_str, pprint, set_precedence
...@@ -433,7 +433,7 @@ def _get_underlying_scalar_constant_value( ...@@ -433,7 +433,7 @@ def _get_underlying_scalar_constant_value(
var.ndim == 1 for var in v.owner.inputs[0].owner.inputs[1:] var.ndim == 1 for var in v.owner.inputs[0].owner.inputs[1:]
): ):
idx = v.owner.op.idx_list[0] idx = v.owner.op.idx_list[0]
if isinstance(idx, Type): if isinstance(idx, int):
idx = _get_underlying_scalar_constant_value( idx = _get_underlying_scalar_constant_value(
v.owner.inputs[1], max_recur=max_recur v.owner.inputs[1], max_recur=max_recur
) )
...@@ -467,7 +467,7 @@ def _get_underlying_scalar_constant_value( ...@@ -467,7 +467,7 @@ def _get_underlying_scalar_constant_value(
and len(v.owner.op.idx_list) == 1 and len(v.owner.op.idx_list) == 1
): ):
idx = v.owner.op.idx_list[0] idx = v.owner.op.idx_list[0]
if isinstance(idx, Type): if isinstance(idx, int):
idx = _get_underlying_scalar_constant_value( idx = _get_underlying_scalar_constant_value(
v.owner.inputs[1], max_recur=max_recur v.owner.inputs[1], max_recur=max_recur
) )
...@@ -488,7 +488,7 @@ def _get_underlying_scalar_constant_value( ...@@ -488,7 +488,7 @@ def _get_underlying_scalar_constant_value(
op = owner.op op = owner.op
idx_list = op.idx_list idx_list = op.idx_list
idx = idx_list[0] idx = idx_list[0]
if isinstance(idx, Type): if isinstance(idx, int):
idx = _get_underlying_scalar_constant_value( idx = _get_underlying_scalar_constant_value(
owner.inputs[1], max_recur=max_recur owner.inputs[1], max_recur=max_recur
) )
......
...@@ -23,7 +23,7 @@ from pytensor.tensor.subtensor import ( ...@@ -23,7 +23,7 @@ from pytensor.tensor.subtensor import (
indices_from_subtensor, indices_from_subtensor,
) )
from pytensor.tensor.type import integer_dtypes from pytensor.tensor.type import integer_dtypes
from pytensor.tensor.type_other import NoneTypeT, SliceType from pytensor.tensor.type_other import NoneTypeT
def is_rv_used_in_graph(base_rv, node, fgraph): def is_rv_used_in_graph(base_rv, node, fgraph):
...@@ -237,20 +237,15 @@ def local_subtensor_rv_lift(fgraph, node): ...@@ -237,20 +237,15 @@ def local_subtensor_rv_lift(fgraph, node):
return False return False
# Parse indices # Parse indices
if isinstance(subtensor_op, Subtensor): if isinstance(subtensor_op, Subtensor | AdvancedSubtensor):
indices = indices_from_subtensor(node.inputs[1:], subtensor_op.idx_list) indices = indices_from_subtensor(node.inputs[1:], subtensor_op.idx_list)
else: else:
indices = node.inputs[1:] indices = node.inputs[1:]
# The rewrite doesn't apply if advanced indexing could broadcast the samples (leading to duplicates) # The rewrite doesn't apply if advanced indexing could broadcast the samples (leading to duplicates)
# Note: For simplicity this also excludes subtensor-related expand_dims (np.newaxis). # TODO: This rewrite is aborting with dummy indexing dimensions which aren't a problem
# If we wanted to support that we could rewrite it as subtensor + dimshuffle # (e.g., x[[0],] is equivalent to x[0] - can only index one entry, won't lead to duplicates)
# and make use of the dimshuffle lift rewrite if any(is_nd_advanced_idx(idx, integer_dtypes) for idx in indices):
# TODO: This rewrite is aborting with dummy indexing dimensions which aren't a problem return False
if any(
is_nd_advanced_idx(idx, integer_dtypes) or isinstance(idx.type, NoneTypeT)
for idx in indices
):
return False
# Check that indexing does not act on support dims # Check that indexing does not act on support dims
batch_ndims = rv_op.batch_ndim(rv_node) batch_ndims = rv_op.batch_ndim(rv_node)
...@@ -268,10 +263,7 @@ def local_subtensor_rv_lift(fgraph, node): ...@@ -268,10 +263,7 @@ def local_subtensor_rv_lift(fgraph, node):
non_bool_indices[batch_ndims:], non_bool_indices[batch_ndims:],
) )
for idx in supp_indices: for idx in supp_indices:
if not ( if idx != slice(None):
isinstance(idx.type, SliceType)
and all(isinstance(i.type, NoneTypeT) for i in idx.owner.inputs)
):
return False return False
n_discarded_idxs = len(supp_indices) n_discarded_idxs = len(supp_indices)
indices = indices[:-n_discarded_idxs] indices = indices[:-n_discarded_idxs]
...@@ -331,7 +323,7 @@ def local_subtensor_rv_lift(fgraph, node): ...@@ -331,7 +323,7 @@ def local_subtensor_rv_lift(fgraph, node):
# Broadcasted dim # Broadcasted dim
if curr_dim in bcast_param_dims: if curr_dim in bcast_param_dims:
# Slice indexing, keep degenerate dim by none-slicing # Slice indexing, keep degenerate dim by none-slicing
if isinstance(idx, slice) or isinstance(idx.type, SliceType): if isinstance(idx, slice):
batch_indices.append(slice(None)) batch_indices.append(slice(None))
# Integer indexing, drop degenerate dim by 0-indexing # Integer indexing, drop degenerate dim by 0-indexing
else: else:
......
...@@ -17,7 +17,6 @@ from pytensor.graph.rewriting.basic import ( ...@@ -17,7 +17,6 @@ from pytensor.graph.rewriting.basic import (
) )
from pytensor.graph.traversal import ancestors from pytensor.graph.traversal import ancestors
from pytensor.graph.utils import InconsistencyError, get_variable_trace_string from pytensor.graph.utils import InconsistencyError, get_variable_trace_string
from pytensor.scalar import ScalarType
from pytensor.tensor.basic import ( from pytensor.tensor.basic import (
MakeVector, MakeVector,
as_tensor_variable, as_tensor_variable,
...@@ -842,13 +841,16 @@ def _is_shape_i_of_x( ...@@ -842,13 +841,16 @@ def _is_shape_i_of_x(
if isinstance(var.owner.op, Shape_i): if isinstance(var.owner.op, Shape_i):
return (var.owner.op.i == i) and (var.owner.inputs[0] == x) # type: ignore return (var.owner.op.i == i) and (var.owner.inputs[0] == x) # type: ignore
# Match Subtensor((ScalarType,))(Shape(input), i) # Match Subtensor((int,))(Shape(input), i) - single integer index into shape
if isinstance(var.owner.op, Subtensor): if isinstance(var.owner.op, Subtensor):
idx_entry = (
var.owner.op.idx_list[0] if len(var.owner.op.idx_list) == 1 else None
)
return ( return (
# Check we have integer indexing operation # Check we have integer indexing operation
# (and not slice or multiple indexing) # (and not slice or multiple indexing)
len(var.owner.op.idx_list) == 1 len(var.owner.op.idx_list) == 1
and isinstance(var.owner.op.idx_list[0], ScalarType) and isinstance(idx_entry, int)
# Check we are indexing on the shape of x # Check we are indexing on the shape of x
and var.owner.inputs[0].owner is not None and var.owner.inputs[0].owner is not None
and isinstance(var.owner.inputs[0].owner.op, Shape) and isinstance(var.owner.inputs[0].owner.op, Shape)
......
...@@ -8,7 +8,6 @@ from pytensor import Variable ...@@ -8,7 +8,6 @@ from pytensor import Variable
from pytensor.compile import optdb from pytensor.compile import optdb
from pytensor.graph import Constant, FunctionGraph, node_rewriter, vectorize_graph from pytensor.graph import Constant, FunctionGraph, node_rewriter, vectorize_graph
from pytensor.graph.rewriting.basic import NodeRewriter, copy_stack_trace from pytensor.graph.rewriting.basic import NodeRewriter, copy_stack_trace
from pytensor.scalar import basic as ps
from pytensor.tensor.basic import ( from pytensor.tensor.basic import (
Alloc, Alloc,
Join, Join,
...@@ -31,7 +30,7 @@ from pytensor.tensor.rewriting.basic import ( ...@@ -31,7 +30,7 @@ from pytensor.tensor.rewriting.basic import (
register_stabilize, register_stabilize,
) )
from pytensor.tensor.rewriting.elemwise import local_dimshuffle_lift from pytensor.tensor.rewriting.elemwise import local_dimshuffle_lift
from pytensor.tensor.rewriting.subtensor import is_full_slice, register_useless from pytensor.tensor.rewriting.subtensor import register_useless
from pytensor.tensor.shape import ( from pytensor.tensor.shape import (
Shape, Shape,
SpecifyShape, SpecifyShape,
...@@ -50,7 +49,6 @@ from pytensor.tensor.subtensor import ( ...@@ -50,7 +49,6 @@ from pytensor.tensor.subtensor import (
indices_from_subtensor, indices_from_subtensor,
) )
from pytensor.tensor.type import TensorType from pytensor.tensor.type import TensorType
from pytensor.tensor.type_other import NoneTypeT, SliceType
from pytensor.tensor.variable import TensorVariable from pytensor.tensor.variable import TensorVariable
...@@ -71,7 +69,7 @@ def _axis_is_indexed_by_basic_index( ...@@ -71,7 +69,7 @@ def _axis_is_indexed_by_basic_index(
) -> bool: ) -> bool:
if isinstance(axis, int): if isinstance(axis, int):
axis = (axis,) axis = (axis,)
return any(ax < len(idxs) and not is_full_slice(idxs[ax]) for ax in axis) return any(ax < len(idxs) and not idxs[ax] == slice(None) for ax in axis)
def _lift_subtensor_non_axis( def _lift_subtensor_non_axis(
...@@ -83,7 +81,7 @@ def _lift_subtensor_non_axis( ...@@ -83,7 +81,7 @@ def _lift_subtensor_non_axis(
old_subtensor_variable: TensorVariable, old_subtensor_variable: TensorVariable,
) -> None | list[TensorVariable]: ) -> None | list[TensorVariable]:
# Apply generic subtensor lift rewrite along "non-axis" dimensions # Apply generic subtensor lift rewrite along "non-axis" dimensions
real_indices = [idx for idx in idx_tuple if not is_full_slice(idx)] real_indices = [idx for idx in idx_tuple if not idx == slice(None)]
if len(real_indices) > 1 and variable.type.ndim > 1: if len(real_indices) > 1 and variable.type.ndim > 1:
# Split the subtensor # Split the subtensor
idx_to_keep = idx_tuple[axis] idx_to_keep = idx_tuple[axis]
...@@ -206,7 +204,7 @@ def local_subtensor_of_batch_dims(fgraph, node): ...@@ -206,7 +204,7 @@ def local_subtensor_of_batch_dims(fgraph, node):
if len(idx_tuple) > batch_ndim: if len(idx_tuple) > batch_ndim:
# Indexing on core dimensions of Blockwise. We split the indices and lift the batch ones only # Indexing on core dimensions of Blockwise. We split the indices and lift the batch ones only
batch_indices, core_indices = idx_tuple[:batch_ndim], idx_tuple[batch_ndim:] batch_indices, core_indices = idx_tuple[:batch_ndim], idx_tuple[batch_ndim:]
if all(is_full_slice(idx) for idx in batch_indices): if all(idx == slice(None) for idx in batch_indices):
# No batch indices, nothing to do # No batch indices, nothing to do
return None return None
elem_with_batch_indices = elem[batch_indices] elem_with_batch_indices = elem[batch_indices]
...@@ -240,7 +238,7 @@ def local_subtensor_of_batch_dims(fgraph, node): ...@@ -240,7 +238,7 @@ def local_subtensor_of_batch_dims(fgraph, node):
strict=False, strict=False,
) )
): ):
if is_full_slice(dim_idx): if dim_idx == slice(None):
# Full slice can be safely applied to all inputs # Full slice can be safely applied to all inputs
continue continue
...@@ -429,7 +427,7 @@ def local_subtensor_of_expand_dims(fgraph, node): ...@@ -429,7 +427,7 @@ def local_subtensor_of_expand_dims(fgraph, node):
if i in expanded_axes: if i in expanded_axes:
if isinstance(idx_item, slice): if isinstance(idx_item, slice):
# Slice could be keeping or dropping this dimension # Slice could be keeping or dropping this dimension
if is_full_slice(idx_item): if idx_item == slice(None):
# A None slice, always keeps the dimension. # A None slice, always keeps the dimension.
# We skip the index, and later introduce the needed expand_dim # We skip the index, and later introduce the needed expand_dim
continue continue
...@@ -648,10 +646,7 @@ def local_subtensor_SpecifyShape_lift(fgraph, node): ...@@ -648,10 +646,7 @@ def local_subtensor_SpecifyShape_lift(fgraph, node):
indices = get_idx_list(node.inputs, node.op.idx_list) indices = get_idx_list(node.inputs, node.op.idx_list)
if any( if any(isinstance(index, slice) for index in indices):
isinstance(index, slice) or isinstance(getattr(index, "type", None), SliceType)
for index in indices
):
return False return False
new_obj_arg = obj_arg[indices] new_obj_arg = obj_arg[indices]
...@@ -702,15 +697,12 @@ def local_subtensor_make_vector(fgraph, node): ...@@ -702,15 +697,12 @@ def local_subtensor_make_vector(fgraph, node):
(idx,) = idxs (idx,) = idxs
if isinstance(idx, ps.ScalarType | TensorType): if isinstance(idx, int):
old_idx, idx = idx, node.inputs[1] idx = node.inputs[1]
assert idx.type.is_super(old_idx)
elif isinstance(node.op, AdvancedSubtensor1): elif isinstance(node.op, AdvancedSubtensor1):
idx = node.inputs[1] idx = node.inputs[1]
if isinstance(idx, int | np.integer): if isinstance(idx, Variable):
return [x.owner.inputs[idx]]
elif isinstance(idx, Variable):
if idx.ndim == 0: if idx.ndim == 0:
try: try:
v = get_underlying_scalar_constant_value( v = get_underlying_scalar_constant_value(
...@@ -833,8 +825,6 @@ def local_subtensor_shape_constant(fgraph, node): ...@@ -833,8 +825,6 @@ def local_subtensor_shape_constant(fgraph, node):
except NotScalarConstantError: except NotScalarConstantError:
return False return False
assert idx_val != np.newaxis
if not isinstance(shape_arg.type, TensorType): if not isinstance(shape_arg.type, TensorType):
return False return False
...@@ -871,22 +861,24 @@ def local_subtensor_of_adv_subtensor(fgraph, node): ...@@ -871,22 +861,24 @@ def local_subtensor_of_adv_subtensor(fgraph, node):
# AdvancedSubtensor involves a full_copy, so we don't want to do it twice # AdvancedSubtensor involves a full_copy, so we don't want to do it twice
return None return None
x, *adv_idxs = adv_subtensor.owner.inputs x, *adv_index_vars = adv_subtensor.owner.inputs
adv_idxs = indices_from_subtensor(adv_index_vars, adv_subtensor.owner.op.idx_list)
# Advanced indexing is a minefield, avoid all cases except for consecutive integer indices # Advanced indexing is a minefield, avoid all cases except for consecutive integer indices
if any( if (
( not all(
isinstance(adv_idx.type, NoneTypeT) (
or (isinstance(adv_idx.type, TensorType) and adv_idx.type.dtype == "bool") (isinstance(adv_idx, TensorVariable) and adv_idx.type.dtype != "bool")
or (isinstance(adv_idx.type, SliceType) and not is_full_slice(adv_idx)) or (isinstance(adv_idx, slice) and adv_idx == slice(None))
)
for adv_idx in adv_idxs
) )
for adv_idx in adv_idxs
) or _non_consecutive_adv_indexing(adv_idxs): ) or _non_consecutive_adv_indexing(adv_idxs):
return None return None
for first_adv_idx_dim, adv_idx in enumerate(adv_idxs): for first_adv_idx_dim, adv_idx in enumerate(adv_idxs):
# We already made sure there were only None slices besides integer indexes # We already made sure there were only None slices besides integer indexes
if isinstance(adv_idx.type, TensorType): if isinstance(adv_idx, TensorVariable):
break break
else: # no-break else: # no-break
# Not sure if this should ever happen, but better safe than sorry # Not sure if this should ever happen, but better safe than sorry
...@@ -909,7 +901,7 @@ def local_subtensor_of_adv_subtensor(fgraph, node): ...@@ -909,7 +901,7 @@ def local_subtensor_of_adv_subtensor(fgraph, node):
copy_stack_trace([basic_subtensor, adv_subtensor], x_indexed) copy_stack_trace([basic_subtensor, adv_subtensor], x_indexed)
x_after_index_lift = expand_dims(x_indexed, dropped_dims) x_after_index_lift = expand_dims(x_indexed, dropped_dims)
x_after_adv_idx = adv_subtensor.owner.op(x_after_index_lift, *adv_idxs) x_after_adv_idx = adv_subtensor.owner.op(x_after_index_lift, *adv_index_vars)
copy_stack_trace([basic_subtensor, adv_subtensor], x_after_adv_idx) copy_stack_trace([basic_subtensor, adv_subtensor], x_after_adv_idx)
new_out = squeeze(x_after_adv_idx[basic_idxs_kept], dropped_dims) new_out = squeeze(x_after_adv_idx[basic_idxs_kept], dropped_dims)
......
...@@ -38,7 +38,7 @@ from pytensor.tensor.elemwise import CAReduce, DimShuffle ...@@ -38,7 +38,7 @@ from pytensor.tensor.elemwise import CAReduce, DimShuffle
from pytensor.tensor.math import Min, neg from pytensor.tensor.math import Min, neg
from pytensor.tensor.rewriting.basic import register_uncanonicalize from pytensor.tensor.rewriting.basic import register_uncanonicalize
from pytensor.tensor.shape import Reshape, reshape from pytensor.tensor.shape import Reshape, reshape
from pytensor.tensor.subtensor import Subtensor from pytensor.tensor.subtensor import Subtensor, indices_from_subtensor
@register_uncanonicalize @register_uncanonicalize
...@@ -193,60 +193,42 @@ def local_dimshuffle_subtensor(fgraph, node): ...@@ -193,60 +193,42 @@ def local_dimshuffle_subtensor(fgraph, node):
if not all(broadcastable[i] for i in missing_dims): if not all(broadcastable[i] for i in missing_dims):
return False return False
# create a new idx_list for a new Subtensor object # create a new index tuple for a new Subtensor
# have to loop on idx_list and inputs # Reconstruct the full indices from the subtensor node, then replace
# inputs has the length of sum of non None elements of idx_list # dimensions that are being dropped by dimshuffle with scalar index 0
# (check in slice!). x = input_.owner.inputs[0]
# len(missing_dims) can be < len(idx_list), this happens if indices = list(
# tensor was indexed such as x[scalar, :, :], check that as well indices_from_subtensor(
new_idx_list = list(input_.owner.op.idx_list) input_.owner.inputs[1:], input_.owner.op.idx_list
new_inputs = [input_.owner.inputs[0]] )
)
zero = constant(0) zero = constant(0)
j = 0
slice_i = -1 # Track which output dimension each index corresponds to
subtensor_removed_dims = 0 # Scalar indices remove dimensions, slices keep them
for i, idx in enumerate(input_.owner.op.idx_list): output_dim = 0
for i, idx in enumerate(indices):
if isinstance(idx, slice): if isinstance(idx, slice):
slice_i += 1 # This slice produces an output dimension
if slice_i in missing_dims: if output_dim in missing_dims:
# Missing dim is a slice(None), remove by indexing by 0 # This output dimension is being dropped, so replace slice with scalar
if idx == slice(None): if idx == slice(None):
new_idx_list[i] = zero indices[i] = zero
new_inputs += [zero]
# Missing dim is an ordinary slice with known output dim length of 1
# Remove by indexing by start
else: else:
if idx.start is None: # Use the start of the slice (or 0 if None)
start = zero indices[i] = idx.start if idx.start is not None else zero
else: output_dim += 1
start = input_.owner.inputs[1 + j] # Scalar indices don't contribute to output dimensions
j += 1
new_idx_list[i] = start # Handle trailing dimensions that weren't explicitly indexed
new_inputs += [start] for input_dim in range(len(indices), x.ndim):
if output_dim in missing_dims:
# Ignore useless stop and step input if there is one # This unindexed dimension is being dropped, index with 0
for slice_attr in ("stop", "step"): indices.append(zero)
if getattr(idx, slice_attr) is not None:
j += 1
# Keep non-dropped slice inputs
else:
for slice_attr in ("start", "stop", "step"):
if getattr(idx, slice_attr) is not None:
new_inputs += [input_.owner.inputs[1 + j]]
j += 1
# Keep non-dropped non-slice inputs
else: else:
new_inputs += [input_.owner.inputs[1 + j]] # This unindexed dimension is kept, index with slice(None)
j += 1 indices.append(slice(None))
subtensor_removed_dims += 1 output_dim += 1
# Verify the trailing dimensions the subtensor didn't look at.
for idx in range(len(input_.owner.op.idx_list), new_inputs[0].ndim): return [x[tuple(indices)]]
if (idx - subtensor_removed_dims) in missing_dims:
while len(new_idx_list) < idx:
new_idx_list.append(slice(None))
new_idx_list.append(zero)
new_inputs.append(zero)
return [Subtensor(new_idx_list)(*new_inputs)]
return False return False
...@@ -15,9 +15,8 @@ from pytensor.scalar import ( ...@@ -15,9 +15,8 @@ from pytensor.scalar import (
ComplexError, ComplexError,
) )
from pytensor.tensor import _get_vector_length from pytensor.tensor import _get_vector_length
from pytensor.tensor.exceptions import AdvancedIndexingError
from pytensor.tensor.type import TensorType from pytensor.tensor.type import TensorType
from pytensor.tensor.type_other import NoneConst from pytensor.tensor.type_other import NoneTypeT
from pytensor.tensor.utils import hash_from_ndarray from pytensor.tensor.utils import hash_from_ndarray
...@@ -455,15 +454,14 @@ class _tensor_py_operators: ...@@ -455,15 +454,14 @@ class _tensor_py_operators:
elif not isinstance(args, tuple): elif not isinstance(args, tuple):
args = (args,) args = (args,)
# Count the dimensions, check for bools and find ellipses.
ellipses = [] ellipses = []
index_dim_count = 0 index_dim_count = 0
for i, arg in enumerate(args): for i, arg in enumerate(args):
if arg is np.newaxis or arg is NoneConst: if arg is None or (
# no increase in index_dim_count isinstance(arg, Variable) and isinstance(arg.type, NoneTypeT)
):
pass pass
elif arg is Ellipsis: elif arg is Ellipsis:
# no increase in index_dim_count
ellipses.append(i) ellipses.append(i)
elif ( elif (
isinstance(arg, np.ndarray | Variable) isinstance(arg, np.ndarray | Variable)
...@@ -505,6 +503,41 @@ class _tensor_py_operators: ...@@ -505,6 +503,41 @@ class _tensor_py_operators:
self.ndim - index_dim_count self.ndim - index_dim_count
) )
if any(
arg is None
or (isinstance(arg, Variable) and isinstance(arg.type, NoneTypeT))
for arg in args
):
expansion_axes = []
new_args = []
# Track dims consumed by args and inserted `None`s after ellipsis
counter = 0
nones = 0
for arg in args:
if arg is None or (
isinstance(arg, Variable) and isinstance(arg.type, NoneTypeT)
):
expansion_axes.append(counter + nones) # Expand here
nones += 1
new_args.append(slice(None))
else:
new_args.append(arg)
consumed = 1
if hasattr(arg, "dtype") and arg.dtype == "bool":
consumed = arg.ndim
counter += consumed
expanded = pt.expand_dims(self, expansion_axes)
if all(
isinstance(arg, slice)
and arg.start is None
and arg.stop is None
and arg.step is None
for arg in new_args
):
return expanded
return expanded[tuple(new_args)]
def is_empty_array(val): def is_empty_array(val):
return (isinstance(val, tuple | list) and len(val) == 0) or ( return (isinstance(val, tuple | list) and len(val) == 0) or (
isinstance(val, np.ndarray) and val.size == 0 isinstance(val, np.ndarray) and val.size == 0
...@@ -520,74 +553,16 @@ class _tensor_py_operators: ...@@ -520,74 +553,16 @@ class _tensor_py_operators:
for inp in args for inp in args
) )
# Determine if advanced indexing is needed or not. The logic is if all(
# already in `index_vars_to_types`: if it succeeds, standard indexing is (
# used; if it fails with `AdvancedIndexingError`, advanced indexing is isinstance(arg, slice | int | float | np.number)
# used or (hasattr(arg, "ndim") and arg.ndim == 0 and arg.dtype != "bool")
advanced = False )
for i, arg in enumerate(args): for arg in args
if includes_bool(arg): ):
advanced = True return pt.subtensor.basic_subtensor(self, *args)
break
if arg is not np.newaxis and arg is not NoneConst:
try:
pt.subtensor.index_vars_to_types(arg)
except AdvancedIndexingError:
if advanced:
break
else:
advanced = True
if advanced:
return pt.subtensor.advanced_subtensor(self, *args)
else: else:
if np.newaxis in args or NoneConst in args: return pt.subtensor.advanced_subtensor(self, *args)
# `np.newaxis` (i.e. `None`) in NumPy indexing mean "add a new
# broadcastable dimension at this location". Since PyTensor adds
# new broadcastable dimensions via the `DimShuffle` `Op`, the
# following code uses said `Op` to add one of the new axes and
# then uses recursion to apply any other indices and add any
# remaining new axes.
counter = 0
pattern = []
new_args = []
for arg in args:
if arg is np.newaxis or arg is NoneConst:
pattern.append("x")
new_args.append(slice(None, None, None))
else:
pattern.append(counter)
counter += 1
new_args.append(arg)
pattern.extend(list(range(counter, self.ndim)))
view = self.dimshuffle(pattern)
full_slices = True
for arg in new_args:
# We can't do arg == slice(None, None, None) as in
# Python 2.7, this call __lt__ if we have a slice
# with some symbolic variable.
if not (
isinstance(arg, slice)
and (arg.start is None or arg.start is NoneConst)
and (arg.stop is None or arg.stop is NoneConst)
and (arg.step is None or arg.step is NoneConst)
):
full_slices = False
if full_slices:
return view
else:
return view.__getitem__(tuple(new_args))
else:
return pt.subtensor.Subtensor(args)(
self,
*pt.subtensor.get_slice_elements(
args, lambda entry: isinstance(entry, Variable)
),
)
def __setitem__(self, key, value): def __setitem__(self, key, value):
raise TypeError( raise TypeError(
......
...@@ -2,9 +2,10 @@ from itertools import zip_longest ...@@ -2,9 +2,10 @@ from itertools import zip_longest
from pytensor import as_symbolic from pytensor import as_symbolic
from pytensor.graph import Constant, node_rewriter from pytensor.graph import Constant, node_rewriter
from pytensor.tensor import TensorType, arange, specify_shape from pytensor.tensor import arange, specify_shape
from pytensor.tensor.subtensor import _non_consecutive_adv_indexing, inc_subtensor from pytensor.tensor.subtensor import _non_consecutive_adv_indexing, inc_subtensor
from pytensor.tensor.type_other import NoneTypeT, SliceType from pytensor.tensor.type_other import NoneTypeT, SliceType
from pytensor.tensor.variable import TensorVariable
from pytensor.xtensor.basic import tensor_from_xtensor, xtensor_from_tensor from pytensor.xtensor.basic import tensor_from_xtensor, xtensor_from_tensor
from pytensor.xtensor.indexing import Index, IndexUpdate, index from pytensor.xtensor.indexing import Index, IndexUpdate, index
from pytensor.xtensor.rewriting.utils import register_lower_xtensor from pytensor.xtensor.rewriting.utils import register_lower_xtensor
...@@ -106,7 +107,7 @@ def _lower_index(node): ...@@ -106,7 +107,7 @@ def _lower_index(node):
# We can use basic indexing directly if no other index acts on this dimension # We can use basic indexing directly if no other index acts on this dimension
# This is an optimization that avoids creating an unnecessary arange tensor # This is an optimization that avoids creating an unnecessary arange tensor
# and facilitates the use of the specialized AdvancedSubtensor1 when possible # and facilitates the use of the specialized AdvancedSubtensor1 when possible
aligned_idxs.append(idx) aligned_idxs.append(to_basic_idx(idx))
basic_idx_axis.append(out_dims.index(x_dim)) basic_idx_axis.append(out_dims.index(x_dim))
else: else:
# Otherwise we need to convert the basic index into an equivalent advanced indexing # Otherwise we need to convert the basic index into an equivalent advanced indexing
...@@ -131,7 +132,7 @@ def _lower_index(node): ...@@ -131,7 +132,7 @@ def _lower_index(node):
if basic_idx_axis: if basic_idx_axis:
aligned_idxs = [ aligned_idxs = [
idx.squeeze(axis=basic_idx_axis) idx.squeeze(axis=basic_idx_axis)
if (isinstance(idx.type, TensorType) and idx.type.ndim > 0) if (isinstance(idx, TensorVariable) and idx.type.ndim > 0)
else idx else idx
for idx in aligned_idxs for idx in aligned_idxs
] ]
......
...@@ -26,9 +26,7 @@ from pytensor.graph.rewriting.unify import LiteralString, OpPattern ...@@ -26,9 +26,7 @@ from pytensor.graph.rewriting.unify import LiteralString, OpPattern
from pytensor.raise_op import assert_op from pytensor.raise_op import assert_op
from pytensor.tensor.math import Dot, add, dot, exp from pytensor.tensor.math import Dot, add, dot, exp
from pytensor.tensor.rewriting.basic import constant_folding from pytensor.tensor.rewriting.basic import constant_folding
from pytensor.tensor.subtensor import AdvancedSubtensor
from pytensor.tensor.type import matrix, values_eq_approx_always_true, vector from pytensor.tensor.type import matrix, values_eq_approx_always_true, vector
from pytensor.tensor.type_other import MakeSlice, SliceConstant, slicetype
from tests.graph.utils import ( from tests.graph.utils import (
MyOp, MyOp,
MyType, MyType,
...@@ -629,21 +627,6 @@ def test_pre_constant_merge(): ...@@ -629,21 +627,6 @@ def test_pre_constant_merge():
assert res == [o2] assert res == [o2]
assert o2.owner.inputs[2] is c2 assert o2.owner.inputs[2] is c2
# What is this supposed to test?
ms = MakeSlice()(1)
res = pre_constant_merge(empty_fgraph, [ms])
assert res == [ms]
const_slice = SliceConstant(type=slicetype, data=slice(1, None, 2))
assert isinstance(const_slice, Constant)
adv = AdvancedSubtensor()(matrix(), [2, 3], const_slice)
res = pre_constant_merge(empty_fgraph, adv)
assert res == [adv]
def test_pre_greedy_node_rewriter(): def test_pre_greedy_node_rewriter():
empty_fgraph = FunctionGraph([], []) empty_fgraph = FunctionGraph([], [])
...@@ -679,15 +662,6 @@ def test_pre_greedy_node_rewriter(): ...@@ -679,15 +662,6 @@ def test_pre_greedy_node_rewriter():
assert cst.owner.inputs[0] is o1 assert cst.owner.inputs[0] is o1
assert cst.owner.inputs[4] is cst.owner.inputs[0] assert cst.owner.inputs[4] is cst.owner.inputs[0]
# What exactly is this supposed to test?
ms = MakeSlice()(1)
cst = pre_greedy_node_rewriter(empty_fgraph, [constant_folding], ms)
assert isinstance(cst, SliceConstant)
# Make sure constant of slice signature is hashable.
assert isinstance(hash(cst.signature()), int)
@pytest.mark.parametrize("tracks", [True, False]) @pytest.mark.parametrize("tracks", [True, False])
@pytest.mark.parametrize("out_pattern", [(op2, "x"), "x", 1.0]) @pytest.mark.parametrize("out_pattern", [(op2, "x"), "x", 1.0])
......
...@@ -225,6 +225,37 @@ def test_jax_IncSubtensor(): ...@@ -225,6 +225,37 @@ def test_jax_IncSubtensor():
compare_jax_and_py([], [out_pt], []) compare_jax_and_py([], [out_pt], [])
@pytest.mark.parametrize(
"func", (pt_subtensor.advanced_inc_subtensor1, pt_subtensor.advanced_set_subtensor1)
)
def test_jax_AdvancedIncSubtensor1_runtime_broadcast(func):
"""Test that JAX backend checks for runtime broadcasting in AdvancedIncSubtensor1.
JAX silently broadcasts when using .at[].set() or .at[].add(), but PyTensor
requires explicit broadcastable dimensions. This test ensures we raise the same
error as the Python/C backend when runtime broadcasting would occur.
"""
from pytensor import function
y = pt.matrix("y", dtype="float64", shape=(None, None))
x = pt.zeros((10, 5))
idxs = np.repeat(np.arange(10), 2) # 20 indices
out = func(x, y, idxs)
f = function([y], out, mode="JAX")
# Should work with correctly sized y
f(np.ones((20, 5)))
# Should raise for runtime broadcasting on first dimension
with pytest.raises(ValueError, match="Runtime broadcasting not allowed"):
f(np.ones((1, 5)))
# Should raise for runtime broadcasting on second dimension
with pytest.raises(ValueError, match="Runtime broadcasting not allowed"):
f(np.ones((20, 1)))
def test_jax_IncSubtensor_boolean_indexing_reexpressible(): def test_jax_IncSubtensor_boolean_indexing_reexpressible():
"""Setting or incrementing values with boolean indexing. """Setting or incrementing values with boolean indexing.
......
...@@ -187,27 +187,6 @@ def test_mlx_inplace_variants(): ...@@ -187,27 +187,6 @@ def test_mlx_inplace_variants():
compare_mlx_and_py([], [out_pt], []) compare_mlx_and_py([], [out_pt], [])
@pytest.mark.xfail(
reason="MLX slice indices must be integers or None, dynamic slices not supported"
)
def test_mlx_MakeSlice():
"""Test MakeSlice operation."""
# Test slice creation
start = pt.iscalar("start")
stop = pt.iscalar("stop")
step = pt.iscalar("step")
# Create a slice using MakeSlice
slice_op = pt_subtensor.MakeSlice()
slice_pt = slice_op(start, stop, step)
# Use simple constant array instead of arange
x_pt = pt.constant(np.arange(10, dtype=np.float32))
out_pt = x_pt[slice_pt]
compare_mlx_and_py([start, stop, step], [out_pt], [1, 8, 2])
def test_mlx_subtensor_edge_cases(): def test_mlx_subtensor_edge_cases():
"""Test edge cases and boundary conditions.""" """Test edge cases and boundary conditions."""
# Empty slices - use constant array # Empty slices - use constant array
......
...@@ -3,9 +3,7 @@ import contextlib ...@@ -3,9 +3,7 @@ import contextlib
import numpy as np import numpy as np
import pytest import pytest
import pytensor.scalar as ps
import pytensor.tensor as pt import pytensor.tensor as pt
from pytensor import Mode, as_symbolic
from pytensor.tensor import as_tensor from pytensor.tensor import as_tensor
from pytensor.tensor.subtensor import ( from pytensor.tensor.subtensor import (
AdvancedIncSubtensor, AdvancedIncSubtensor,
...@@ -20,51 +18,16 @@ from pytensor.tensor.subtensor import ( ...@@ -20,51 +18,16 @@ from pytensor.tensor.subtensor import (
inc_subtensor, inc_subtensor,
set_subtensor, set_subtensor,
) )
from tests.link.numba.test_basic import compare_numba_and_py, numba_mode from tests.link.numba.test_basic import (
compare_numba_and_py,
numba_inplace_mode,
numba_mode,
)
rng = np.random.default_rng(sum(map(ord, "Numba subtensors"))) rng = np.random.default_rng(sum(map(ord, "Numba subtensors")))
@pytest.mark.parametrize("step", [None, 1, 2, -2, "x"], ids=lambda x: f"step={x}")
@pytest.mark.parametrize("stop", [None, 10, "x"], ids=lambda x: f"stop={x}")
@pytest.mark.parametrize("start", [None, 0, 3, "x"], ids=lambda x: f"start={x}")
def test_slice(start, stop, step):
x = ps.int64("x")
sym_slice = as_symbolic(
slice(
x if start == "x" else start,
x if stop == "x" else stop,
x if step == "x" else step,
)
)
no_opt_mode = Mode(linker="numba", optimizer=None)
evaled_slice = sym_slice.eval({x: -5}, on_unused_input="ignore", mode=no_opt_mode)
assert isinstance(evaled_slice, slice)
if start == "x":
assert evaled_slice.start == -5
elif start is None and (evaled_slice.step is None or evaled_slice.step > 0):
# Numba can convert to 0 (and sometimes does) in this case
assert evaled_slice.start in (None, 0)
else:
assert evaled_slice.start == start
if stop == "x":
assert evaled_slice.stop == -5
else:
assert evaled_slice.stop == stop
if step == "x":
assert evaled_slice.step == -5
elif step is None:
# Numba can convert to 1 (and sometimes does) in this case
assert evaled_slice.step in (None, 1)
else:
assert evaled_slice.step == step
@pytest.mark.parametrize( @pytest.mark.parametrize(
"x, indices", "x, indices",
[ [
...@@ -182,6 +145,11 @@ def test_AdvancedSubtensor1_out_of_bounds(): ...@@ -182,6 +145,11 @@ def test_AdvancedSubtensor1_out_of_bounds():
as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
([[1, 2], [2, 1]], slice(1, None), [[0, 0], [0, 0]]), ([[1, 2], [2, 1]], slice(1, None), [[0, 0], [0, 0]]),
), ),
# Newaxis with vector indexing
(
as_tensor(np.arange(4 * 4).reshape((4, 4))),
(None, [0, 1, 2], [0, 1, 2]),
),
], ],
) )
@pytest.mark.filterwarnings("error") # Raise if we did not expect objmode to be needed @pytest.mark.filterwarnings("error") # Raise if we did not expect objmode to be needed
...@@ -447,6 +415,13 @@ def test_AdvancedIncSubtensor1(x, y, indices): ...@@ -447,6 +415,13 @@ def test_AdvancedIncSubtensor1(x, y, indices):
False, False,
False, False,
), ),
(
np.arange(4 * 4).reshape((4, 4)),
np.array(5), # Broadcasted scalar value
(None, [0, 1, 2], [0, 1, 2]), # Newaxis with vector indexing
False,
False,
),
], ],
) )
@pytest.mark.parametrize("inplace", (False, True)) @pytest.mark.parametrize("inplace", (False, True))
...@@ -460,7 +435,9 @@ def test_AdvancedIncSubtensor( ...@@ -460,7 +435,9 @@ def test_AdvancedIncSubtensor(
inplace, inplace,
): ):
# Need rewrite to support certain forms of advanced indexing without object mode # Need rewrite to support certain forms of advanced indexing without object mode
mode = numba_mode.including("specialize") # Use inplace_mode when testing inplace operations to preserve inplace flag
base_mode = numba_inplace_mode if inplace else numba_mode
mode = base_mode.including("specialize")
x_pt = pt.as_tensor(x).type("x") x_pt = pt.as_tensor(x).type("x")
y_pt = pt.as_tensor(y).type("y") y_pt = pt.as_tensor(y).type("y")
...@@ -514,22 +491,3 @@ def test_AdvancedIncSubtensor( ...@@ -514,22 +491,3 @@ def test_AdvancedIncSubtensor(
x_orig = x.copy() x_orig = x.copy()
fn(x, y) fn(x, y)
assert not np.all(x == x_orig) assert not np.all(x == x_orig)
def test_advanced_indexing_with_newaxis_fallback_obj_mode():
# This should be automatically solved with https://github.com/pymc-devs/pytensor/issues/1564
# After which we can add these parametrizations to the relevant tests above
x = pt.matrix("x")
out = x[None, [0, 1, 2], [0, 1, 2]]
with pytest.warns(
UserWarning,
match=r"Numba will use object mode to run AdvancedSubtensor's perform method",
):
compare_numba_and_py([x], [out], [np.random.normal(size=(4, 4))])
out = x[None, [0, 1, 2], [0, 1, 2]].inc(5)
with pytest.warns(
UserWarning,
match=r"Numba will use object mode to run AdvancedIncSubtensor's perform method",
):
compare_numba_and_py([x], [out], [np.random.normal(size=(4, 4))])
...@@ -1642,9 +1642,15 @@ def test_InplaceElemwiseOptimizer_bug(): ...@@ -1642,9 +1642,15 @@ def test_InplaceElemwiseOptimizer_bug():
# with config.change_flags(tensor__insert_inplace_optimizer_validate_nb=10): # with config.change_flags(tensor__insert_inplace_optimizer_validate_nb=10):
rewrite_graph(fgraph, include=("inplace",)) rewrite_graph(fgraph, include=("inplace",))
pytensor.config.tensor__insert_inplace_optimizer_validate_nb = 1 # Save original value to restore later
with pytest.warns( original_value = pytensor.config.tensor__insert_inplace_optimizer_validate_nb
FutureWarning, try:
match="tensor__insert_inplace_optimizer_validate_nb config is deprecated", pytensor.config.tensor__insert_inplace_optimizer_validate_nb = 1
): with pytest.warns(
rewrite_graph(fgraph, include=("inplace",)) FutureWarning,
match="tensor__insert_inplace_optimizer_validate_nb config is deprecated",
):
rewrite_graph(fgraph, include=("inplace",))
finally:
# Restore original value to avoid affecting other tests
pytensor.config.tensor__insert_inplace_optimizer_validate_nb = original_value
...@@ -52,7 +52,6 @@ from pytensor.tensor.type import ( ...@@ -52,7 +52,6 @@ from pytensor.tensor.type import (
tensor4, tensor4,
vector, vector,
) )
from pytensor.tensor.type_other import make_slice
from tests import unittest_tools as utt from tests import unittest_tools as utt
from tests.unittest_tools import create_pytensor_param from tests.unittest_tools import create_pytensor_param
...@@ -1701,11 +1700,11 @@ def test_local_uint_constant_indices(): ...@@ -1701,11 +1700,11 @@ def test_local_uint_constant_indices():
assert isinstance(new_index, Constant) assert isinstance(new_index, Constant)
assert new_index.type.dtype == "uint8" assert new_index.type.dtype == "uint8"
# `AdvancedSubtensor`, two indices, one symbolic slice, convert # `AdvancedSubtensor`, two indices, one slice, convert
x = pt.matrix("x") x = pt.matrix("x")
indices = ( indices = (
pt.as_tensor_variable(np.array(1, np.int64)), pt.as_tensor_variable(np.array([1], np.int64)),
make_slice(slice(None, 10)), slice(None, 10),
) )
z = x[indices] z = x[indices]
...@@ -1792,7 +1791,7 @@ def test_local_uint_constant_indices(): ...@@ -1792,7 +1791,7 @@ def test_local_uint_constant_indices():
z_fn = pytensor.function([x], z, mode=mode) z_fn = pytensor.function([x], z, mode=mode)
subtensor_node = z_fn.maker.fgraph.outputs[0].owner subtensor_node = z_fn.maker.fgraph.outputs[0].owner
assert isinstance(subtensor_node.op, AdvancedSubtensor) assert isinstance(subtensor_node.op, (AdvancedSubtensor, AdvancedSubtensor1))
new_index = subtensor_node.inputs[1] new_index = subtensor_node.inputs[1]
assert isinstance(new_index, Constant) assert isinstance(new_index, Constant)
assert new_index.type.dtype == "uint8" assert new_index.type.dtype == "uint8"
...@@ -1843,7 +1842,6 @@ class TestBlockwiseIncSubtensor: ...@@ -1843,7 +1842,6 @@ class TestBlockwiseIncSubtensor:
out = vectorize_graph(core_graph, replace={core_x: x, core_y: y}) out = vectorize_graph(core_graph, replace={core_x: x, core_y: y})
fn, ref_fn = self.compile_fn_and_ref([x, y], out) fn, ref_fn = self.compile_fn_and_ref([x, y], out)
assert self.has_blockwise(ref_fn) assert self.has_blockwise(ref_fn)
assert not self.has_blockwise(fn)
test_x = np.ones(x.type.shape, dtype=x.type.dtype) test_x = np.ones(x.type.shape, dtype=x.type.dtype)
test_y = rng.integers(1, 10, size=y.type.shape, dtype=y.type.dtype) test_y = rng.integers(1, 10, size=y.type.shape, dtype=y.type.dtype)
np.testing.assert_allclose(fn(test_x, test_y), ref_fn(test_x, test_y)) np.testing.assert_allclose(fn(test_x, test_y), ref_fn(test_x, test_y))
...@@ -1948,15 +1946,7 @@ class TestBlockwiseIncSubtensor: ...@@ -1948,15 +1946,7 @@ class TestBlockwiseIncSubtensor:
@pytest.mark.parametrize( @pytest.mark.parametrize(
"basic_idx", "basic_idx",
[ [True, False],
True,
pytest.param(
False,
marks=pytest.mark.xfail(
reason="AdvancedIncSubtensor with slices can't be blockwise"
),
),
],
ids=["basic_idx", "adv_idx"], ids=["basic_idx", "adv_idx"],
) )
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -1973,7 +1963,7 @@ class TestBlockwiseIncSubtensor: ...@@ -1973,7 +1963,7 @@ class TestBlockwiseIncSubtensor:
core_idx = pt.tensor("idx", dtype=int, shape=() if basic_idx else (2,)) core_idx = pt.tensor("idx", dtype=int, shape=() if basic_idx else (2,))
# The empty slice before core_idx, will lead to a transposition of the advanced view # The empty slice before core_idx, will lead to a transposition of the advanced view
# once it is paired with an new arange slice on the batched dimensions. # once it is paired with a new arange slice on the batched dimensions.
# That's why core_v is (2, 3), and not (3, 2), in the case of advanced indexing # That's why core_v is (2, 3), and not (3, 2), in the case of advanced indexing
core_out = core_a[0, :, core_idx].set(core_v) core_out = core_a[0, :, core_idx].set(core_v)
......
...@@ -32,7 +32,6 @@ from pytensor.tensor import ( ...@@ -32,7 +32,6 @@ from pytensor.tensor import (
lscalars, lscalars,
matrix, matrix,
shape, shape,
slicetype,
specify_shape, specify_shape,
tensor, tensor,
tensor3, tensor3,
...@@ -557,7 +556,7 @@ class TestLocalSubtensorSpecifyShapeLift: ...@@ -557,7 +556,7 @@ class TestLocalSubtensorSpecifyShapeLift:
( (
matrix(), matrix(),
(iscalar(), iscalar()), (iscalar(), iscalar()),
(slicetype(),), (slice(iscalar(), iscalar(), iscalar()),),
), ),
( (
matrix(), matrix(),
...@@ -789,12 +788,12 @@ def test_local_subtensor_shape_constant(): ...@@ -789,12 +788,12 @@ def test_local_subtensor_shape_constant():
(lambda x: x[:, [0, 1]][0], True), (lambda x: x[:, [0, 1]][0], True),
(lambda x: x[:, [0, 1], [0, 0]][1:], True), (lambda x: x[:, [0, 1], [0, 0]][1:], True),
(lambda x: x[:, [[0, 1], [0, 0]]][1:], True), (lambda x: x[:, [[0, 1], [0, 0]]][1:], True),
(lambda x: x[:, None, [0, 1]][0], True),
# Not supported, basic indexing on advanced indexing dim # Not supported, basic indexing on advanced indexing dim
(lambda x: x[[0, 1]][0], False), (lambda x: x[[0, 1]][0], False),
# Not implemented, basic indexing on the right of advanced indexing # Not supported, basic indexing on the right of advanced indexing
(lambda x: x[[0, 1]][:, 0], False), (lambda x: x[[0, 1]][:, 0], False),
# Not implemented, complex flavors of advanced indexing # Not implemented, complex flavors of advanced indexing
(lambda x: x[:, None, [0, 1]][0], False),
(lambda x: x[:, 5:, [0, 1]][0], False), (lambda x: x[:, 5:, [0, 1]][0], False),
(lambda x: x[:, :, np.array([True, False, False])][0], False), (lambda x: x[:, :, np.array([True, False, False])][0], False),
(lambda x: x[[0, 1], :, [0, 1]][:, 0], False), (lambda x: x[[0, 1], :, [0, 1]][:, 0], False),
......
...@@ -31,6 +31,8 @@ from pytensor.tensor.blockwise import ( ...@@ -31,6 +31,8 @@ from pytensor.tensor.blockwise import (
vectorize_node_fallback, vectorize_node_fallback,
) )
from pytensor.tensor.nlinalg import MatrixInverse, eig from pytensor.tensor.nlinalg import MatrixInverse, eig
from pytensor.tensor.random import normal
from pytensor.tensor.random.op import default_rng
from pytensor.tensor.rewriting.blas import specialize_matmul_to_batched_dot from pytensor.tensor.rewriting.blas import specialize_matmul_to_batched_dot
from pytensor.tensor.signal import convolve1d from pytensor.tensor.signal import convolve1d
from pytensor.tensor.slinalg import ( from pytensor.tensor.slinalg import (
...@@ -114,16 +116,18 @@ def test_vectorize_blockwise(): ...@@ -114,16 +116,18 @@ def test_vectorize_blockwise():
def test_vectorize_node_fallback_unsupported_type(): def test_vectorize_node_fallback_unsupported_type():
x = tensor("x", shape=(2, 6)) rng = default_rng()
node = x[:, [0, 2, 4]].owner node = normal(rng=rng).owner
with pytest.raises( with pytest.raises(
NotImplementedError, NotImplementedError,
match=re.escape( match=re.escape(
"Cannot vectorize node AdvancedSubtensor(x, MakeSlice.0, [0 2 4]) with input MakeSlice.0 of type slice" 'Cannot vectorize node normal_rv{"(),()->()"}('
"DefaultGeneratorMakerOp.0, NoneConst{None}, 0.0, 1.0)"
" with input DefaultGeneratorMakerOp.0 of type RandomGeneratorType"
), ),
): ):
vectorize_node_fallback(node.op, node, node.inputs) vectorize_node_fallback(node.op, node, *node.inputs)
def check_blockwise_runtime_broadcasting(mode): def check_blockwise_runtime_broadcasting(mode):
......
...@@ -4,30 +4,8 @@ import pytensor ...@@ -4,30 +4,8 @@ import pytensor
from pytensor import as_symbolic from pytensor import as_symbolic
from pytensor.graph.basic import Constant from pytensor.graph.basic import Constant
from pytensor.tensor.math import argmax from pytensor.tensor.math import argmax
from pytensor.tensor.type import iscalar, vector from pytensor.tensor.type import vector
from pytensor.tensor.type_other import ( from pytensor.tensor.type_other import NoneConst, NoneTypeT
MakeSlice,
NoneConst,
NoneTypeT,
SliceConstant,
SliceType,
make_slice,
)
def test_SliceType():
st = SliceType()
assert st == st.clone()
def test_make_slice_merge():
# In the past, this was crahsing during compilation.
i = iscalar()
s1 = make_slice(0, i)
s2 = make_slice(0, i)
f = pytensor.function([i], [s1, s2])
nodes = f.maker.fgraph.apply_nodes
assert len([n for n in nodes if isinstance(n.op, MakeSlice)]) == 1
def test_none_Constant(): def test_none_Constant():
...@@ -47,8 +25,6 @@ def test_none_Constant(): ...@@ -47,8 +25,6 @@ def test_none_Constant():
# This trigger equals that returned the wrong answer in the past. # This trigger equals that returned the wrong answer in the past.
import pickle import pickle
import pytensor
x = vector("x") x = vector("x")
y = argmax(x) y = argmax(x)
kwargs = {} kwargs = {}
...@@ -60,11 +36,18 @@ def test_none_Constant(): ...@@ -60,11 +36,18 @@ def test_none_Constant():
def test_as_symbolic(): def test_as_symbolic():
# Remove this when xtensor is not using symbolic slices
from pytensor.tensor.type import iscalar
from pytensor.tensor.type_other import SliceConstant, slicetype
res = as_symbolic(None) res = as_symbolic(None)
assert res is NoneConst assert res is NoneConst
res = as_symbolic(slice(iscalar()))
assert res.owner.op == make_slice
res = as_symbolic(slice(1, 2)) res = as_symbolic(slice(1, 2))
assert isinstance(res, SliceConstant) assert isinstance(res, SliceConstant)
assert res.type == slicetype
assert res.data == slice(1, 2)
i = iscalar()
res = as_symbolic(slice(i))
assert res.owner is not None
...@@ -35,7 +35,7 @@ from pytensor.tensor.type import ( ...@@ -35,7 +35,7 @@ from pytensor.tensor.type import (
scalar, scalar,
tensor3, tensor3,
) )
from pytensor.tensor.type_other import MakeSlice, NoneConst from pytensor.tensor.type_other import NoneConst
from pytensor.tensor.variable import ( from pytensor.tensor.variable import (
DenseTensorConstant, DenseTensorConstant,
DenseTensorVariable, DenseTensorVariable,
...@@ -232,11 +232,11 @@ def test__getitem__AdvancedSubtensor(): ...@@ -232,11 +232,11 @@ def test__getitem__AdvancedSubtensor():
z = x[:, i] z = x[:, i]
op_types = [type(node.op) for node in io_toposort([x, i], [z])] op_types = [type(node.op) for node in io_toposort([x, i], [z])]
assert op_types == [MakeSlice, AdvancedSubtensor] assert op_types == [AdvancedSubtensor]
z = x[..., i, None] z = x[..., i, None]
op_types = [type(node.op) for node in io_toposort([x, i], [z])] op_types = [type(node.op) for node in io_toposort([x, i], [z])]
assert op_types == [MakeSlice, AdvancedSubtensor] assert op_types == [DimShuffle, AdvancedSubtensor]
z = x[i, None] z = x[i, None]
op_types = [type(node.op) for node in io_toposort([x, i], [z])] op_types = [type(node.op) for node in io_toposort([x, i], [z])]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论