提交 db7fa079 authored 作者: Jaan Erik Pihel's avatar Jaan Erik Pihel 提交者: Ricardo Vieira

Refactor AdvancedSubtensor

- newaxis is handled as explicit DimShuffel on the inputs - slices are encoded internally, so the Ops only take numerical inputs Co-authored-by: 's avatarRicardo Vieira <28983449+ricardov94@users.noreply.github.com>
上级 87470065
......@@ -771,9 +771,9 @@ class DestroyHandler(Bookkeeper):
}
tolerated.add(destroyed_idx)
tolerate_aliased = getattr(
app.op, "destroyhandler_tolerate_aliased", []
app.op, "destroyhandler_tolerate_aliased", ()
)
assert isinstance(tolerate_aliased, list)
assert isinstance(tolerate_aliased, tuple | list)
ignored = {
idx1 for idx0, idx1 in tolerate_aliased if idx0 == destroyed_idx
}
......
......@@ -8,7 +8,6 @@ from pytensor.tensor.subtensor import (
Subtensor,
indices_from_subtensor,
)
from pytensor.tensor.type_other import MakeSlice
BOOLEAN_MASK_ERROR = """JAX does not support resizing arrays with boolean
......@@ -35,10 +34,8 @@ slice length.
@jax_funcify.register(AdvancedSubtensor)
@jax_funcify.register(AdvancedSubtensor1)
def jax_funcify_Subtensor(op, node, **kwargs):
idx_list = getattr(op, "idx_list", None)
def subtensor(x, *ilists):
indices = indices_from_subtensor(ilists, idx_list)
indices = indices_from_subtensor(ilists, op.idx_list)
if len(indices) == 1:
indices = indices[0]
......@@ -48,10 +45,9 @@ def jax_funcify_Subtensor(op, node, **kwargs):
@jax_funcify.register(IncSubtensor)
@jax_funcify.register(AdvancedIncSubtensor)
@jax_funcify.register(AdvancedIncSubtensor1)
def jax_funcify_IncSubtensor(op, node, **kwargs):
idx_list = getattr(op, "idx_list", None)
if getattr(op, "set_instead_of_inc", False):
def jax_fn(x, indices, y):
......@@ -62,7 +58,7 @@ def jax_funcify_IncSubtensor(op, node, **kwargs):
def jax_fn(x, indices, y):
return x.at[indices].add(y)
def incsubtensor(x, y, *ilist, jax_fn=jax_fn, idx_list=idx_list):
def incsubtensor(x, y, *ilist, jax_fn=jax_fn, idx_list=op.idx_list):
indices = indices_from_subtensor(ilist, idx_list)
if len(indices) == 1:
indices = indices[0]
......@@ -73,29 +69,3 @@ def jax_funcify_IncSubtensor(op, node, **kwargs):
return jax_fn(x, indices, y)
return incsubtensor
@jax_funcify.register(AdvancedIncSubtensor)
def jax_funcify_AdvancedIncSubtensor(op, node, **kwargs):
if getattr(op, "set_instead_of_inc", False):
def jax_fn(x, indices, y):
return x.at[indices].set(y)
else:
def jax_fn(x, indices, y):
return x.at[indices].add(y)
def advancedincsubtensor(x, y, *ilist, jax_fn=jax_fn):
return jax_fn(x, ilist, y)
return advancedincsubtensor
@jax_funcify.register(MakeSlice)
def jax_funcify_MakeSlice(op, **kwargs):
def makeslice(*x):
return slice(*x)
return makeslice
......@@ -10,15 +10,14 @@ from pytensor.tensor.subtensor import (
Subtensor,
indices_from_subtensor,
)
from pytensor.tensor.type_other import MakeSlice
@mlx_funcify.register(Subtensor)
def mlx_funcify_Subtensor(op, node, **kwargs):
idx_list = getattr(op, "idx_list", None)
def subtensor(x, *ilists):
indices = indices_from_subtensor([int(element) for element in ilists], idx_list)
indices = indices_from_subtensor(
[int(element) for element in ilists], op.idx_list
)
if len(indices) == 1:
indices = indices[0]
......@@ -30,10 +29,8 @@ def mlx_funcify_Subtensor(op, node, **kwargs):
@mlx_funcify.register(AdvancedSubtensor)
@mlx_funcify.register(AdvancedSubtensor1)
def mlx_funcify_AdvancedSubtensor(op, node, **kwargs):
idx_list = getattr(op, "idx_list", None)
def advanced_subtensor(x, *ilists):
indices = indices_from_subtensor(ilists, idx_list)
indices = indices_from_subtensor(ilists, op.idx_list)
if len(indices) == 1:
indices = indices[0]
......@@ -45,8 +42,6 @@ def mlx_funcify_AdvancedSubtensor(op, node, **kwargs):
@mlx_funcify.register(IncSubtensor)
@mlx_funcify.register(AdvancedIncSubtensor1)
def mlx_funcify_IncSubtensor(op, node, **kwargs):
idx_list = getattr(op, "idx_list", None)
if getattr(op, "set_instead_of_inc", False):
def mlx_fn(x, indices, y):
......@@ -63,7 +58,7 @@ def mlx_funcify_IncSubtensor(op, node, **kwargs):
x[indices] += y
return x
def incsubtensor(x, y, *ilist, mlx_fn=mlx_fn, idx_list=idx_list):
def incsubtensor(x, y, *ilist, mlx_fn=mlx_fn, idx_list=op.idx_list):
indices = indices_from_subtensor(ilist, idx_list)
if len(indices) == 1:
indices = indices[0]
......@@ -95,11 +90,3 @@ def mlx_funcify_AdvancedIncSubtensor(op, node, **kwargs):
return mlx_fn(x, ilist, y)
return advancedincsubtensor
@mlx_funcify.register(MakeSlice)
def mlx_funcify_MakeSlice(op, **kwargs):
def makeslice(*x):
return slice(*x)
return makeslice
......@@ -9,7 +9,6 @@ from pytensor.tensor.subtensor import (
Subtensor,
indices_from_subtensor,
)
from pytensor.tensor.type_other import MakeSlice, SliceType
def check_negative_steps(indices):
......@@ -47,23 +46,11 @@ def pytorch_funcify_Subtensor(op, node, **kwargs):
return subtensor
@pytorch_funcify.register(MakeSlice)
def pytorch_funcify_makeslice(op, **kwargs):
def makeslice(start, stop, step):
# Torch does not like numpy integers in indexing slices
return slice(
None if start is None else int(start),
None if stop is None else int(stop),
None if step is None else int(step),
)
return makeslice
@pytorch_funcify.register(AdvancedSubtensor1)
@pytorch_funcify.register(AdvancedSubtensor)
def pytorch_funcify_AdvSubtensor(op, node, **kwargs):
def advsubtensor(x, *indices):
indices = indices_from_subtensor(indices, op.idx_list)
check_negative_steps(indices)
return x[indices]
......@@ -102,12 +89,14 @@ def pytorch_funcify_IncSubtensor(op, node, **kwargs):
@pytorch_funcify.register(AdvancedIncSubtensor)
@pytorch_funcify.register(AdvancedIncSubtensor1)
def pytorch_funcify_AdvancedIncSubtensor(op, node, **kwargs):
idx_list = op.idx_list
inplace = op.inplace
ignore_duplicates = getattr(op, "ignore_duplicates", False)
if op.set_instead_of_inc:
def adv_set_subtensor(x, y, *indices):
def adv_set_subtensor(x, y, *flattened_indices):
indices = indices_from_subtensor(flattened_indices, idx_list)
check_negative_steps(indices)
if isinstance(op, AdvancedIncSubtensor1):
op._check_runtime_broadcasting(node, x, y, indices)
......@@ -120,7 +109,8 @@ def pytorch_funcify_AdvancedIncSubtensor(op, node, **kwargs):
elif ignore_duplicates:
def adv_inc_subtensor_no_duplicates(x, y, *indices):
def adv_inc_subtensor_no_duplicates(x, y, *flattened_indices):
indices = indices_from_subtensor(flattened_indices, idx_list)
check_negative_steps(indices)
if isinstance(op, AdvancedIncSubtensor1):
op._check_runtime_broadcasting(node, x, y, indices)
......@@ -132,13 +122,14 @@ def pytorch_funcify_AdvancedIncSubtensor(op, node, **kwargs):
return adv_inc_subtensor_no_duplicates
else:
if any(isinstance(idx.type, SliceType) for idx in node.inputs[2:]):
if any(isinstance(entry, slice) for entry in idx_list):
raise NotImplementedError(
"IncSubtensor with potential duplicates indexes and slice indexing not implemented in PyTorch"
)
def adv_inc_subtensor(x, y, *indices):
# Not needed because slices aren't supported
def adv_inc_subtensor(x, y, *flattened_indices):
indices = indices_from_subtensor(flattened_indices, idx_list)
# Not needed because slices aren't supported in this path
# check_negative_steps(indices)
if not inplace:
x = x.clone()
......
......@@ -72,9 +72,9 @@ from pytensor.tensor.shape import shape
from pytensor.tensor.subtensor import (
IncSubtensor,
Subtensor,
basic_subtensor,
get_canonical_form_slice,
get_idx_list,
get_slice_elements,
set_subtensor,
)
from pytensor.tensor.variable import TensorConstant, TensorVariable
......@@ -1211,7 +1211,7 @@ def _is_default_scan_buffer(final_buffer: TensorVariable, taps: int) -> bool:
if not (
isinstance(op, IncSubtensor)
and op.set_instead_of_inc
and op.idx_list == [slice(None, ps.int64)]
and op.idx_list == (slice(None, 0),)
):
return False
......@@ -1389,12 +1389,6 @@ def scan_save_mem_rewrite(fgraph, node, backend_supports_output_pre_allocation:
else:
# 2.3.1 extract idx list of subtensor
this_slice = get_idx_list(cl.inputs, cl.op.idx_list)
if this_slice is None:
# if unable to extract idx_list
# => outputs needs all its intermediate values
global_nsteps = None
slices[i] = None
break
# 2.3.2 extract the begin/end of the first dimension
if i >= op_info.n_mit_mot:
......@@ -1487,9 +1481,6 @@ def scan_save_mem_rewrite(fgraph, node, backend_supports_output_pre_allocation:
break
else:
this_slice = get_idx_list(cl.inputs, cl.op.idx_list)
if this_slice is None:
store_steps[i] = 0
break
if isinstance(this_slice[0], slice):
start = this_slice[0].start
......@@ -1711,16 +1702,9 @@ def scan_save_mem_rewrite(fgraph, node, backend_supports_output_pre_allocation:
)
else:
fslice = sanitize(cnf_slice[0])
nw_slice = (fslice, *old_slices[1:])
nw_pos = inv_compress_map[idx]
subtens = Subtensor(nw_slice)
# slice inputs
sl_ins = get_slice_elements(
nw_slice, lambda entry: isinstance(entry, Variable)
)
new_o = cast(TensorVariable, subtens(new_outs[nw_pos], *sl_ins))
new_o = basic_subtensor(new_outs[nw_pos], fslice, *old_slices[1:])
if new_o.ndim > 0:
new_o = new_o[:: cnf_slice[1]]
replaced_outs.append(idx)
......@@ -1771,11 +1755,7 @@ def scan_save_mem_rewrite(fgraph, node, backend_supports_output_pre_allocation:
)
nw_slice = (sanitize(position), *old_slices[1:])
subtens = Subtensor(nw_slice)
sl_ins = get_slice_elements(
nw_slice, lambda entry: isinstance(entry, Variable)
)
new_o = cast(TensorVariable, subtens(new_outs[nw_pos], *sl_ins))
new_o = basic_subtensor(new_outs[nw_pos], *nw_slice)
if new_o.ndim > 0:
new_o = new_o[:: cnf_slice[1]]
old_new += [(old, new_o)]
......
......@@ -29,7 +29,7 @@ from pytensor.graph.fg import FunctionGraph, Output
from pytensor.graph.op import Op
from pytensor.graph.replace import _vectorize_node
from pytensor.graph.rewriting.db import EquilibriumDB
from pytensor.graph.type import HasShape, Type
from pytensor.graph.type import HasShape
from pytensor.link.c.op import COp
from pytensor.link.c.params_type import ParamsType
from pytensor.printing import Printer, min_informative_str, pprint, set_precedence
......@@ -433,7 +433,7 @@ def _get_underlying_scalar_constant_value(
var.ndim == 1 for var in v.owner.inputs[0].owner.inputs[1:]
):
idx = v.owner.op.idx_list[0]
if isinstance(idx, Type):
if isinstance(idx, int):
idx = _get_underlying_scalar_constant_value(
v.owner.inputs[1], max_recur=max_recur
)
......@@ -467,7 +467,7 @@ def _get_underlying_scalar_constant_value(
and len(v.owner.op.idx_list) == 1
):
idx = v.owner.op.idx_list[0]
if isinstance(idx, Type):
if isinstance(idx, int):
idx = _get_underlying_scalar_constant_value(
v.owner.inputs[1], max_recur=max_recur
)
......@@ -488,7 +488,7 @@ def _get_underlying_scalar_constant_value(
op = owner.op
idx_list = op.idx_list
idx = idx_list[0]
if isinstance(idx, Type):
if isinstance(idx, int):
idx = _get_underlying_scalar_constant_value(
owner.inputs[1], max_recur=max_recur
)
......
......@@ -23,7 +23,7 @@ from pytensor.tensor.subtensor import (
indices_from_subtensor,
)
from pytensor.tensor.type import integer_dtypes
from pytensor.tensor.type_other import NoneTypeT, SliceType
from pytensor.tensor.type_other import NoneTypeT
def is_rv_used_in_graph(base_rv, node, fgraph):
......@@ -237,20 +237,15 @@ def local_subtensor_rv_lift(fgraph, node):
return False
# Parse indices
if isinstance(subtensor_op, Subtensor):
if isinstance(subtensor_op, Subtensor | AdvancedSubtensor):
indices = indices_from_subtensor(node.inputs[1:], subtensor_op.idx_list)
else:
indices = node.inputs[1:]
# The rewrite doesn't apply if advanced indexing could broadcast the samples (leading to duplicates)
# Note: For simplicity this also excludes subtensor-related expand_dims (np.newaxis).
# If we wanted to support that we could rewrite it as subtensor + dimshuffle
# and make use of the dimshuffle lift rewrite
# TODO: This rewrite is aborting with dummy indexing dimensions which aren't a problem
if any(
is_nd_advanced_idx(idx, integer_dtypes) or isinstance(idx.type, NoneTypeT)
for idx in indices
):
return False
# The rewrite doesn't apply if advanced indexing could broadcast the samples (leading to duplicates)
# TODO: This rewrite is aborting with dummy indexing dimensions which aren't a problem
# (e.g., x[[0],] is equivalent to x[0] - can only index one entry, won't lead to duplicates)
if any(is_nd_advanced_idx(idx, integer_dtypes) for idx in indices):
return False
# Check that indexing does not act on support dims
batch_ndims = rv_op.batch_ndim(rv_node)
......@@ -268,10 +263,7 @@ def local_subtensor_rv_lift(fgraph, node):
non_bool_indices[batch_ndims:],
)
for idx in supp_indices:
if not (
isinstance(idx.type, SliceType)
and all(isinstance(i.type, NoneTypeT) for i in idx.owner.inputs)
):
if idx != slice(None):
return False
n_discarded_idxs = len(supp_indices)
indices = indices[:-n_discarded_idxs]
......@@ -331,7 +323,7 @@ def local_subtensor_rv_lift(fgraph, node):
# Broadcasted dim
if curr_dim in bcast_param_dims:
# Slice indexing, keep degenerate dim by none-slicing
if isinstance(idx, slice) or isinstance(idx.type, SliceType):
if isinstance(idx, slice):
batch_indices.append(slice(None))
# Integer indexing, drop degenerate dim by 0-indexing
else:
......
......@@ -17,7 +17,6 @@ from pytensor.graph.rewriting.basic import (
)
from pytensor.graph.traversal import ancestors
from pytensor.graph.utils import InconsistencyError, get_variable_trace_string
from pytensor.scalar import ScalarType
from pytensor.tensor.basic import (
MakeVector,
as_tensor_variable,
......@@ -842,13 +841,16 @@ def _is_shape_i_of_x(
if isinstance(var.owner.op, Shape_i):
return (var.owner.op.i == i) and (var.owner.inputs[0] == x) # type: ignore
# Match Subtensor((ScalarType,))(Shape(input), i)
# Match Subtensor((int,))(Shape(input), i) - single integer index into shape
if isinstance(var.owner.op, Subtensor):
idx_entry = (
var.owner.op.idx_list[0] if len(var.owner.op.idx_list) == 1 else None
)
return (
# Check we have integer indexing operation
# (and not slice or multiple indexing)
len(var.owner.op.idx_list) == 1
and isinstance(var.owner.op.idx_list[0], ScalarType)
and isinstance(idx_entry, int)
# Check we are indexing on the shape of x
and var.owner.inputs[0].owner is not None
and isinstance(var.owner.inputs[0].owner.op, Shape)
......
......@@ -8,7 +8,6 @@ from pytensor import Variable
from pytensor.compile import optdb
from pytensor.graph import Constant, FunctionGraph, node_rewriter, vectorize_graph
from pytensor.graph.rewriting.basic import NodeRewriter, copy_stack_trace
from pytensor.scalar import basic as ps
from pytensor.tensor.basic import (
Alloc,
Join,
......@@ -31,7 +30,7 @@ from pytensor.tensor.rewriting.basic import (
register_stabilize,
)
from pytensor.tensor.rewriting.elemwise import local_dimshuffle_lift
from pytensor.tensor.rewriting.subtensor import is_full_slice, register_useless
from pytensor.tensor.rewriting.subtensor import register_useless
from pytensor.tensor.shape import (
Shape,
SpecifyShape,
......@@ -50,7 +49,6 @@ from pytensor.tensor.subtensor import (
indices_from_subtensor,
)
from pytensor.tensor.type import TensorType
from pytensor.tensor.type_other import NoneTypeT, SliceType
from pytensor.tensor.variable import TensorVariable
......@@ -71,7 +69,7 @@ def _axis_is_indexed_by_basic_index(
) -> bool:
if isinstance(axis, int):
axis = (axis,)
return any(ax < len(idxs) and not is_full_slice(idxs[ax]) for ax in axis)
return any(ax < len(idxs) and not idxs[ax] == slice(None) for ax in axis)
def _lift_subtensor_non_axis(
......@@ -83,7 +81,7 @@ def _lift_subtensor_non_axis(
old_subtensor_variable: TensorVariable,
) -> None | list[TensorVariable]:
# Apply generic subtensor lift rewrite along "non-axis" dimensions
real_indices = [idx for idx in idx_tuple if not is_full_slice(idx)]
real_indices = [idx for idx in idx_tuple if not idx == slice(None)]
if len(real_indices) > 1 and variable.type.ndim > 1:
# Split the subtensor
idx_to_keep = idx_tuple[axis]
......@@ -206,7 +204,7 @@ def local_subtensor_of_batch_dims(fgraph, node):
if len(idx_tuple) > batch_ndim:
# Indexing on core dimensions of Blockwise. We split the indices and lift the batch ones only
batch_indices, core_indices = idx_tuple[:batch_ndim], idx_tuple[batch_ndim:]
if all(is_full_slice(idx) for idx in batch_indices):
if all(idx == slice(None) for idx in batch_indices):
# No batch indices, nothing to do
return None
elem_with_batch_indices = elem[batch_indices]
......@@ -240,7 +238,7 @@ def local_subtensor_of_batch_dims(fgraph, node):
strict=False,
)
):
if is_full_slice(dim_idx):
if dim_idx == slice(None):
# Full slice can be safely applied to all inputs
continue
......@@ -429,7 +427,7 @@ def local_subtensor_of_expand_dims(fgraph, node):
if i in expanded_axes:
if isinstance(idx_item, slice):
# Slice could be keeping or dropping this dimension
if is_full_slice(idx_item):
if idx_item == slice(None):
# A None slice, always keeps the dimension.
# We skip the index, and later introduce the needed expand_dim
continue
......@@ -648,10 +646,7 @@ def local_subtensor_SpecifyShape_lift(fgraph, node):
indices = get_idx_list(node.inputs, node.op.idx_list)
if any(
isinstance(index, slice) or isinstance(getattr(index, "type", None), SliceType)
for index in indices
):
if any(isinstance(index, slice) for index in indices):
return False
new_obj_arg = obj_arg[indices]
......@@ -702,15 +697,12 @@ def local_subtensor_make_vector(fgraph, node):
(idx,) = idxs
if isinstance(idx, ps.ScalarType | TensorType):
old_idx, idx = idx, node.inputs[1]
assert idx.type.is_super(old_idx)
if isinstance(idx, int):
idx = node.inputs[1]
elif isinstance(node.op, AdvancedSubtensor1):
idx = node.inputs[1]
if isinstance(idx, int | np.integer):
return [x.owner.inputs[idx]]
elif isinstance(idx, Variable):
if isinstance(idx, Variable):
if idx.ndim == 0:
try:
v = get_underlying_scalar_constant_value(
......@@ -833,8 +825,6 @@ def local_subtensor_shape_constant(fgraph, node):
except NotScalarConstantError:
return False
assert idx_val != np.newaxis
if not isinstance(shape_arg.type, TensorType):
return False
......@@ -871,22 +861,24 @@ def local_subtensor_of_adv_subtensor(fgraph, node):
# AdvancedSubtensor involves a full_copy, so we don't want to do it twice
return None
x, *adv_idxs = adv_subtensor.owner.inputs
x, *adv_index_vars = adv_subtensor.owner.inputs
adv_idxs = indices_from_subtensor(adv_index_vars, adv_subtensor.owner.op.idx_list)
# Advanced indexing is a minefield, avoid all cases except for consecutive integer indices
if any(
(
isinstance(adv_idx.type, NoneTypeT)
or (isinstance(adv_idx.type, TensorType) and adv_idx.type.dtype == "bool")
or (isinstance(adv_idx.type, SliceType) and not is_full_slice(adv_idx))
if (
not all(
(
(isinstance(adv_idx, TensorVariable) and adv_idx.type.dtype != "bool")
or (isinstance(adv_idx, slice) and adv_idx == slice(None))
)
for adv_idx in adv_idxs
)
for adv_idx in adv_idxs
) or _non_consecutive_adv_indexing(adv_idxs):
return None
for first_adv_idx_dim, adv_idx in enumerate(adv_idxs):
# We already made sure there were only None slices besides integer indexes
if isinstance(adv_idx.type, TensorType):
if isinstance(adv_idx, TensorVariable):
break
else: # no-break
# Not sure if this should ever happen, but better safe than sorry
......@@ -909,7 +901,7 @@ def local_subtensor_of_adv_subtensor(fgraph, node):
copy_stack_trace([basic_subtensor, adv_subtensor], x_indexed)
x_after_index_lift = expand_dims(x_indexed, dropped_dims)
x_after_adv_idx = adv_subtensor.owner.op(x_after_index_lift, *adv_idxs)
x_after_adv_idx = adv_subtensor.owner.op(x_after_index_lift, *adv_index_vars)
copy_stack_trace([basic_subtensor, adv_subtensor], x_after_adv_idx)
new_out = squeeze(x_after_adv_idx[basic_idxs_kept], dropped_dims)
......
......@@ -38,7 +38,7 @@ from pytensor.tensor.elemwise import CAReduce, DimShuffle
from pytensor.tensor.math import Min, neg
from pytensor.tensor.rewriting.basic import register_uncanonicalize
from pytensor.tensor.shape import Reshape, reshape
from pytensor.tensor.subtensor import Subtensor
from pytensor.tensor.subtensor import Subtensor, indices_from_subtensor
@register_uncanonicalize
......@@ -193,60 +193,42 @@ def local_dimshuffle_subtensor(fgraph, node):
if not all(broadcastable[i] for i in missing_dims):
return False
# create a new idx_list for a new Subtensor object
# have to loop on idx_list and inputs
# inputs has the length of sum of non None elements of idx_list
# (check in slice!).
# len(missing_dims) can be < len(idx_list), this happens if
# tensor was indexed such as x[scalar, :, :], check that as well
new_idx_list = list(input_.owner.op.idx_list)
new_inputs = [input_.owner.inputs[0]]
# create a new index tuple for a new Subtensor
# Reconstruct the full indices from the subtensor node, then replace
# dimensions that are being dropped by dimshuffle with scalar index 0
x = input_.owner.inputs[0]
indices = list(
indices_from_subtensor(
input_.owner.inputs[1:], input_.owner.op.idx_list
)
)
zero = constant(0)
j = 0
slice_i = -1
subtensor_removed_dims = 0
for i, idx in enumerate(input_.owner.op.idx_list):
# Track which output dimension each index corresponds to
# Scalar indices remove dimensions, slices keep them
output_dim = 0
for i, idx in enumerate(indices):
if isinstance(idx, slice):
slice_i += 1
if slice_i in missing_dims:
# Missing dim is a slice(None), remove by indexing by 0
# This slice produces an output dimension
if output_dim in missing_dims:
# This output dimension is being dropped, so replace slice with scalar
if idx == slice(None):
new_idx_list[i] = zero
new_inputs += [zero]
# Missing dim is an ordinary slice with known output dim length of 1
# Remove by indexing by start
indices[i] = zero
else:
if idx.start is None:
start = zero
else:
start = input_.owner.inputs[1 + j]
j += 1
new_idx_list[i] = start
new_inputs += [start]
# Ignore useless stop and step input if there is one
for slice_attr in ("stop", "step"):
if getattr(idx, slice_attr) is not None:
j += 1
# Keep non-dropped slice inputs
else:
for slice_attr in ("start", "stop", "step"):
if getattr(idx, slice_attr) is not None:
new_inputs += [input_.owner.inputs[1 + j]]
j += 1
# Keep non-dropped non-slice inputs
# Use the start of the slice (or 0 if None)
indices[i] = idx.start if idx.start is not None else zero
output_dim += 1
# Scalar indices don't contribute to output dimensions
# Handle trailing dimensions that weren't explicitly indexed
for input_dim in range(len(indices), x.ndim):
if output_dim in missing_dims:
# This unindexed dimension is being dropped, index with 0
indices.append(zero)
else:
new_inputs += [input_.owner.inputs[1 + j]]
j += 1
subtensor_removed_dims += 1
# Verify the trailing dimensions the subtensor didn't look at.
for idx in range(len(input_.owner.op.idx_list), new_inputs[0].ndim):
if (idx - subtensor_removed_dims) in missing_dims:
while len(new_idx_list) < idx:
new_idx_list.append(slice(None))
new_idx_list.append(zero)
new_inputs.append(zero)
return [Subtensor(new_idx_list)(*new_inputs)]
# This unindexed dimension is kept, index with slice(None)
indices.append(slice(None))
output_dim += 1
return [x[tuple(indices)]]
return False
......@@ -15,9 +15,8 @@ from pytensor.scalar import (
ComplexError,
)
from pytensor.tensor import _get_vector_length
from pytensor.tensor.exceptions import AdvancedIndexingError
from pytensor.tensor.type import TensorType
from pytensor.tensor.type_other import NoneConst
from pytensor.tensor.type_other import NoneTypeT
from pytensor.tensor.utils import hash_from_ndarray
......@@ -455,15 +454,14 @@ class _tensor_py_operators:
elif not isinstance(args, tuple):
args = (args,)
# Count the dimensions, check for bools and find ellipses.
ellipses = []
index_dim_count = 0
for i, arg in enumerate(args):
if arg is np.newaxis or arg is NoneConst:
# no increase in index_dim_count
if arg is None or (
isinstance(arg, Variable) and isinstance(arg.type, NoneTypeT)
):
pass
elif arg is Ellipsis:
# no increase in index_dim_count
ellipses.append(i)
elif (
isinstance(arg, np.ndarray | Variable)
......@@ -505,6 +503,41 @@ class _tensor_py_operators:
self.ndim - index_dim_count
)
if any(
arg is None
or (isinstance(arg, Variable) and isinstance(arg.type, NoneTypeT))
for arg in args
):
expansion_axes = []
new_args = []
# Track dims consumed by args and inserted `None`s after ellipsis
counter = 0
nones = 0
for arg in args:
if arg is None or (
isinstance(arg, Variable) and isinstance(arg.type, NoneTypeT)
):
expansion_axes.append(counter + nones) # Expand here
nones += 1
new_args.append(slice(None))
else:
new_args.append(arg)
consumed = 1
if hasattr(arg, "dtype") and arg.dtype == "bool":
consumed = arg.ndim
counter += consumed
expanded = pt.expand_dims(self, expansion_axes)
if all(
isinstance(arg, slice)
and arg.start is None
and arg.stop is None
and arg.step is None
for arg in new_args
):
return expanded
return expanded[tuple(new_args)]
def is_empty_array(val):
return (isinstance(val, tuple | list) and len(val) == 0) or (
isinstance(val, np.ndarray) and val.size == 0
......@@ -520,74 +553,16 @@ class _tensor_py_operators:
for inp in args
)
# Determine if advanced indexing is needed or not. The logic is
# already in `index_vars_to_types`: if it succeeds, standard indexing is
# used; if it fails with `AdvancedIndexingError`, advanced indexing is
# used
advanced = False
for i, arg in enumerate(args):
if includes_bool(arg):
advanced = True
break
if arg is not np.newaxis and arg is not NoneConst:
try:
pt.subtensor.index_vars_to_types(arg)
except AdvancedIndexingError:
if advanced:
break
else:
advanced = True
if advanced:
return pt.subtensor.advanced_subtensor(self, *args)
if all(
(
isinstance(arg, slice | int | float | np.number)
or (hasattr(arg, "ndim") and arg.ndim == 0 and arg.dtype != "bool")
)
for arg in args
):
return pt.subtensor.basic_subtensor(self, *args)
else:
if np.newaxis in args or NoneConst in args:
# `np.newaxis` (i.e. `None`) in NumPy indexing mean "add a new
# broadcastable dimension at this location". Since PyTensor adds
# new broadcastable dimensions via the `DimShuffle` `Op`, the
# following code uses said `Op` to add one of the new axes and
# then uses recursion to apply any other indices and add any
# remaining new axes.
counter = 0
pattern = []
new_args = []
for arg in args:
if arg is np.newaxis or arg is NoneConst:
pattern.append("x")
new_args.append(slice(None, None, None))
else:
pattern.append(counter)
counter += 1
new_args.append(arg)
pattern.extend(list(range(counter, self.ndim)))
view = self.dimshuffle(pattern)
full_slices = True
for arg in new_args:
# We can't do arg == slice(None, None, None) as in
# Python 2.7, this call __lt__ if we have a slice
# with some symbolic variable.
if not (
isinstance(arg, slice)
and (arg.start is None or arg.start is NoneConst)
and (arg.stop is None or arg.stop is NoneConst)
and (arg.step is None or arg.step is NoneConst)
):
full_slices = False
if full_slices:
return view
else:
return view.__getitem__(tuple(new_args))
else:
return pt.subtensor.Subtensor(args)(
self,
*pt.subtensor.get_slice_elements(
args, lambda entry: isinstance(entry, Variable)
),
)
return pt.subtensor.advanced_subtensor(self, *args)
def __setitem__(self, key, value):
raise TypeError(
......
......@@ -2,9 +2,10 @@ from itertools import zip_longest
from pytensor import as_symbolic
from pytensor.graph import Constant, node_rewriter
from pytensor.tensor import TensorType, arange, specify_shape
from pytensor.tensor import arange, specify_shape
from pytensor.tensor.subtensor import _non_consecutive_adv_indexing, inc_subtensor
from pytensor.tensor.type_other import NoneTypeT, SliceType
from pytensor.tensor.variable import TensorVariable
from pytensor.xtensor.basic import tensor_from_xtensor, xtensor_from_tensor
from pytensor.xtensor.indexing import Index, IndexUpdate, index
from pytensor.xtensor.rewriting.utils import register_lower_xtensor
......@@ -106,7 +107,7 @@ def _lower_index(node):
# We can use basic indexing directly if no other index acts on this dimension
# This is an optimization that avoids creating an unnecessary arange tensor
# and facilitates the use of the specialized AdvancedSubtensor1 when possible
aligned_idxs.append(idx)
aligned_idxs.append(to_basic_idx(idx))
basic_idx_axis.append(out_dims.index(x_dim))
else:
# Otherwise we need to convert the basic index into an equivalent advanced indexing
......@@ -131,7 +132,7 @@ def _lower_index(node):
if basic_idx_axis:
aligned_idxs = [
idx.squeeze(axis=basic_idx_axis)
if (isinstance(idx.type, TensorType) and idx.type.ndim > 0)
if (isinstance(idx, TensorVariable) and idx.type.ndim > 0)
else idx
for idx in aligned_idxs
]
......
......@@ -26,9 +26,7 @@ from pytensor.graph.rewriting.unify import LiteralString, OpPattern
from pytensor.raise_op import assert_op
from pytensor.tensor.math import Dot, add, dot, exp
from pytensor.tensor.rewriting.basic import constant_folding
from pytensor.tensor.subtensor import AdvancedSubtensor
from pytensor.tensor.type import matrix, values_eq_approx_always_true, vector
from pytensor.tensor.type_other import MakeSlice, SliceConstant, slicetype
from tests.graph.utils import (
MyOp,
MyType,
......@@ -629,21 +627,6 @@ def test_pre_constant_merge():
assert res == [o2]
assert o2.owner.inputs[2] is c2
# What is this supposed to test?
ms = MakeSlice()(1)
res = pre_constant_merge(empty_fgraph, [ms])
assert res == [ms]
const_slice = SliceConstant(type=slicetype, data=slice(1, None, 2))
assert isinstance(const_slice, Constant)
adv = AdvancedSubtensor()(matrix(), [2, 3], const_slice)
res = pre_constant_merge(empty_fgraph, adv)
assert res == [adv]
def test_pre_greedy_node_rewriter():
empty_fgraph = FunctionGraph([], [])
......@@ -679,15 +662,6 @@ def test_pre_greedy_node_rewriter():
assert cst.owner.inputs[0] is o1
assert cst.owner.inputs[4] is cst.owner.inputs[0]
# What exactly is this supposed to test?
ms = MakeSlice()(1)
cst = pre_greedy_node_rewriter(empty_fgraph, [constant_folding], ms)
assert isinstance(cst, SliceConstant)
# Make sure constant of slice signature is hashable.
assert isinstance(hash(cst.signature()), int)
@pytest.mark.parametrize("tracks", [True, False])
@pytest.mark.parametrize("out_pattern", [(op2, "x"), "x", 1.0])
......
......@@ -225,6 +225,37 @@ def test_jax_IncSubtensor():
compare_jax_and_py([], [out_pt], [])
@pytest.mark.parametrize(
"func", (pt_subtensor.advanced_inc_subtensor1, pt_subtensor.advanced_set_subtensor1)
)
def test_jax_AdvancedIncSubtensor1_runtime_broadcast(func):
"""Test that JAX backend checks for runtime broadcasting in AdvancedIncSubtensor1.
JAX silently broadcasts when using .at[].set() or .at[].add(), but PyTensor
requires explicit broadcastable dimensions. This test ensures we raise the same
error as the Python/C backend when runtime broadcasting would occur.
"""
from pytensor import function
y = pt.matrix("y", dtype="float64", shape=(None, None))
x = pt.zeros((10, 5))
idxs = np.repeat(np.arange(10), 2) # 20 indices
out = func(x, y, idxs)
f = function([y], out, mode="JAX")
# Should work with correctly sized y
f(np.ones((20, 5)))
# Should raise for runtime broadcasting on first dimension
with pytest.raises(ValueError, match="Runtime broadcasting not allowed"):
f(np.ones((1, 5)))
# Should raise for runtime broadcasting on second dimension
with pytest.raises(ValueError, match="Runtime broadcasting not allowed"):
f(np.ones((20, 1)))
def test_jax_IncSubtensor_boolean_indexing_reexpressible():
"""Setting or incrementing values with boolean indexing.
......
......@@ -187,27 +187,6 @@ def test_mlx_inplace_variants():
compare_mlx_and_py([], [out_pt], [])
@pytest.mark.xfail(
reason="MLX slice indices must be integers or None, dynamic slices not supported"
)
def test_mlx_MakeSlice():
"""Test MakeSlice operation."""
# Test slice creation
start = pt.iscalar("start")
stop = pt.iscalar("stop")
step = pt.iscalar("step")
# Create a slice using MakeSlice
slice_op = pt_subtensor.MakeSlice()
slice_pt = slice_op(start, stop, step)
# Use simple constant array instead of arange
x_pt = pt.constant(np.arange(10, dtype=np.float32))
out_pt = x_pt[slice_pt]
compare_mlx_and_py([start, stop, step], [out_pt], [1, 8, 2])
def test_mlx_subtensor_edge_cases():
"""Test edge cases and boundary conditions."""
# Empty slices - use constant array
......
......@@ -3,9 +3,7 @@ import contextlib
import numpy as np
import pytest
import pytensor.scalar as ps
import pytensor.tensor as pt
from pytensor import Mode, as_symbolic
from pytensor.tensor import as_tensor
from pytensor.tensor.subtensor import (
AdvancedIncSubtensor,
......@@ -20,51 +18,16 @@ from pytensor.tensor.subtensor import (
inc_subtensor,
set_subtensor,
)
from tests.link.numba.test_basic import compare_numba_and_py, numba_mode
from tests.link.numba.test_basic import (
compare_numba_and_py,
numba_inplace_mode,
numba_mode,
)
rng = np.random.default_rng(sum(map(ord, "Numba subtensors")))
@pytest.mark.parametrize("step", [None, 1, 2, -2, "x"], ids=lambda x: f"step={x}")
@pytest.mark.parametrize("stop", [None, 10, "x"], ids=lambda x: f"stop={x}")
@pytest.mark.parametrize("start", [None, 0, 3, "x"], ids=lambda x: f"start={x}")
def test_slice(start, stop, step):
x = ps.int64("x")
sym_slice = as_symbolic(
slice(
x if start == "x" else start,
x if stop == "x" else stop,
x if step == "x" else step,
)
)
no_opt_mode = Mode(linker="numba", optimizer=None)
evaled_slice = sym_slice.eval({x: -5}, on_unused_input="ignore", mode=no_opt_mode)
assert isinstance(evaled_slice, slice)
if start == "x":
assert evaled_slice.start == -5
elif start is None and (evaled_slice.step is None or evaled_slice.step > 0):
# Numba can convert to 0 (and sometimes does) in this case
assert evaled_slice.start in (None, 0)
else:
assert evaled_slice.start == start
if stop == "x":
assert evaled_slice.stop == -5
else:
assert evaled_slice.stop == stop
if step == "x":
assert evaled_slice.step == -5
elif step is None:
# Numba can convert to 1 (and sometimes does) in this case
assert evaled_slice.step in (None, 1)
else:
assert evaled_slice.step == step
@pytest.mark.parametrize(
"x, indices",
[
......@@ -182,6 +145,11 @@ def test_AdvancedSubtensor1_out_of_bounds():
as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))),
([[1, 2], [2, 1]], slice(1, None), [[0, 0], [0, 0]]),
),
# Newaxis with vector indexing
(
as_tensor(np.arange(4 * 4).reshape((4, 4))),
(None, [0, 1, 2], [0, 1, 2]),
),
],
)
@pytest.mark.filterwarnings("error") # Raise if we did not expect objmode to be needed
......@@ -447,6 +415,13 @@ def test_AdvancedIncSubtensor1(x, y, indices):
False,
False,
),
(
np.arange(4 * 4).reshape((4, 4)),
np.array(5), # Broadcasted scalar value
(None, [0, 1, 2], [0, 1, 2]), # Newaxis with vector indexing
False,
False,
),
],
)
@pytest.mark.parametrize("inplace", (False, True))
......@@ -460,7 +435,9 @@ def test_AdvancedIncSubtensor(
inplace,
):
# Need rewrite to support certain forms of advanced indexing without object mode
mode = numba_mode.including("specialize")
# Use inplace_mode when testing inplace operations to preserve inplace flag
base_mode = numba_inplace_mode if inplace else numba_mode
mode = base_mode.including("specialize")
x_pt = pt.as_tensor(x).type("x")
y_pt = pt.as_tensor(y).type("y")
......@@ -514,22 +491,3 @@ def test_AdvancedIncSubtensor(
x_orig = x.copy()
fn(x, y)
assert not np.all(x == x_orig)
def test_advanced_indexing_with_newaxis_fallback_obj_mode():
# This should be automatically solved with https://github.com/pymc-devs/pytensor/issues/1564
# After which we can add these parametrizations to the relevant tests above
x = pt.matrix("x")
out = x[None, [0, 1, 2], [0, 1, 2]]
with pytest.warns(
UserWarning,
match=r"Numba will use object mode to run AdvancedSubtensor's perform method",
):
compare_numba_and_py([x], [out], [np.random.normal(size=(4, 4))])
out = x[None, [0, 1, 2], [0, 1, 2]].inc(5)
with pytest.warns(
UserWarning,
match=r"Numba will use object mode to run AdvancedIncSubtensor's perform method",
):
compare_numba_and_py([x], [out], [np.random.normal(size=(4, 4))])
......@@ -1642,9 +1642,15 @@ def test_InplaceElemwiseOptimizer_bug():
# with config.change_flags(tensor__insert_inplace_optimizer_validate_nb=10):
rewrite_graph(fgraph, include=("inplace",))
pytensor.config.tensor__insert_inplace_optimizer_validate_nb = 1
with pytest.warns(
FutureWarning,
match="tensor__insert_inplace_optimizer_validate_nb config is deprecated",
):
rewrite_graph(fgraph, include=("inplace",))
# Save original value to restore later
original_value = pytensor.config.tensor__insert_inplace_optimizer_validate_nb
try:
pytensor.config.tensor__insert_inplace_optimizer_validate_nb = 1
with pytest.warns(
FutureWarning,
match="tensor__insert_inplace_optimizer_validate_nb config is deprecated",
):
rewrite_graph(fgraph, include=("inplace",))
finally:
# Restore original value to avoid affecting other tests
pytensor.config.tensor__insert_inplace_optimizer_validate_nb = original_value
......@@ -52,7 +52,6 @@ from pytensor.tensor.type import (
tensor4,
vector,
)
from pytensor.tensor.type_other import make_slice
from tests import unittest_tools as utt
from tests.unittest_tools import create_pytensor_param
......@@ -1701,11 +1700,11 @@ def test_local_uint_constant_indices():
assert isinstance(new_index, Constant)
assert new_index.type.dtype == "uint8"
# `AdvancedSubtensor`, two indices, one symbolic slice, convert
# `AdvancedSubtensor`, two indices, one slice, convert
x = pt.matrix("x")
indices = (
pt.as_tensor_variable(np.array(1, np.int64)),
make_slice(slice(None, 10)),
pt.as_tensor_variable(np.array([1], np.int64)),
slice(None, 10),
)
z = x[indices]
......@@ -1792,7 +1791,7 @@ def test_local_uint_constant_indices():
z_fn = pytensor.function([x], z, mode=mode)
subtensor_node = z_fn.maker.fgraph.outputs[0].owner
assert isinstance(subtensor_node.op, AdvancedSubtensor)
assert isinstance(subtensor_node.op, (AdvancedSubtensor, AdvancedSubtensor1))
new_index = subtensor_node.inputs[1]
assert isinstance(new_index, Constant)
assert new_index.type.dtype == "uint8"
......@@ -1843,7 +1842,6 @@ class TestBlockwiseIncSubtensor:
out = vectorize_graph(core_graph, replace={core_x: x, core_y: y})
fn, ref_fn = self.compile_fn_and_ref([x, y], out)
assert self.has_blockwise(ref_fn)
assert not self.has_blockwise(fn)
test_x = np.ones(x.type.shape, dtype=x.type.dtype)
test_y = rng.integers(1, 10, size=y.type.shape, dtype=y.type.dtype)
np.testing.assert_allclose(fn(test_x, test_y), ref_fn(test_x, test_y))
......@@ -1948,15 +1946,7 @@ class TestBlockwiseIncSubtensor:
@pytest.mark.parametrize(
"basic_idx",
[
True,
pytest.param(
False,
marks=pytest.mark.xfail(
reason="AdvancedIncSubtensor with slices can't be blockwise"
),
),
],
[True, False],
ids=["basic_idx", "adv_idx"],
)
@pytest.mark.parametrize(
......@@ -1973,7 +1963,7 @@ class TestBlockwiseIncSubtensor:
core_idx = pt.tensor("idx", dtype=int, shape=() if basic_idx else (2,))
# The empty slice before core_idx, will lead to a transposition of the advanced view
# once it is paired with an new arange slice on the batched dimensions.
# once it is paired with a new arange slice on the batched dimensions.
# That's why core_v is (2, 3), and not (3, 2), in the case of advanced indexing
core_out = core_a[0, :, core_idx].set(core_v)
......
......@@ -32,7 +32,6 @@ from pytensor.tensor import (
lscalars,
matrix,
shape,
slicetype,
specify_shape,
tensor,
tensor3,
......@@ -557,7 +556,7 @@ class TestLocalSubtensorSpecifyShapeLift:
(
matrix(),
(iscalar(), iscalar()),
(slicetype(),),
(slice(iscalar(), iscalar(), iscalar()),),
),
(
matrix(),
......@@ -789,12 +788,12 @@ def test_local_subtensor_shape_constant():
(lambda x: x[:, [0, 1]][0], True),
(lambda x: x[:, [0, 1], [0, 0]][1:], True),
(lambda x: x[:, [[0, 1], [0, 0]]][1:], True),
(lambda x: x[:, None, [0, 1]][0], True),
# Not supported, basic indexing on advanced indexing dim
(lambda x: x[[0, 1]][0], False),
# Not implemented, basic indexing on the right of advanced indexing
# Not supported, basic indexing on the right of advanced indexing
(lambda x: x[[0, 1]][:, 0], False),
# Not implemented, complex flavors of advanced indexing
(lambda x: x[:, None, [0, 1]][0], False),
(lambda x: x[:, 5:, [0, 1]][0], False),
(lambda x: x[:, :, np.array([True, False, False])][0], False),
(lambda x: x[[0, 1], :, [0, 1]][:, 0], False),
......
......@@ -31,6 +31,8 @@ from pytensor.tensor.blockwise import (
vectorize_node_fallback,
)
from pytensor.tensor.nlinalg import MatrixInverse, eig
from pytensor.tensor.random import normal
from pytensor.tensor.random.op import default_rng
from pytensor.tensor.rewriting.blas import specialize_matmul_to_batched_dot
from pytensor.tensor.signal import convolve1d
from pytensor.tensor.slinalg import (
......@@ -114,16 +116,18 @@ def test_vectorize_blockwise():
def test_vectorize_node_fallback_unsupported_type():
x = tensor("x", shape=(2, 6))
node = x[:, [0, 2, 4]].owner
rng = default_rng()
node = normal(rng=rng).owner
with pytest.raises(
NotImplementedError,
match=re.escape(
"Cannot vectorize node AdvancedSubtensor(x, MakeSlice.0, [0 2 4]) with input MakeSlice.0 of type slice"
'Cannot vectorize node normal_rv{"(),()->()"}('
"DefaultGeneratorMakerOp.0, NoneConst{None}, 0.0, 1.0)"
" with input DefaultGeneratorMakerOp.0 of type RandomGeneratorType"
),
):
vectorize_node_fallback(node.op, node, node.inputs)
vectorize_node_fallback(node.op, node, *node.inputs)
def check_blockwise_runtime_broadcasting(mode):
......
......@@ -4,30 +4,8 @@ import pytensor
from pytensor import as_symbolic
from pytensor.graph.basic import Constant
from pytensor.tensor.math import argmax
from pytensor.tensor.type import iscalar, vector
from pytensor.tensor.type_other import (
MakeSlice,
NoneConst,
NoneTypeT,
SliceConstant,
SliceType,
make_slice,
)
def test_SliceType():
st = SliceType()
assert st == st.clone()
def test_make_slice_merge():
# In the past, this was crahsing during compilation.
i = iscalar()
s1 = make_slice(0, i)
s2 = make_slice(0, i)
f = pytensor.function([i], [s1, s2])
nodes = f.maker.fgraph.apply_nodes
assert len([n for n in nodes if isinstance(n.op, MakeSlice)]) == 1
from pytensor.tensor.type import vector
from pytensor.tensor.type_other import NoneConst, NoneTypeT
def test_none_Constant():
......@@ -47,8 +25,6 @@ def test_none_Constant():
# This trigger equals that returned the wrong answer in the past.
import pickle
import pytensor
x = vector("x")
y = argmax(x)
kwargs = {}
......@@ -60,11 +36,18 @@ def test_none_Constant():
def test_as_symbolic():
# Remove this when xtensor is not using symbolic slices
from pytensor.tensor.type import iscalar
from pytensor.tensor.type_other import SliceConstant, slicetype
res = as_symbolic(None)
assert res is NoneConst
res = as_symbolic(slice(iscalar()))
assert res.owner.op == make_slice
res = as_symbolic(slice(1, 2))
assert isinstance(res, SliceConstant)
assert res.type == slicetype
assert res.data == slice(1, 2)
i = iscalar()
res = as_symbolic(slice(i))
assert res.owner is not None
......@@ -35,7 +35,7 @@ from pytensor.tensor.type import (
scalar,
tensor3,
)
from pytensor.tensor.type_other import MakeSlice, NoneConst
from pytensor.tensor.type_other import NoneConst
from pytensor.tensor.variable import (
DenseTensorConstant,
DenseTensorVariable,
......@@ -232,11 +232,11 @@ def test__getitem__AdvancedSubtensor():
z = x[:, i]
op_types = [type(node.op) for node in io_toposort([x, i], [z])]
assert op_types == [MakeSlice, AdvancedSubtensor]
assert op_types == [AdvancedSubtensor]
z = x[..., i, None]
op_types = [type(node.op) for node in io_toposort([x, i], [z])]
assert op_types == [MakeSlice, AdvancedSubtensor]
assert op_types == [DimShuffle, AdvancedSubtensor]
z = x[i, None]
op_types = [type(node.op) for node in io_toposort([x, i], [z])]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论