提交 cc6bed1a authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Revert "Refactor AdvancedSubtensor"

This reverts commit db7fa079.
上级 03afa5bb
...@@ -771,9 +771,9 @@ class DestroyHandler(Bookkeeper): ...@@ -771,9 +771,9 @@ class DestroyHandler(Bookkeeper):
} }
tolerated.add(destroyed_idx) tolerated.add(destroyed_idx)
tolerate_aliased = getattr( tolerate_aliased = getattr(
app.op, "destroyhandler_tolerate_aliased", () app.op, "destroyhandler_tolerate_aliased", []
) )
assert isinstance(tolerate_aliased, tuple | list) assert isinstance(tolerate_aliased, list)
ignored = { ignored = {
idx1 for idx0, idx1 in tolerate_aliased if idx0 == destroyed_idx idx1 for idx0, idx1 in tolerate_aliased if idx0 == destroyed_idx
} }
......
...@@ -8,6 +8,7 @@ from pytensor.tensor.subtensor import ( ...@@ -8,6 +8,7 @@ from pytensor.tensor.subtensor import (
Subtensor, Subtensor,
indices_from_subtensor, indices_from_subtensor,
) )
from pytensor.tensor.type_other import MakeSlice
BOOLEAN_MASK_ERROR = """JAX does not support resizing arrays with boolean BOOLEAN_MASK_ERROR = """JAX does not support resizing arrays with boolean
...@@ -34,8 +35,10 @@ slice length. ...@@ -34,8 +35,10 @@ slice length.
@jax_funcify.register(AdvancedSubtensor) @jax_funcify.register(AdvancedSubtensor)
@jax_funcify.register(AdvancedSubtensor1) @jax_funcify.register(AdvancedSubtensor1)
def jax_funcify_Subtensor(op, node, **kwargs): def jax_funcify_Subtensor(op, node, **kwargs):
idx_list = getattr(op, "idx_list", None)
def subtensor(x, *ilists): def subtensor(x, *ilists):
indices = indices_from_subtensor(ilists, op.idx_list) indices = indices_from_subtensor(ilists, idx_list)
if len(indices) == 1: if len(indices) == 1:
indices = indices[0] indices = indices[0]
...@@ -45,9 +48,10 @@ def jax_funcify_Subtensor(op, node, **kwargs): ...@@ -45,9 +48,10 @@ def jax_funcify_Subtensor(op, node, **kwargs):
@jax_funcify.register(IncSubtensor) @jax_funcify.register(IncSubtensor)
@jax_funcify.register(AdvancedIncSubtensor)
@jax_funcify.register(AdvancedIncSubtensor1) @jax_funcify.register(AdvancedIncSubtensor1)
def jax_funcify_IncSubtensor(op, node, **kwargs): def jax_funcify_IncSubtensor(op, node, **kwargs):
idx_list = getattr(op, "idx_list", None)
if getattr(op, "set_instead_of_inc", False): if getattr(op, "set_instead_of_inc", False):
def jax_fn(x, indices, y): def jax_fn(x, indices, y):
...@@ -58,7 +62,7 @@ def jax_funcify_IncSubtensor(op, node, **kwargs): ...@@ -58,7 +62,7 @@ def jax_funcify_IncSubtensor(op, node, **kwargs):
def jax_fn(x, indices, y): def jax_fn(x, indices, y):
return x.at[indices].add(y) return x.at[indices].add(y)
def incsubtensor(x, y, *ilist, jax_fn=jax_fn, idx_list=op.idx_list): def incsubtensor(x, y, *ilist, jax_fn=jax_fn, idx_list=idx_list):
indices = indices_from_subtensor(ilist, idx_list) indices = indices_from_subtensor(ilist, idx_list)
if len(indices) == 1: if len(indices) == 1:
indices = indices[0] indices = indices[0]
...@@ -69,3 +73,29 @@ def jax_funcify_IncSubtensor(op, node, **kwargs): ...@@ -69,3 +73,29 @@ def jax_funcify_IncSubtensor(op, node, **kwargs):
return jax_fn(x, indices, y) return jax_fn(x, indices, y)
return incsubtensor return incsubtensor
@jax_funcify.register(AdvancedIncSubtensor)
def jax_funcify_AdvancedIncSubtensor(op, node, **kwargs):
if getattr(op, "set_instead_of_inc", False):
def jax_fn(x, indices, y):
return x.at[indices].set(y)
else:
def jax_fn(x, indices, y):
return x.at[indices].add(y)
def advancedincsubtensor(x, y, *ilist, jax_fn=jax_fn):
return jax_fn(x, ilist, y)
return advancedincsubtensor
@jax_funcify.register(MakeSlice)
def jax_funcify_MakeSlice(op, **kwargs):
def makeslice(*x):
return slice(*x)
return makeslice
...@@ -10,14 +10,15 @@ from pytensor.tensor.subtensor import ( ...@@ -10,14 +10,15 @@ from pytensor.tensor.subtensor import (
Subtensor, Subtensor,
indices_from_subtensor, indices_from_subtensor,
) )
from pytensor.tensor.type_other import MakeSlice
@mlx_funcify.register(Subtensor) @mlx_funcify.register(Subtensor)
def mlx_funcify_Subtensor(op, node, **kwargs): def mlx_funcify_Subtensor(op, node, **kwargs):
idx_list = getattr(op, "idx_list", None)
def subtensor(x, *ilists): def subtensor(x, *ilists):
indices = indices_from_subtensor( indices = indices_from_subtensor([int(element) for element in ilists], idx_list)
[int(element) for element in ilists], op.idx_list
)
if len(indices) == 1: if len(indices) == 1:
indices = indices[0] indices = indices[0]
...@@ -29,8 +30,10 @@ def mlx_funcify_Subtensor(op, node, **kwargs): ...@@ -29,8 +30,10 @@ def mlx_funcify_Subtensor(op, node, **kwargs):
@mlx_funcify.register(AdvancedSubtensor) @mlx_funcify.register(AdvancedSubtensor)
@mlx_funcify.register(AdvancedSubtensor1) @mlx_funcify.register(AdvancedSubtensor1)
def mlx_funcify_AdvancedSubtensor(op, node, **kwargs): def mlx_funcify_AdvancedSubtensor(op, node, **kwargs):
idx_list = getattr(op, "idx_list", None)
def advanced_subtensor(x, *ilists): def advanced_subtensor(x, *ilists):
indices = indices_from_subtensor(ilists, op.idx_list) indices = indices_from_subtensor(ilists, idx_list)
if len(indices) == 1: if len(indices) == 1:
indices = indices[0] indices = indices[0]
...@@ -42,6 +45,8 @@ def mlx_funcify_AdvancedSubtensor(op, node, **kwargs): ...@@ -42,6 +45,8 @@ def mlx_funcify_AdvancedSubtensor(op, node, **kwargs):
@mlx_funcify.register(IncSubtensor) @mlx_funcify.register(IncSubtensor)
@mlx_funcify.register(AdvancedIncSubtensor1) @mlx_funcify.register(AdvancedIncSubtensor1)
def mlx_funcify_IncSubtensor(op, node, **kwargs): def mlx_funcify_IncSubtensor(op, node, **kwargs):
idx_list = getattr(op, "idx_list", None)
if getattr(op, "set_instead_of_inc", False): if getattr(op, "set_instead_of_inc", False):
def mlx_fn(x, indices, y): def mlx_fn(x, indices, y):
...@@ -58,7 +63,7 @@ def mlx_funcify_IncSubtensor(op, node, **kwargs): ...@@ -58,7 +63,7 @@ def mlx_funcify_IncSubtensor(op, node, **kwargs):
x[indices] += y x[indices] += y
return x return x
def incsubtensor(x, y, *ilist, mlx_fn=mlx_fn, idx_list=op.idx_list): def incsubtensor(x, y, *ilist, mlx_fn=mlx_fn, idx_list=idx_list):
indices = indices_from_subtensor(ilist, idx_list) indices = indices_from_subtensor(ilist, idx_list)
if len(indices) == 1: if len(indices) == 1:
indices = indices[0] indices = indices[0]
...@@ -90,3 +95,11 @@ def mlx_funcify_AdvancedIncSubtensor(op, node, **kwargs): ...@@ -90,3 +95,11 @@ def mlx_funcify_AdvancedIncSubtensor(op, node, **kwargs):
return mlx_fn(x, ilist, y) return mlx_fn(x, ilist, y)
return advancedincsubtensor return advancedincsubtensor
@mlx_funcify.register(MakeSlice)
def mlx_funcify_MakeSlice(op, **kwargs):
def makeslice(*x):
return slice(*x)
return makeslice
...@@ -9,6 +9,7 @@ from pytensor.tensor.subtensor import ( ...@@ -9,6 +9,7 @@ from pytensor.tensor.subtensor import (
Subtensor, Subtensor,
indices_from_subtensor, indices_from_subtensor,
) )
from pytensor.tensor.type_other import MakeSlice, SliceType
def check_negative_steps(indices): def check_negative_steps(indices):
...@@ -46,11 +47,23 @@ def pytorch_funcify_Subtensor(op, node, **kwargs): ...@@ -46,11 +47,23 @@ def pytorch_funcify_Subtensor(op, node, **kwargs):
return subtensor return subtensor
@pytorch_funcify.register(MakeSlice)
def pytorch_funcify_makeslice(op, **kwargs):
def makeslice(start, stop, step):
# Torch does not like numpy integers in indexing slices
return slice(
None if start is None else int(start),
None if stop is None else int(stop),
None if step is None else int(step),
)
return makeslice
@pytorch_funcify.register(AdvancedSubtensor1) @pytorch_funcify.register(AdvancedSubtensor1)
@pytorch_funcify.register(AdvancedSubtensor) @pytorch_funcify.register(AdvancedSubtensor)
def pytorch_funcify_AdvSubtensor(op, node, **kwargs): def pytorch_funcify_AdvSubtensor(op, node, **kwargs):
def advsubtensor(x, *indices): def advsubtensor(x, *indices):
indices = indices_from_subtensor(indices, op.idx_list)
check_negative_steps(indices) check_negative_steps(indices)
return x[indices] return x[indices]
...@@ -89,14 +102,12 @@ def pytorch_funcify_IncSubtensor(op, node, **kwargs): ...@@ -89,14 +102,12 @@ def pytorch_funcify_IncSubtensor(op, node, **kwargs):
@pytorch_funcify.register(AdvancedIncSubtensor) @pytorch_funcify.register(AdvancedIncSubtensor)
@pytorch_funcify.register(AdvancedIncSubtensor1) @pytorch_funcify.register(AdvancedIncSubtensor1)
def pytorch_funcify_AdvancedIncSubtensor(op, node, **kwargs): def pytorch_funcify_AdvancedIncSubtensor(op, node, **kwargs):
idx_list = op.idx_list
inplace = op.inplace inplace = op.inplace
ignore_duplicates = getattr(op, "ignore_duplicates", False) ignore_duplicates = getattr(op, "ignore_duplicates", False)
if op.set_instead_of_inc: if op.set_instead_of_inc:
def adv_set_subtensor(x, y, *flattened_indices): def adv_set_subtensor(x, y, *indices):
indices = indices_from_subtensor(flattened_indices, idx_list)
check_negative_steps(indices) check_negative_steps(indices)
if isinstance(op, AdvancedIncSubtensor1): if isinstance(op, AdvancedIncSubtensor1):
op._check_runtime_broadcasting(node, x, y, indices) op._check_runtime_broadcasting(node, x, y, indices)
...@@ -109,8 +120,7 @@ def pytorch_funcify_AdvancedIncSubtensor(op, node, **kwargs): ...@@ -109,8 +120,7 @@ def pytorch_funcify_AdvancedIncSubtensor(op, node, **kwargs):
elif ignore_duplicates: elif ignore_duplicates:
def adv_inc_subtensor_no_duplicates(x, y, *flattened_indices): def adv_inc_subtensor_no_duplicates(x, y, *indices):
indices = indices_from_subtensor(flattened_indices, idx_list)
check_negative_steps(indices) check_negative_steps(indices)
if isinstance(op, AdvancedIncSubtensor1): if isinstance(op, AdvancedIncSubtensor1):
op._check_runtime_broadcasting(node, x, y, indices) op._check_runtime_broadcasting(node, x, y, indices)
...@@ -122,14 +132,13 @@ def pytorch_funcify_AdvancedIncSubtensor(op, node, **kwargs): ...@@ -122,14 +132,13 @@ def pytorch_funcify_AdvancedIncSubtensor(op, node, **kwargs):
return adv_inc_subtensor_no_duplicates return adv_inc_subtensor_no_duplicates
else: else:
if any(isinstance(entry, slice) for entry in idx_list): if any(isinstance(idx.type, SliceType) for idx in node.inputs[2:]):
raise NotImplementedError( raise NotImplementedError(
"IncSubtensor with potential duplicates indexes and slice indexing not implemented in PyTorch" "IncSubtensor with potential duplicates indexes and slice indexing not implemented in PyTorch"
) )
def adv_inc_subtensor(x, y, *flattened_indices): def adv_inc_subtensor(x, y, *indices):
indices = indices_from_subtensor(flattened_indices, idx_list) # Not needed because slices aren't supported
# Not needed because slices aren't supported in this path
# check_negative_steps(indices) # check_negative_steps(indices)
if not inplace: if not inplace:
x = x.clone() x = x.clone()
......
...@@ -72,9 +72,9 @@ from pytensor.tensor.shape import shape ...@@ -72,9 +72,9 @@ from pytensor.tensor.shape import shape
from pytensor.tensor.subtensor import ( from pytensor.tensor.subtensor import (
IncSubtensor, IncSubtensor,
Subtensor, Subtensor,
basic_subtensor,
get_canonical_form_slice, get_canonical_form_slice,
get_idx_list, get_idx_list,
get_slice_elements,
set_subtensor, set_subtensor,
) )
from pytensor.tensor.variable import TensorConstant, TensorVariable from pytensor.tensor.variable import TensorConstant, TensorVariable
...@@ -1211,7 +1211,7 @@ def _is_default_scan_buffer(final_buffer: TensorVariable, taps: int) -> bool: ...@@ -1211,7 +1211,7 @@ def _is_default_scan_buffer(final_buffer: TensorVariable, taps: int) -> bool:
if not ( if not (
isinstance(op, IncSubtensor) isinstance(op, IncSubtensor)
and op.set_instead_of_inc and op.set_instead_of_inc
and op.idx_list == (slice(None, 0),) and op.idx_list == [slice(None, ps.int64)]
): ):
return False return False
...@@ -1389,6 +1389,12 @@ def scan_save_mem_rewrite(fgraph, node, backend_supports_output_pre_allocation: ...@@ -1389,6 +1389,12 @@ def scan_save_mem_rewrite(fgraph, node, backend_supports_output_pre_allocation:
else: else:
# 2.3.1 extract idx list of subtensor # 2.3.1 extract idx list of subtensor
this_slice = get_idx_list(cl.inputs, cl.op.idx_list) this_slice = get_idx_list(cl.inputs, cl.op.idx_list)
if this_slice is None:
# if unable to extract idx_list
# => outputs needs all its intermediate values
global_nsteps = None
slices[i] = None
break
# 2.3.2 extract the begin/end of the first dimension # 2.3.2 extract the begin/end of the first dimension
if i >= op_info.n_mit_mot: if i >= op_info.n_mit_mot:
...@@ -1481,6 +1487,9 @@ def scan_save_mem_rewrite(fgraph, node, backend_supports_output_pre_allocation: ...@@ -1481,6 +1487,9 @@ def scan_save_mem_rewrite(fgraph, node, backend_supports_output_pre_allocation:
break break
else: else:
this_slice = get_idx_list(cl.inputs, cl.op.idx_list) this_slice = get_idx_list(cl.inputs, cl.op.idx_list)
if this_slice is None:
store_steps[i] = 0
break
if isinstance(this_slice[0], slice): if isinstance(this_slice[0], slice):
start = this_slice[0].start start = this_slice[0].start
...@@ -1702,9 +1711,16 @@ def scan_save_mem_rewrite(fgraph, node, backend_supports_output_pre_allocation: ...@@ -1702,9 +1711,16 @@ def scan_save_mem_rewrite(fgraph, node, backend_supports_output_pre_allocation:
) )
else: else:
fslice = sanitize(cnf_slice[0]) fslice = sanitize(cnf_slice[0])
nw_slice = (fslice, *old_slices[1:])
nw_pos = inv_compress_map[idx] nw_pos = inv_compress_map[idx]
new_o = basic_subtensor(new_outs[nw_pos], fslice, *old_slices[1:])
subtens = Subtensor(nw_slice)
# slice inputs
sl_ins = get_slice_elements(
nw_slice, lambda entry: isinstance(entry, Variable)
)
new_o = cast(TensorVariable, subtens(new_outs[nw_pos], *sl_ins))
if new_o.ndim > 0: if new_o.ndim > 0:
new_o = new_o[:: cnf_slice[1]] new_o = new_o[:: cnf_slice[1]]
replaced_outs.append(idx) replaced_outs.append(idx)
...@@ -1755,7 +1771,11 @@ def scan_save_mem_rewrite(fgraph, node, backend_supports_output_pre_allocation: ...@@ -1755,7 +1771,11 @@ def scan_save_mem_rewrite(fgraph, node, backend_supports_output_pre_allocation:
) )
nw_slice = (sanitize(position), *old_slices[1:]) nw_slice = (sanitize(position), *old_slices[1:])
new_o = basic_subtensor(new_outs[nw_pos], *nw_slice) subtens = Subtensor(nw_slice)
sl_ins = get_slice_elements(
nw_slice, lambda entry: isinstance(entry, Variable)
)
new_o = cast(TensorVariable, subtens(new_outs[nw_pos], *sl_ins))
if new_o.ndim > 0: if new_o.ndim > 0:
new_o = new_o[:: cnf_slice[1]] new_o = new_o[:: cnf_slice[1]]
old_new += [(old, new_o)] old_new += [(old, new_o)]
......
...@@ -29,7 +29,7 @@ from pytensor.graph.fg import FunctionGraph, Output ...@@ -29,7 +29,7 @@ from pytensor.graph.fg import FunctionGraph, Output
from pytensor.graph.op import Op from pytensor.graph.op import Op
from pytensor.graph.replace import _vectorize_node from pytensor.graph.replace import _vectorize_node
from pytensor.graph.rewriting.db import EquilibriumDB from pytensor.graph.rewriting.db import EquilibriumDB
from pytensor.graph.type import HasShape from pytensor.graph.type import HasShape, Type
from pytensor.link.c.op import COp from pytensor.link.c.op import COp
from pytensor.link.c.params_type import ParamsType from pytensor.link.c.params_type import ParamsType
from pytensor.printing import Printer, min_informative_str, pprint, set_precedence from pytensor.printing import Printer, min_informative_str, pprint, set_precedence
...@@ -433,7 +433,7 @@ def _get_underlying_scalar_constant_value( ...@@ -433,7 +433,7 @@ def _get_underlying_scalar_constant_value(
var.ndim == 1 for var in v.owner.inputs[0].owner.inputs[1:] var.ndim == 1 for var in v.owner.inputs[0].owner.inputs[1:]
): ):
idx = v.owner.op.idx_list[0] idx = v.owner.op.idx_list[0]
if isinstance(idx, int): if isinstance(idx, Type):
idx = _get_underlying_scalar_constant_value( idx = _get_underlying_scalar_constant_value(
v.owner.inputs[1], max_recur=max_recur v.owner.inputs[1], max_recur=max_recur
) )
...@@ -467,7 +467,7 @@ def _get_underlying_scalar_constant_value( ...@@ -467,7 +467,7 @@ def _get_underlying_scalar_constant_value(
and len(v.owner.op.idx_list) == 1 and len(v.owner.op.idx_list) == 1
): ):
idx = v.owner.op.idx_list[0] idx = v.owner.op.idx_list[0]
if isinstance(idx, int): if isinstance(idx, Type):
idx = _get_underlying_scalar_constant_value( idx = _get_underlying_scalar_constant_value(
v.owner.inputs[1], max_recur=max_recur v.owner.inputs[1], max_recur=max_recur
) )
...@@ -488,7 +488,7 @@ def _get_underlying_scalar_constant_value( ...@@ -488,7 +488,7 @@ def _get_underlying_scalar_constant_value(
op = owner.op op = owner.op
idx_list = op.idx_list idx_list = op.idx_list
idx = idx_list[0] idx = idx_list[0]
if isinstance(idx, int): if isinstance(idx, Type):
idx = _get_underlying_scalar_constant_value( idx = _get_underlying_scalar_constant_value(
owner.inputs[1], max_recur=max_recur owner.inputs[1], max_recur=max_recur
) )
......
...@@ -23,7 +23,7 @@ from pytensor.tensor.subtensor import ( ...@@ -23,7 +23,7 @@ from pytensor.tensor.subtensor import (
indices_from_subtensor, indices_from_subtensor,
) )
from pytensor.tensor.type import integer_dtypes from pytensor.tensor.type import integer_dtypes
from pytensor.tensor.type_other import NoneTypeT from pytensor.tensor.type_other import NoneTypeT, SliceType
def is_rv_used_in_graph(base_rv, node, fgraph): def is_rv_used_in_graph(base_rv, node, fgraph):
...@@ -237,14 +237,19 @@ def local_subtensor_rv_lift(fgraph, node): ...@@ -237,14 +237,19 @@ def local_subtensor_rv_lift(fgraph, node):
return False return False
# Parse indices # Parse indices
if isinstance(subtensor_op, Subtensor | AdvancedSubtensor): if isinstance(subtensor_op, Subtensor):
indices = indices_from_subtensor(node.inputs[1:], subtensor_op.idx_list) indices = indices_from_subtensor(node.inputs[1:], subtensor_op.idx_list)
else: else:
indices = node.inputs[1:] indices = node.inputs[1:]
# The rewrite doesn't apply if advanced indexing could broadcast the samples (leading to duplicates) # The rewrite doesn't apply if advanced indexing could broadcast the samples (leading to duplicates)
# Note: For simplicity this also excludes subtensor-related expand_dims (np.newaxis).
# If we wanted to support that we could rewrite it as subtensor + dimshuffle
# and make use of the dimshuffle lift rewrite
# TODO: This rewrite is aborting with dummy indexing dimensions which aren't a problem # TODO: This rewrite is aborting with dummy indexing dimensions which aren't a problem
# (e.g., x[[0],] is equivalent to x[0] - can only index one entry, won't lead to duplicates) if any(
if any(is_nd_advanced_idx(idx, integer_dtypes) for idx in indices): is_nd_advanced_idx(idx, integer_dtypes) or isinstance(idx.type, NoneTypeT)
for idx in indices
):
return False return False
# Check that indexing does not act on support dims # Check that indexing does not act on support dims
...@@ -263,7 +268,10 @@ def local_subtensor_rv_lift(fgraph, node): ...@@ -263,7 +268,10 @@ def local_subtensor_rv_lift(fgraph, node):
non_bool_indices[batch_ndims:], non_bool_indices[batch_ndims:],
) )
for idx in supp_indices: for idx in supp_indices:
if idx != slice(None): if not (
isinstance(idx.type, SliceType)
and all(isinstance(i.type, NoneTypeT) for i in idx.owner.inputs)
):
return False return False
n_discarded_idxs = len(supp_indices) n_discarded_idxs = len(supp_indices)
indices = indices[:-n_discarded_idxs] indices = indices[:-n_discarded_idxs]
...@@ -323,7 +331,7 @@ def local_subtensor_rv_lift(fgraph, node): ...@@ -323,7 +331,7 @@ def local_subtensor_rv_lift(fgraph, node):
# Broadcasted dim # Broadcasted dim
if curr_dim in bcast_param_dims: if curr_dim in bcast_param_dims:
# Slice indexing, keep degenerate dim by none-slicing # Slice indexing, keep degenerate dim by none-slicing
if isinstance(idx, slice): if isinstance(idx, slice) or isinstance(idx.type, SliceType):
batch_indices.append(slice(None)) batch_indices.append(slice(None))
# Integer indexing, drop degenerate dim by 0-indexing # Integer indexing, drop degenerate dim by 0-indexing
else: else:
......
...@@ -17,6 +17,7 @@ from pytensor.graph.rewriting.basic import ( ...@@ -17,6 +17,7 @@ from pytensor.graph.rewriting.basic import (
) )
from pytensor.graph.traversal import ancestors from pytensor.graph.traversal import ancestors
from pytensor.graph.utils import InconsistencyError, get_variable_trace_string from pytensor.graph.utils import InconsistencyError, get_variable_trace_string
from pytensor.scalar import ScalarType
from pytensor.tensor.basic import ( from pytensor.tensor.basic import (
MakeVector, MakeVector,
as_tensor_variable, as_tensor_variable,
...@@ -841,16 +842,13 @@ def _is_shape_i_of_x( ...@@ -841,16 +842,13 @@ def _is_shape_i_of_x(
if isinstance(var.owner.op, Shape_i): if isinstance(var.owner.op, Shape_i):
return (var.owner.op.i == i) and (var.owner.inputs[0] == x) # type: ignore return (var.owner.op.i == i) and (var.owner.inputs[0] == x) # type: ignore
# Match Subtensor((int,))(Shape(input), i) - single integer index into shape # Match Subtensor((ScalarType,))(Shape(input), i)
if isinstance(var.owner.op, Subtensor): if isinstance(var.owner.op, Subtensor):
idx_entry = (
var.owner.op.idx_list[0] if len(var.owner.op.idx_list) == 1 else None
)
return ( return (
# Check we have integer indexing operation # Check we have integer indexing operation
# (and not slice or multiple indexing) # (and not slice or multiple indexing)
len(var.owner.op.idx_list) == 1 len(var.owner.op.idx_list) == 1
and isinstance(idx_entry, int) and isinstance(var.owner.op.idx_list[0], ScalarType)
# Check we are indexing on the shape of x # Check we are indexing on the shape of x
and var.owner.inputs[0].owner is not None and var.owner.inputs[0].owner is not None
and isinstance(var.owner.inputs[0].owner.op, Shape) and isinstance(var.owner.inputs[0].owner.op, Shape)
......
...@@ -8,6 +8,7 @@ from pytensor import Variable ...@@ -8,6 +8,7 @@ from pytensor import Variable
from pytensor.compile import optdb from pytensor.compile import optdb
from pytensor.graph import Constant, FunctionGraph, node_rewriter, vectorize_graph from pytensor.graph import Constant, FunctionGraph, node_rewriter, vectorize_graph
from pytensor.graph.rewriting.basic import NodeRewriter, copy_stack_trace from pytensor.graph.rewriting.basic import NodeRewriter, copy_stack_trace
from pytensor.scalar import basic as ps
from pytensor.tensor.basic import ( from pytensor.tensor.basic import (
Alloc, Alloc,
Join, Join,
...@@ -30,7 +31,7 @@ from pytensor.tensor.rewriting.basic import ( ...@@ -30,7 +31,7 @@ from pytensor.tensor.rewriting.basic import (
register_stabilize, register_stabilize,
) )
from pytensor.tensor.rewriting.elemwise import local_dimshuffle_lift from pytensor.tensor.rewriting.elemwise import local_dimshuffle_lift
from pytensor.tensor.rewriting.subtensor import register_useless from pytensor.tensor.rewriting.subtensor import is_full_slice, register_useless
from pytensor.tensor.shape import ( from pytensor.tensor.shape import (
Shape, Shape,
SpecifyShape, SpecifyShape,
...@@ -49,6 +50,7 @@ from pytensor.tensor.subtensor import ( ...@@ -49,6 +50,7 @@ from pytensor.tensor.subtensor import (
indices_from_subtensor, indices_from_subtensor,
) )
from pytensor.tensor.type import TensorType from pytensor.tensor.type import TensorType
from pytensor.tensor.type_other import NoneTypeT, SliceType
from pytensor.tensor.variable import TensorVariable from pytensor.tensor.variable import TensorVariable
...@@ -69,7 +71,7 @@ def _axis_is_indexed_by_basic_index( ...@@ -69,7 +71,7 @@ def _axis_is_indexed_by_basic_index(
) -> bool: ) -> bool:
if isinstance(axis, int): if isinstance(axis, int):
axis = (axis,) axis = (axis,)
return any(ax < len(idxs) and not idxs[ax] == slice(None) for ax in axis) return any(ax < len(idxs) and not is_full_slice(idxs[ax]) for ax in axis)
def _lift_subtensor_non_axis( def _lift_subtensor_non_axis(
...@@ -81,7 +83,7 @@ def _lift_subtensor_non_axis( ...@@ -81,7 +83,7 @@ def _lift_subtensor_non_axis(
old_subtensor_variable: TensorVariable, old_subtensor_variable: TensorVariable,
) -> None | list[TensorVariable]: ) -> None | list[TensorVariable]:
# Apply generic subtensor lift rewrite along "non-axis" dimensions # Apply generic subtensor lift rewrite along "non-axis" dimensions
real_indices = [idx for idx in idx_tuple if not idx == slice(None)] real_indices = [idx for idx in idx_tuple if not is_full_slice(idx)]
if len(real_indices) > 1 and variable.type.ndim > 1: if len(real_indices) > 1 and variable.type.ndim > 1:
# Split the subtensor # Split the subtensor
idx_to_keep = idx_tuple[axis] idx_to_keep = idx_tuple[axis]
...@@ -204,7 +206,7 @@ def local_subtensor_of_batch_dims(fgraph, node): ...@@ -204,7 +206,7 @@ def local_subtensor_of_batch_dims(fgraph, node):
if len(idx_tuple) > batch_ndim: if len(idx_tuple) > batch_ndim:
# Indexing on core dimensions of Blockwise. We split the indices and lift the batch ones only # Indexing on core dimensions of Blockwise. We split the indices and lift the batch ones only
batch_indices, core_indices = idx_tuple[:batch_ndim], idx_tuple[batch_ndim:] batch_indices, core_indices = idx_tuple[:batch_ndim], idx_tuple[batch_ndim:]
if all(idx == slice(None) for idx in batch_indices): if all(is_full_slice(idx) for idx in batch_indices):
# No batch indices, nothing to do # No batch indices, nothing to do
return None return None
elem_with_batch_indices = elem[batch_indices] elem_with_batch_indices = elem[batch_indices]
...@@ -238,7 +240,7 @@ def local_subtensor_of_batch_dims(fgraph, node): ...@@ -238,7 +240,7 @@ def local_subtensor_of_batch_dims(fgraph, node):
strict=False, strict=False,
) )
): ):
if dim_idx == slice(None): if is_full_slice(dim_idx):
# Full slice can be safely applied to all inputs # Full slice can be safely applied to all inputs
continue continue
...@@ -427,7 +429,7 @@ def local_subtensor_of_expand_dims(fgraph, node): ...@@ -427,7 +429,7 @@ def local_subtensor_of_expand_dims(fgraph, node):
if i in expanded_axes: if i in expanded_axes:
if isinstance(idx_item, slice): if isinstance(idx_item, slice):
# Slice could be keeping or dropping this dimension # Slice could be keeping or dropping this dimension
if idx_item == slice(None): if is_full_slice(idx_item):
# A None slice, always keeps the dimension. # A None slice, always keeps the dimension.
# We skip the index, and later introduce the needed expand_dim # We skip the index, and later introduce the needed expand_dim
continue continue
...@@ -646,7 +648,10 @@ def local_subtensor_SpecifyShape_lift(fgraph, node): ...@@ -646,7 +648,10 @@ def local_subtensor_SpecifyShape_lift(fgraph, node):
indices = get_idx_list(node.inputs, node.op.idx_list) indices = get_idx_list(node.inputs, node.op.idx_list)
if any(isinstance(index, slice) for index in indices): if any(
isinstance(index, slice) or isinstance(getattr(index, "type", None), SliceType)
for index in indices
):
return False return False
new_obj_arg = obj_arg[indices] new_obj_arg = obj_arg[indices]
...@@ -697,12 +702,15 @@ def local_subtensor_make_vector(fgraph, node): ...@@ -697,12 +702,15 @@ def local_subtensor_make_vector(fgraph, node):
(idx,) = idxs (idx,) = idxs
if isinstance(idx, int): if isinstance(idx, ps.ScalarType | TensorType):
idx = node.inputs[1] old_idx, idx = idx, node.inputs[1]
assert idx.type.is_super(old_idx)
elif isinstance(node.op, AdvancedSubtensor1): elif isinstance(node.op, AdvancedSubtensor1):
idx = node.inputs[1] idx = node.inputs[1]
if isinstance(idx, Variable): if isinstance(idx, int | np.integer):
return [x.owner.inputs[idx]]
elif isinstance(idx, Variable):
if idx.ndim == 0: if idx.ndim == 0:
try: try:
v = get_underlying_scalar_constant_value( v = get_underlying_scalar_constant_value(
...@@ -825,6 +833,8 @@ def local_subtensor_shape_constant(fgraph, node): ...@@ -825,6 +833,8 @@ def local_subtensor_shape_constant(fgraph, node):
except NotScalarConstantError: except NotScalarConstantError:
return False return False
assert idx_val != np.newaxis
if not isinstance(shape_arg.type, TensorType): if not isinstance(shape_arg.type, TensorType):
return False return False
...@@ -861,24 +871,22 @@ def local_subtensor_of_adv_subtensor(fgraph, node): ...@@ -861,24 +871,22 @@ def local_subtensor_of_adv_subtensor(fgraph, node):
# AdvancedSubtensor involves a full_copy, so we don't want to do it twice # AdvancedSubtensor involves a full_copy, so we don't want to do it twice
return None return None
x, *adv_index_vars = adv_subtensor.owner.inputs x, *adv_idxs = adv_subtensor.owner.inputs
adv_idxs = indices_from_subtensor(adv_index_vars, adv_subtensor.owner.op.idx_list)
# Advanced indexing is a minefield, avoid all cases except for consecutive integer indices # Advanced indexing is a minefield, avoid all cases except for consecutive integer indices
if ( if any(
not all(
( (
(isinstance(adv_idx, TensorVariable) and adv_idx.type.dtype != "bool") isinstance(adv_idx.type, NoneTypeT)
or (isinstance(adv_idx, slice) and adv_idx == slice(None)) or (isinstance(adv_idx.type, TensorType) and adv_idx.type.dtype == "bool")
or (isinstance(adv_idx.type, SliceType) and not is_full_slice(adv_idx))
) )
for adv_idx in adv_idxs for adv_idx in adv_idxs
)
) or _non_consecutive_adv_indexing(adv_idxs): ) or _non_consecutive_adv_indexing(adv_idxs):
return None return None
for first_adv_idx_dim, adv_idx in enumerate(adv_idxs): for first_adv_idx_dim, adv_idx in enumerate(adv_idxs):
# We already made sure there were only None slices besides integer indexes # We already made sure there were only None slices besides integer indexes
if isinstance(adv_idx, TensorVariable): if isinstance(adv_idx.type, TensorType):
break break
else: # no-break else: # no-break
# Not sure if this should ever happen, but better safe than sorry # Not sure if this should ever happen, but better safe than sorry
...@@ -901,7 +909,7 @@ def local_subtensor_of_adv_subtensor(fgraph, node): ...@@ -901,7 +909,7 @@ def local_subtensor_of_adv_subtensor(fgraph, node):
copy_stack_trace([basic_subtensor, adv_subtensor], x_indexed) copy_stack_trace([basic_subtensor, adv_subtensor], x_indexed)
x_after_index_lift = expand_dims(x_indexed, dropped_dims) x_after_index_lift = expand_dims(x_indexed, dropped_dims)
x_after_adv_idx = adv_subtensor.owner.op(x_after_index_lift, *adv_index_vars) x_after_adv_idx = adv_subtensor.owner.op(x_after_index_lift, *adv_idxs)
copy_stack_trace([basic_subtensor, adv_subtensor], x_after_adv_idx) copy_stack_trace([basic_subtensor, adv_subtensor], x_after_adv_idx)
new_out = squeeze(x_after_adv_idx[basic_idxs_kept], dropped_dims) new_out = squeeze(x_after_adv_idx[basic_idxs_kept], dropped_dims)
......
...@@ -38,7 +38,7 @@ from pytensor.tensor.elemwise import CAReduce, DimShuffle ...@@ -38,7 +38,7 @@ from pytensor.tensor.elemwise import CAReduce, DimShuffle
from pytensor.tensor.math import Min, neg from pytensor.tensor.math import Min, neg
from pytensor.tensor.rewriting.basic import register_uncanonicalize from pytensor.tensor.rewriting.basic import register_uncanonicalize
from pytensor.tensor.shape import Reshape, reshape from pytensor.tensor.shape import Reshape, reshape
from pytensor.tensor.subtensor import Subtensor, indices_from_subtensor from pytensor.tensor.subtensor import Subtensor
@register_uncanonicalize @register_uncanonicalize
...@@ -193,42 +193,60 @@ def local_dimshuffle_subtensor(fgraph, node): ...@@ -193,42 +193,60 @@ def local_dimshuffle_subtensor(fgraph, node):
if not all(broadcastable[i] for i in missing_dims): if not all(broadcastable[i] for i in missing_dims):
return False return False
# create a new index tuple for a new Subtensor # create a new idx_list for a new Subtensor object
# Reconstruct the full indices from the subtensor node, then replace # have to loop on idx_list and inputs
# dimensions that are being dropped by dimshuffle with scalar index 0 # inputs has the length of sum of non None elements of idx_list
x = input_.owner.inputs[0] # (check in slice!).
indices = list( # len(missing_dims) can be < len(idx_list), this happens if
indices_from_subtensor( # tensor was indexed such as x[scalar, :, :], check that as well
input_.owner.inputs[1:], input_.owner.op.idx_list new_idx_list = list(input_.owner.op.idx_list)
) new_inputs = [input_.owner.inputs[0]]
)
zero = constant(0) zero = constant(0)
j = 0
# Track which output dimension each index corresponds to slice_i = -1
# Scalar indices remove dimensions, slices keep them subtensor_removed_dims = 0
output_dim = 0 for i, idx in enumerate(input_.owner.op.idx_list):
for i, idx in enumerate(indices):
if isinstance(idx, slice): if isinstance(idx, slice):
# This slice produces an output dimension slice_i += 1
if output_dim in missing_dims: if slice_i in missing_dims:
# This output dimension is being dropped, so replace slice with scalar # Missing dim is a slice(None), remove by indexing by 0
if idx == slice(None): if idx == slice(None):
indices[i] = zero new_idx_list[i] = zero
new_inputs += [zero]
# Missing dim is an ordinary slice with known output dim length of 1
# Remove by indexing by start
else: else:
# Use the start of the slice (or 0 if None) if idx.start is None:
indices[i] = idx.start if idx.start is not None else zero start = zero
output_dim += 1
# Scalar indices don't contribute to output dimensions
# Handle trailing dimensions that weren't explicitly indexed
for input_dim in range(len(indices), x.ndim):
if output_dim in missing_dims:
# This unindexed dimension is being dropped, index with 0
indices.append(zero)
else: else:
# This unindexed dimension is kept, index with slice(None) start = input_.owner.inputs[1 + j]
indices.append(slice(None)) j += 1
output_dim += 1 new_idx_list[i] = start
new_inputs += [start]
return [x[tuple(indices)]] # Ignore useless stop and step input if there is one
for slice_attr in ("stop", "step"):
if getattr(idx, slice_attr) is not None:
j += 1
# Keep non-dropped slice inputs
else:
for slice_attr in ("start", "stop", "step"):
if getattr(idx, slice_attr) is not None:
new_inputs += [input_.owner.inputs[1 + j]]
j += 1
# Keep non-dropped non-slice inputs
else:
new_inputs += [input_.owner.inputs[1 + j]]
j += 1
subtensor_removed_dims += 1
# Verify the trailing dimensions the subtensor didn't look at.
for idx in range(len(input_.owner.op.idx_list), new_inputs[0].ndim):
if (idx - subtensor_removed_dims) in missing_dims:
while len(new_idx_list) < idx:
new_idx_list.append(slice(None))
new_idx_list.append(zero)
new_inputs.append(zero)
return [Subtensor(new_idx_list)(*new_inputs)]
return False return False
...@@ -15,8 +15,9 @@ from pytensor.scalar import ( ...@@ -15,8 +15,9 @@ from pytensor.scalar import (
ComplexError, ComplexError,
) )
from pytensor.tensor import _get_vector_length from pytensor.tensor import _get_vector_length
from pytensor.tensor.exceptions import AdvancedIndexingError
from pytensor.tensor.type import TensorType from pytensor.tensor.type import TensorType
from pytensor.tensor.type_other import NoneTypeT from pytensor.tensor.type_other import NoneConst
from pytensor.tensor.utils import hash_from_ndarray from pytensor.tensor.utils import hash_from_ndarray
...@@ -454,14 +455,15 @@ class _tensor_py_operators: ...@@ -454,14 +455,15 @@ class _tensor_py_operators:
elif not isinstance(args, tuple): elif not isinstance(args, tuple):
args = (args,) args = (args,)
# Count the dimensions, check for bools and find ellipses.
ellipses = [] ellipses = []
index_dim_count = 0 index_dim_count = 0
for i, arg in enumerate(args): for i, arg in enumerate(args):
if arg is None or ( if arg is np.newaxis or arg is NoneConst:
isinstance(arg, Variable) and isinstance(arg.type, NoneTypeT) # no increase in index_dim_count
):
pass pass
elif arg is Ellipsis: elif arg is Ellipsis:
# no increase in index_dim_count
ellipses.append(i) ellipses.append(i)
elif ( elif (
isinstance(arg, np.ndarray | Variable) isinstance(arg, np.ndarray | Variable)
...@@ -503,41 +505,6 @@ class _tensor_py_operators: ...@@ -503,41 +505,6 @@ class _tensor_py_operators:
self.ndim - index_dim_count self.ndim - index_dim_count
) )
if any(
arg is None
or (isinstance(arg, Variable) and isinstance(arg.type, NoneTypeT))
for arg in args
):
expansion_axes = []
new_args = []
# Track dims consumed by args and inserted `None`s after ellipsis
counter = 0
nones = 0
for arg in args:
if arg is None or (
isinstance(arg, Variable) and isinstance(arg.type, NoneTypeT)
):
expansion_axes.append(counter + nones) # Expand here
nones += 1
new_args.append(slice(None))
else:
new_args.append(arg)
consumed = 1
if hasattr(arg, "dtype") and arg.dtype == "bool":
consumed = arg.ndim
counter += consumed
expanded = pt.expand_dims(self, expansion_axes)
if all(
isinstance(arg, slice)
and arg.start is None
and arg.stop is None
and arg.step is None
for arg in new_args
):
return expanded
return expanded[tuple(new_args)]
def is_empty_array(val): def is_empty_array(val):
return (isinstance(val, tuple | list) and len(val) == 0) or ( return (isinstance(val, tuple | list) and len(val) == 0) or (
isinstance(val, np.ndarray) and val.size == 0 isinstance(val, np.ndarray) and val.size == 0
...@@ -553,16 +520,74 @@ class _tensor_py_operators: ...@@ -553,16 +520,74 @@ class _tensor_py_operators:
for inp in args for inp in args
) )
if all( # Determine if advanced indexing is needed or not. The logic is
( # already in `index_vars_to_types`: if it succeeds, standard indexing is
isinstance(arg, slice | int | float | np.number) # used; if it fails with `AdvancedIndexingError`, advanced indexing is
or (hasattr(arg, "ndim") and arg.ndim == 0 and arg.dtype != "bool") # used
) advanced = False
for arg in args for i, arg in enumerate(args):
): if includes_bool(arg):
return pt.subtensor.basic_subtensor(self, *args) advanced = True
break
if arg is not np.newaxis and arg is not NoneConst:
try:
pt.subtensor.index_vars_to_types(arg)
except AdvancedIndexingError:
if advanced:
break
else: else:
advanced = True
if advanced:
return pt.subtensor.advanced_subtensor(self, *args) return pt.subtensor.advanced_subtensor(self, *args)
else:
if np.newaxis in args or NoneConst in args:
# `np.newaxis` (i.e. `None`) in NumPy indexing mean "add a new
# broadcastable dimension at this location". Since PyTensor adds
# new broadcastable dimensions via the `DimShuffle` `Op`, the
# following code uses said `Op` to add one of the new axes and
# then uses recursion to apply any other indices and add any
# remaining new axes.
counter = 0
pattern = []
new_args = []
for arg in args:
if arg is np.newaxis or arg is NoneConst:
pattern.append("x")
new_args.append(slice(None, None, None))
else:
pattern.append(counter)
counter += 1
new_args.append(arg)
pattern.extend(list(range(counter, self.ndim)))
view = self.dimshuffle(pattern)
full_slices = True
for arg in new_args:
# We can't do arg == slice(None, None, None) as in
# Python 2.7, this call __lt__ if we have a slice
# with some symbolic variable.
if not (
isinstance(arg, slice)
and (arg.start is None or arg.start is NoneConst)
and (arg.stop is None or arg.stop is NoneConst)
and (arg.step is None or arg.step is NoneConst)
):
full_slices = False
if full_slices:
return view
else:
return view.__getitem__(tuple(new_args))
else:
return pt.subtensor.Subtensor(args)(
self,
*pt.subtensor.get_slice_elements(
args, lambda entry: isinstance(entry, Variable)
),
)
def __setitem__(self, key, value): def __setitem__(self, key, value):
raise TypeError( raise TypeError(
......
...@@ -2,10 +2,9 @@ from itertools import zip_longest ...@@ -2,10 +2,9 @@ from itertools import zip_longest
from pytensor import as_symbolic from pytensor import as_symbolic
from pytensor.graph import Constant, node_rewriter from pytensor.graph import Constant, node_rewriter
from pytensor.tensor import arange, specify_shape from pytensor.tensor import TensorType, arange, specify_shape
from pytensor.tensor.subtensor import _non_consecutive_adv_indexing, inc_subtensor from pytensor.tensor.subtensor import _non_consecutive_adv_indexing, inc_subtensor
from pytensor.tensor.type_other import NoneTypeT, SliceType from pytensor.tensor.type_other import NoneTypeT, SliceType
from pytensor.tensor.variable import TensorVariable
from pytensor.xtensor.basic import tensor_from_xtensor, xtensor_from_tensor from pytensor.xtensor.basic import tensor_from_xtensor, xtensor_from_tensor
from pytensor.xtensor.indexing import Index, IndexUpdate, index from pytensor.xtensor.indexing import Index, IndexUpdate, index
from pytensor.xtensor.rewriting.utils import register_lower_xtensor from pytensor.xtensor.rewriting.utils import register_lower_xtensor
...@@ -107,7 +106,7 @@ def _lower_index(node): ...@@ -107,7 +106,7 @@ def _lower_index(node):
# We can use basic indexing directly if no other index acts on this dimension # We can use basic indexing directly if no other index acts on this dimension
# This is an optimization that avoids creating an unnecessary arange tensor # This is an optimization that avoids creating an unnecessary arange tensor
# and facilitates the use of the specialized AdvancedSubtensor1 when possible # and facilitates the use of the specialized AdvancedSubtensor1 when possible
aligned_idxs.append(to_basic_idx(idx)) aligned_idxs.append(idx)
basic_idx_axis.append(out_dims.index(x_dim)) basic_idx_axis.append(out_dims.index(x_dim))
else: else:
# Otherwise we need to convert the basic index into an equivalent advanced indexing # Otherwise we need to convert the basic index into an equivalent advanced indexing
...@@ -132,7 +131,7 @@ def _lower_index(node): ...@@ -132,7 +131,7 @@ def _lower_index(node):
if basic_idx_axis: if basic_idx_axis:
aligned_idxs = [ aligned_idxs = [
idx.squeeze(axis=basic_idx_axis) idx.squeeze(axis=basic_idx_axis)
if (isinstance(idx, TensorVariable) and idx.type.ndim > 0) if (isinstance(idx.type, TensorType) and idx.type.ndim > 0)
else idx else idx
for idx in aligned_idxs for idx in aligned_idxs
] ]
......
...@@ -26,7 +26,9 @@ from pytensor.graph.rewriting.unify import LiteralString, OpPattern ...@@ -26,7 +26,9 @@ from pytensor.graph.rewriting.unify import LiteralString, OpPattern
from pytensor.raise_op import assert_op from pytensor.raise_op import assert_op
from pytensor.tensor.math import Dot, add, dot, exp from pytensor.tensor.math import Dot, add, dot, exp
from pytensor.tensor.rewriting.basic import constant_folding from pytensor.tensor.rewriting.basic import constant_folding
from pytensor.tensor.subtensor import AdvancedSubtensor
from pytensor.tensor.type import matrix, values_eq_approx_always_true, vector from pytensor.tensor.type import matrix, values_eq_approx_always_true, vector
from pytensor.tensor.type_other import MakeSlice, SliceConstant, slicetype
from tests.graph.utils import ( from tests.graph.utils import (
MyOp, MyOp,
MyType, MyType,
...@@ -627,6 +629,21 @@ def test_pre_constant_merge(): ...@@ -627,6 +629,21 @@ def test_pre_constant_merge():
assert res == [o2] assert res == [o2]
assert o2.owner.inputs[2] is c2 assert o2.owner.inputs[2] is c2
# What is this supposed to test?
ms = MakeSlice()(1)
res = pre_constant_merge(empty_fgraph, [ms])
assert res == [ms]
const_slice = SliceConstant(type=slicetype, data=slice(1, None, 2))
assert isinstance(const_slice, Constant)
adv = AdvancedSubtensor()(matrix(), [2, 3], const_slice)
res = pre_constant_merge(empty_fgraph, adv)
assert res == [adv]
def test_pre_greedy_node_rewriter(): def test_pre_greedy_node_rewriter():
empty_fgraph = FunctionGraph([], []) empty_fgraph = FunctionGraph([], [])
...@@ -662,6 +679,15 @@ def test_pre_greedy_node_rewriter(): ...@@ -662,6 +679,15 @@ def test_pre_greedy_node_rewriter():
assert cst.owner.inputs[0] is o1 assert cst.owner.inputs[0] is o1
assert cst.owner.inputs[4] is cst.owner.inputs[0] assert cst.owner.inputs[4] is cst.owner.inputs[0]
# What exactly is this supposed to test?
ms = MakeSlice()(1)
cst = pre_greedy_node_rewriter(empty_fgraph, [constant_folding], ms)
assert isinstance(cst, SliceConstant)
# Make sure constant of slice signature is hashable.
assert isinstance(hash(cst.signature()), int)
@pytest.mark.parametrize("tracks", [True, False]) @pytest.mark.parametrize("tracks", [True, False])
@pytest.mark.parametrize("out_pattern", [(op2, "x"), "x", 1.0]) @pytest.mark.parametrize("out_pattern", [(op2, "x"), "x", 1.0])
......
...@@ -225,37 +225,6 @@ def test_jax_IncSubtensor(): ...@@ -225,37 +225,6 @@ def test_jax_IncSubtensor():
compare_jax_and_py([], [out_pt], []) compare_jax_and_py([], [out_pt], [])
@pytest.mark.parametrize(
"func", (pt_subtensor.advanced_inc_subtensor1, pt_subtensor.advanced_set_subtensor1)
)
def test_jax_AdvancedIncSubtensor1_runtime_broadcast(func):
"""Test that JAX backend checks for runtime broadcasting in AdvancedIncSubtensor1.
JAX silently broadcasts when using .at[].set() or .at[].add(), but PyTensor
requires explicit broadcastable dimensions. This test ensures we raise the same
error as the Python/C backend when runtime broadcasting would occur.
"""
from pytensor import function
y = pt.matrix("y", dtype="float64", shape=(None, None))
x = pt.zeros((10, 5))
idxs = np.repeat(np.arange(10), 2) # 20 indices
out = func(x, y, idxs)
f = function([y], out, mode="JAX")
# Should work with correctly sized y
f(np.ones((20, 5)))
# Should raise for runtime broadcasting on first dimension
with pytest.raises(ValueError, match="Runtime broadcasting not allowed"):
f(np.ones((1, 5)))
# Should raise for runtime broadcasting on second dimension
with pytest.raises(ValueError, match="Runtime broadcasting not allowed"):
f(np.ones((20, 1)))
def test_jax_IncSubtensor_boolean_indexing_reexpressible(): def test_jax_IncSubtensor_boolean_indexing_reexpressible():
"""Setting or incrementing values with boolean indexing. """Setting or incrementing values with boolean indexing.
......
...@@ -187,6 +187,27 @@ def test_mlx_inplace_variants(): ...@@ -187,6 +187,27 @@ def test_mlx_inplace_variants():
compare_mlx_and_py([], [out_pt], []) compare_mlx_and_py([], [out_pt], [])
@pytest.mark.xfail(
reason="MLX slice indices must be integers or None, dynamic slices not supported"
)
def test_mlx_MakeSlice():
"""Test MakeSlice operation."""
# Test slice creation
start = pt.iscalar("start")
stop = pt.iscalar("stop")
step = pt.iscalar("step")
# Create a slice using MakeSlice
slice_op = pt_subtensor.MakeSlice()
slice_pt = slice_op(start, stop, step)
# Use simple constant array instead of arange
x_pt = pt.constant(np.arange(10, dtype=np.float32))
out_pt = x_pt[slice_pt]
compare_mlx_and_py([start, stop, step], [out_pt], [1, 8, 2])
def test_mlx_subtensor_edge_cases(): def test_mlx_subtensor_edge_cases():
"""Test edge cases and boundary conditions.""" """Test edge cases and boundary conditions."""
# Empty slices - use constant array # Empty slices - use constant array
......
...@@ -3,7 +3,9 @@ import contextlib ...@@ -3,7 +3,9 @@ import contextlib
import numpy as np import numpy as np
import pytest import pytest
import pytensor.scalar as ps
import pytensor.tensor as pt import pytensor.tensor as pt
from pytensor import Mode, as_symbolic
from pytensor.tensor import as_tensor from pytensor.tensor import as_tensor
from pytensor.tensor.subtensor import ( from pytensor.tensor.subtensor import (
AdvancedIncSubtensor, AdvancedIncSubtensor,
...@@ -18,16 +20,51 @@ from pytensor.tensor.subtensor import ( ...@@ -18,16 +20,51 @@ from pytensor.tensor.subtensor import (
inc_subtensor, inc_subtensor,
set_subtensor, set_subtensor,
) )
from tests.link.numba.test_basic import ( from tests.link.numba.test_basic import compare_numba_and_py, numba_mode
compare_numba_and_py,
numba_inplace_mode,
numba_mode,
)
rng = np.random.default_rng(sum(map(ord, "Numba subtensors"))) rng = np.random.default_rng(sum(map(ord, "Numba subtensors")))
@pytest.mark.parametrize("step", [None, 1, 2, -2, "x"], ids=lambda x: f"step={x}")
@pytest.mark.parametrize("stop", [None, 10, "x"], ids=lambda x: f"stop={x}")
@pytest.mark.parametrize("start", [None, 0, 3, "x"], ids=lambda x: f"start={x}")
def test_slice(start, stop, step):
x = ps.int64("x")
sym_slice = as_symbolic(
slice(
x if start == "x" else start,
x if stop == "x" else stop,
x if step == "x" else step,
)
)
no_opt_mode = Mode(linker="numba", optimizer=None)
evaled_slice = sym_slice.eval({x: -5}, on_unused_input="ignore", mode=no_opt_mode)
assert isinstance(evaled_slice, slice)
if start == "x":
assert evaled_slice.start == -5
elif start is None and (evaled_slice.step is None or evaled_slice.step > 0):
# Numba can convert to 0 (and sometimes does) in this case
assert evaled_slice.start in (None, 0)
else:
assert evaled_slice.start == start
if stop == "x":
assert evaled_slice.stop == -5
else:
assert evaled_slice.stop == stop
if step == "x":
assert evaled_slice.step == -5
elif step is None:
# Numba can convert to 1 (and sometimes does) in this case
assert evaled_slice.step in (None, 1)
else:
assert evaled_slice.step == step
@pytest.mark.parametrize( @pytest.mark.parametrize(
"x, indices", "x, indices",
[ [
...@@ -145,11 +182,6 @@ def test_AdvancedSubtensor1_out_of_bounds(): ...@@ -145,11 +182,6 @@ def test_AdvancedSubtensor1_out_of_bounds():
as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))), as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
([[1, 2], [2, 1]], slice(1, None), [[0, 0], [0, 0]]), ([[1, 2], [2, 1]], slice(1, None), [[0, 0], [0, 0]]),
), ),
# Newaxis with vector indexing
(
as_tensor(np.arange(4 * 4).reshape((4, 4))),
(None, [0, 1, 2], [0, 1, 2]),
),
], ],
) )
@pytest.mark.filterwarnings("error") # Raise if we did not expect objmode to be needed @pytest.mark.filterwarnings("error") # Raise if we did not expect objmode to be needed
...@@ -415,13 +447,6 @@ def test_AdvancedIncSubtensor1(x, y, indices): ...@@ -415,13 +447,6 @@ def test_AdvancedIncSubtensor1(x, y, indices):
False, False,
False, False,
), ),
(
np.arange(4 * 4).reshape((4, 4)),
np.array(5), # Broadcasted scalar value
(None, [0, 1, 2], [0, 1, 2]), # Newaxis with vector indexing
False,
False,
),
], ],
) )
@pytest.mark.parametrize("inplace", (False, True)) @pytest.mark.parametrize("inplace", (False, True))
...@@ -435,9 +460,7 @@ def test_AdvancedIncSubtensor( ...@@ -435,9 +460,7 @@ def test_AdvancedIncSubtensor(
inplace, inplace,
): ):
# Need rewrite to support certain forms of advanced indexing without object mode # Need rewrite to support certain forms of advanced indexing without object mode
# Use inplace_mode when testing inplace operations to preserve inplace flag mode = numba_mode.including("specialize")
base_mode = numba_inplace_mode if inplace else numba_mode
mode = base_mode.including("specialize")
x_pt = pt.as_tensor(x).type("x") x_pt = pt.as_tensor(x).type("x")
y_pt = pt.as_tensor(y).type("y") y_pt = pt.as_tensor(y).type("y")
...@@ -491,3 +514,22 @@ def test_AdvancedIncSubtensor( ...@@ -491,3 +514,22 @@ def test_AdvancedIncSubtensor(
x_orig = x.copy() x_orig = x.copy()
fn(x, y) fn(x, y)
assert not np.all(x == x_orig) assert not np.all(x == x_orig)
def test_advanced_indexing_with_newaxis_fallback_obj_mode():
# This should be automatically solved with https://github.com/pymc-devs/pytensor/issues/1564
# After which we can add these parametrizations to the relevant tests above
x = pt.matrix("x")
out = x[None, [0, 1, 2], [0, 1, 2]]
with pytest.warns(
UserWarning,
match=r"Numba will use object mode to run AdvancedSubtensor's perform method",
):
compare_numba_and_py([x], [out], [np.random.normal(size=(4, 4))])
out = x[None, [0, 1, 2], [0, 1, 2]].inc(5)
with pytest.warns(
UserWarning,
match=r"Numba will use object mode to run AdvancedIncSubtensor's perform method",
):
compare_numba_and_py([x], [out], [np.random.normal(size=(4, 4))])
...@@ -1642,15 +1642,9 @@ def test_InplaceElemwiseOptimizer_bug(): ...@@ -1642,15 +1642,9 @@ def test_InplaceElemwiseOptimizer_bug():
# with config.change_flags(tensor__insert_inplace_optimizer_validate_nb=10): # with config.change_flags(tensor__insert_inplace_optimizer_validate_nb=10):
rewrite_graph(fgraph, include=("inplace",)) rewrite_graph(fgraph, include=("inplace",))
# Save original value to restore later
original_value = pytensor.config.tensor__insert_inplace_optimizer_validate_nb
try:
pytensor.config.tensor__insert_inplace_optimizer_validate_nb = 1 pytensor.config.tensor__insert_inplace_optimizer_validate_nb = 1
with pytest.warns( with pytest.warns(
FutureWarning, FutureWarning,
match="tensor__insert_inplace_optimizer_validate_nb config is deprecated", match="tensor__insert_inplace_optimizer_validate_nb config is deprecated",
): ):
rewrite_graph(fgraph, include=("inplace",)) rewrite_graph(fgraph, include=("inplace",))
finally:
# Restore original value to avoid affecting other tests
pytensor.config.tensor__insert_inplace_optimizer_validate_nb = original_value
...@@ -52,6 +52,7 @@ from pytensor.tensor.type import ( ...@@ -52,6 +52,7 @@ from pytensor.tensor.type import (
tensor4, tensor4,
vector, vector,
) )
from pytensor.tensor.type_other import make_slice
from tests import unittest_tools as utt from tests import unittest_tools as utt
from tests.unittest_tools import create_pytensor_param from tests.unittest_tools import create_pytensor_param
...@@ -1700,11 +1701,11 @@ def test_local_uint_constant_indices(): ...@@ -1700,11 +1701,11 @@ def test_local_uint_constant_indices():
assert isinstance(new_index, Constant) assert isinstance(new_index, Constant)
assert new_index.type.dtype == "uint8" assert new_index.type.dtype == "uint8"
# `AdvancedSubtensor`, two indices, one slice, convert # `AdvancedSubtensor`, two indices, one symbolic slice, convert
x = pt.matrix("x") x = pt.matrix("x")
indices = ( indices = (
pt.as_tensor_variable(np.array([1], np.int64)), pt.as_tensor_variable(np.array(1, np.int64)),
slice(None, 10), make_slice(slice(None, 10)),
) )
z = x[indices] z = x[indices]
...@@ -1791,7 +1792,7 @@ def test_local_uint_constant_indices(): ...@@ -1791,7 +1792,7 @@ def test_local_uint_constant_indices():
z_fn = pytensor.function([x], z, mode=mode) z_fn = pytensor.function([x], z, mode=mode)
subtensor_node = z_fn.maker.fgraph.outputs[0].owner subtensor_node = z_fn.maker.fgraph.outputs[0].owner
assert isinstance(subtensor_node.op, (AdvancedSubtensor, AdvancedSubtensor1)) assert isinstance(subtensor_node.op, AdvancedSubtensor)
new_index = subtensor_node.inputs[1] new_index = subtensor_node.inputs[1]
assert isinstance(new_index, Constant) assert isinstance(new_index, Constant)
assert new_index.type.dtype == "uint8" assert new_index.type.dtype == "uint8"
...@@ -1842,6 +1843,7 @@ class TestBlockwiseIncSubtensor: ...@@ -1842,6 +1843,7 @@ class TestBlockwiseIncSubtensor:
out = vectorize_graph(core_graph, replace={core_x: x, core_y: y}) out = vectorize_graph(core_graph, replace={core_x: x, core_y: y})
fn, ref_fn = self.compile_fn_and_ref([x, y], out) fn, ref_fn = self.compile_fn_and_ref([x, y], out)
assert self.has_blockwise(ref_fn) assert self.has_blockwise(ref_fn)
assert not self.has_blockwise(fn)
test_x = np.ones(x.type.shape, dtype=x.type.dtype) test_x = np.ones(x.type.shape, dtype=x.type.dtype)
test_y = rng.integers(1, 10, size=y.type.shape, dtype=y.type.dtype) test_y = rng.integers(1, 10, size=y.type.shape, dtype=y.type.dtype)
np.testing.assert_allclose(fn(test_x, test_y), ref_fn(test_x, test_y)) np.testing.assert_allclose(fn(test_x, test_y), ref_fn(test_x, test_y))
...@@ -1946,7 +1948,15 @@ class TestBlockwiseIncSubtensor: ...@@ -1946,7 +1948,15 @@ class TestBlockwiseIncSubtensor:
@pytest.mark.parametrize( @pytest.mark.parametrize(
"basic_idx", "basic_idx",
[True, False], [
True,
pytest.param(
False,
marks=pytest.mark.xfail(
reason="AdvancedIncSubtensor with slices can't be blockwise"
),
),
],
ids=["basic_idx", "adv_idx"], ids=["basic_idx", "adv_idx"],
) )
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -1963,7 +1973,7 @@ class TestBlockwiseIncSubtensor: ...@@ -1963,7 +1973,7 @@ class TestBlockwiseIncSubtensor:
core_idx = pt.tensor("idx", dtype=int, shape=() if basic_idx else (2,)) core_idx = pt.tensor("idx", dtype=int, shape=() if basic_idx else (2,))
# The empty slice before core_idx, will lead to a transposition of the advanced view # The empty slice before core_idx, will lead to a transposition of the advanced view
# once it is paired with a new arange slice on the batched dimensions. # once it is paired with an new arange slice on the batched dimensions.
# That's why core_v is (2, 3), and not (3, 2), in the case of advanced indexing # That's why core_v is (2, 3), and not (3, 2), in the case of advanced indexing
core_out = core_a[0, :, core_idx].set(core_v) core_out = core_a[0, :, core_idx].set(core_v)
......
...@@ -32,6 +32,7 @@ from pytensor.tensor import ( ...@@ -32,6 +32,7 @@ from pytensor.tensor import (
lscalars, lscalars,
matrix, matrix,
shape, shape,
slicetype,
specify_shape, specify_shape,
tensor, tensor,
tensor3, tensor3,
...@@ -556,7 +557,7 @@ class TestLocalSubtensorSpecifyShapeLift: ...@@ -556,7 +557,7 @@ class TestLocalSubtensorSpecifyShapeLift:
( (
matrix(), matrix(),
(iscalar(), iscalar()), (iscalar(), iscalar()),
(slice(iscalar(), iscalar(), iscalar()),), (slicetype(),),
), ),
( (
matrix(), matrix(),
...@@ -788,12 +789,12 @@ def test_local_subtensor_shape_constant(): ...@@ -788,12 +789,12 @@ def test_local_subtensor_shape_constant():
(lambda x: x[:, [0, 1]][0], True), (lambda x: x[:, [0, 1]][0], True),
(lambda x: x[:, [0, 1], [0, 0]][1:], True), (lambda x: x[:, [0, 1], [0, 0]][1:], True),
(lambda x: x[:, [[0, 1], [0, 0]]][1:], True), (lambda x: x[:, [[0, 1], [0, 0]]][1:], True),
(lambda x: x[:, None, [0, 1]][0], True),
# Not supported, basic indexing on advanced indexing dim # Not supported, basic indexing on advanced indexing dim
(lambda x: x[[0, 1]][0], False), (lambda x: x[[0, 1]][0], False),
# Not supported, basic indexing on the right of advanced indexing # Not implemented, basic indexing on the right of advanced indexing
(lambda x: x[[0, 1]][:, 0], False), (lambda x: x[[0, 1]][:, 0], False),
# Not implemented, complex flavors of advanced indexing # Not implemented, complex flavors of advanced indexing
(lambda x: x[:, None, [0, 1]][0], False),
(lambda x: x[:, 5:, [0, 1]][0], False), (lambda x: x[:, 5:, [0, 1]][0], False),
(lambda x: x[:, :, np.array([True, False, False])][0], False), (lambda x: x[:, :, np.array([True, False, False])][0], False),
(lambda x: x[[0, 1], :, [0, 1]][:, 0], False), (lambda x: x[[0, 1], :, [0, 1]][:, 0], False),
......
...@@ -31,8 +31,6 @@ from pytensor.tensor.blockwise import ( ...@@ -31,8 +31,6 @@ from pytensor.tensor.blockwise import (
vectorize_node_fallback, vectorize_node_fallback,
) )
from pytensor.tensor.nlinalg import MatrixInverse, eig from pytensor.tensor.nlinalg import MatrixInverse, eig
from pytensor.tensor.random import normal
from pytensor.tensor.random.op import default_rng
from pytensor.tensor.rewriting.blas import specialize_matmul_to_batched_dot from pytensor.tensor.rewriting.blas import specialize_matmul_to_batched_dot
from pytensor.tensor.signal import convolve1d from pytensor.tensor.signal import convolve1d
from pytensor.tensor.slinalg import ( from pytensor.tensor.slinalg import (
...@@ -116,18 +114,16 @@ def test_vectorize_blockwise(): ...@@ -116,18 +114,16 @@ def test_vectorize_blockwise():
def test_vectorize_node_fallback_unsupported_type(): def test_vectorize_node_fallback_unsupported_type():
rng = default_rng() x = tensor("x", shape=(2, 6))
node = normal(rng=rng).owner node = x[:, [0, 2, 4]].owner
with pytest.raises( with pytest.raises(
NotImplementedError, NotImplementedError,
match=re.escape( match=re.escape(
'Cannot vectorize node normal_rv{"(),()->()"}(' "Cannot vectorize node AdvancedSubtensor(x, MakeSlice.0, [0 2 4]) with input MakeSlice.0 of type slice"
"DefaultGeneratorMakerOp.0, NoneConst{None}, 0.0, 1.0)"
" with input DefaultGeneratorMakerOp.0 of type RandomGeneratorType"
), ),
): ):
vectorize_node_fallback(node.op, node, *node.inputs) vectorize_node_fallback(node.op, node, node.inputs)
def check_blockwise_runtime_broadcasting(mode): def check_blockwise_runtime_broadcasting(mode):
......
...@@ -4,8 +4,30 @@ import pytensor ...@@ -4,8 +4,30 @@ import pytensor
from pytensor import as_symbolic from pytensor import as_symbolic
from pytensor.graph.basic import Constant from pytensor.graph.basic import Constant
from pytensor.tensor.math import argmax from pytensor.tensor.math import argmax
from pytensor.tensor.type import vector from pytensor.tensor.type import iscalar, vector
from pytensor.tensor.type_other import NoneConst, NoneTypeT from pytensor.tensor.type_other import (
MakeSlice,
NoneConst,
NoneTypeT,
SliceConstant,
SliceType,
make_slice,
)
def test_SliceType():
st = SliceType()
assert st == st.clone()
def test_make_slice_merge():
# In the past, this was crahsing during compilation.
i = iscalar()
s1 = make_slice(0, i)
s2 = make_slice(0, i)
f = pytensor.function([i], [s1, s2])
nodes = f.maker.fgraph.apply_nodes
assert len([n for n in nodes if isinstance(n.op, MakeSlice)]) == 1
def test_none_Constant(): def test_none_Constant():
...@@ -25,6 +47,8 @@ def test_none_Constant(): ...@@ -25,6 +47,8 @@ def test_none_Constant():
# This trigger equals that returned the wrong answer in the past. # This trigger equals that returned the wrong answer in the past.
import pickle import pickle
import pytensor
x = vector("x") x = vector("x")
y = argmax(x) y = argmax(x)
kwargs = {} kwargs = {}
...@@ -36,18 +60,11 @@ def test_none_Constant(): ...@@ -36,18 +60,11 @@ def test_none_Constant():
def test_as_symbolic(): def test_as_symbolic():
# Remove this when xtensor is not using symbolic slices
from pytensor.tensor.type import iscalar
from pytensor.tensor.type_other import SliceConstant, slicetype
res = as_symbolic(None) res = as_symbolic(None)
assert res is NoneConst assert res is NoneConst
res = as_symbolic(slice(iscalar()))
assert res.owner.op == make_slice
res = as_symbolic(slice(1, 2)) res = as_symbolic(slice(1, 2))
assert isinstance(res, SliceConstant) assert isinstance(res, SliceConstant)
assert res.type == slicetype
assert res.data == slice(1, 2)
i = iscalar()
res = as_symbolic(slice(i))
assert res.owner is not None
...@@ -35,7 +35,7 @@ from pytensor.tensor.type import ( ...@@ -35,7 +35,7 @@ from pytensor.tensor.type import (
scalar, scalar,
tensor3, tensor3,
) )
from pytensor.tensor.type_other import NoneConst from pytensor.tensor.type_other import MakeSlice, NoneConst
from pytensor.tensor.variable import ( from pytensor.tensor.variable import (
DenseTensorConstant, DenseTensorConstant,
DenseTensorVariable, DenseTensorVariable,
...@@ -232,11 +232,11 @@ def test__getitem__AdvancedSubtensor(): ...@@ -232,11 +232,11 @@ def test__getitem__AdvancedSubtensor():
z = x[:, i] z = x[:, i]
op_types = [type(node.op) for node in io_toposort([x, i], [z])] op_types = [type(node.op) for node in io_toposort([x, i], [z])]
assert op_types == [AdvancedSubtensor] assert op_types == [MakeSlice, AdvancedSubtensor]
z = x[..., i, None] z = x[..., i, None]
op_types = [type(node.op) for node in io_toposort([x, i], [z])] op_types = [type(node.op) for node in io_toposort([x, i], [z])]
assert op_types == [DimShuffle, AdvancedSubtensor] assert op_types == [MakeSlice, AdvancedSubtensor]
z = x[i, None] z = x[i, None]
op_types = [type(node.op) for node in io_toposort([x, i], [z])] op_types = [type(node.op) for node in io_toposort([x, i], [z])]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论