Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
d8501d14
提交
d8501d14
authored
12月 07, 2025
作者:
Ricardo Vieira
提交者:
Ricardo Vieira
12月 14, 2025
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Numba AdvancedIndexing: Complete support for integer (and mixed basic) advanced indexing
When default `ignore_updates=True` for inc_subtensor, and boolean indices were rewritten during specialize
上级
fe10f960
显示空白字符变更
内嵌
并排
正在显示
3 个修改的文件
包含
417 行增加
和
454 行删除
+417
-454
subtensor.py
pytensor/link/numba/dispatch/subtensor.py
+344
-198
subtensor.py
pytensor/tensor/rewriting/subtensor.py
+19
-179
test_subtensor.py
tests/link/numba/test_subtensor.py
+54
-77
没有找到文件。
pytensor/link/numba/dispatch/subtensor.py
浏览文件 @
d8501d14
import
operator
import
sys
from
hashlib
import
sha256
from
textwrap
import
dedent
,
indent
import
numba
import
numpy
as
np
...
...
@@ -14,13 +15,13 @@ from pytensor.link.numba.cache import (
compile_numba_function_src
,
)
from
pytensor.link.numba.dispatch.basic
import
(
create_tuple_string
,
generate_fallback_impl
,
register_funcify_and_cache_key
,
register_funcify_default_op_cache_key
,
)
from
pytensor.link.numba.dispatch.compile_ops
import
numba_deepcopy
from
pytensor.tensor
import
TensorType
from
pytensor.tensor.rewriting.subtensor
import
is_full_slice
from
pytensor.tensor.subtensor
import
(
AdvancedIncSubtensor
,
AdvancedIncSubtensor1
,
...
...
@@ -29,7 +30,7 @@ from pytensor.tensor.subtensor import (
IncSubtensor
,
Subtensor
,
)
from
pytensor.tensor.type_other
import
MakeSlice
,
NoneTypeT
,
SliceType
from
pytensor.tensor.type_other
import
MakeSlice
,
NoneTypeT
def
slice_new
(
self
,
start
,
stop
,
step
):
...
...
@@ -243,14 +244,6 @@ def numba_funcify_AdvancedSubtensor(op, node, **kwargs):
else
:
_x
,
_y
,
*
idxs
=
node
.
inputs
basic_idxs
=
[
idx
for
idx
in
idxs
if
(
isinstance
(
idx
.
type
,
NoneTypeT
)
or
(
isinstance
(
idx
.
type
,
SliceType
)
and
not
is_full_slice
(
idx
))
)
]
adv_idxs
=
[
{
"axis"
:
i
,
...
...
@@ -262,248 +255,401 @@ def numba_funcify_AdvancedSubtensor(op, node, **kwargs):
if
isinstance
(
idx
.
type
,
TensorType
)
]
# Special implementation for consecutive integer vector indices
must_ignore_duplicates
=
(
isinstance
(
op
,
AdvancedIncSubtensor
)
and
not
op
.
set_instead_of_inc
and
op
.
ignore_duplicates
# Only vector integer indices can have "duplicates", not scalars or boolean vectors
and
not
all
(
adv_idx
[
"ndim"
]
==
0
or
adv_idx
[
"dtype"
]
==
"bool"
for
adv_idx
in
adv_idxs
)
)
# Special implementation for integer indices that respects duplicates
if
(
not
basic_idxs
and
len
(
adv_idxs
)
>=
2
# Must be integer vectors
# Todo: we could allow shape=(1,) if this is the shape of x
and
all
(
(
adv_idx
[
"bcast"
]
==
(
False
,)
and
adv_idx
[
"dtype"
]
!=
"bool"
)
for
adv_idx
in
adv_idxs
)
# Must be consecutive
and
not
op
.
non_consecutive_adv_indexing
(
node
)
not
must_ignore_duplicates
and
len
(
adv_idxs
)
>=
1
and
all
(
adv_idx
[
"dtype"
]
!=
"bool"
for
adv_idx
in
adv_idxs
)
# Implementation does not support newaxis
and
not
any
(
isinstance
(
idx
.
type
,
NoneTypeT
)
for
idx
in
idxs
)
):
return
numba_funcify_multiple_integer_vector
_indexing
(
op
,
node
,
**
kwargs
)
return
vector_integer_advanced
_indexing
(
op
,
node
,
**
kwargs
)
# Other cases not natively supported by Numba (fallback to obj-mode)
if
(
# Numba does not support indexes with more than one dimension
any
(
idx
[
"ndim"
]
>
1
for
idx
in
adv_idxs
)
# Nor multiple vector indexes
or
sum
(
idx
[
"ndim"
]
>
0
for
idx
in
adv_idxs
)
>
1
# The default PyTensor implementation does not handle duplicate indices correctly
or
(
must_respect_duplicates
=
(
isinstance
(
op
,
AdvancedIncSubtensor
)
and
not
op
.
set_instead_of_inc
and
not
(
op
.
ignore_duplicates
and
not
op
.
ignore_duplicates
# Only vector integer indices can have "duplicates", not scalars or boolean vectors
or
all
(
adv_idx
[
"ndim"
]
==
0
or
adv_idx
[
"dtype"
]
==
"bool"
for
adv_idx
in
adv_idxs
)
and
not
all
(
adv_idx
[
"ndim"
]
==
0
or
adv_idx
[
"dtype"
]
==
"bool"
for
adv_idx
in
adv_idxs
)
)
# Cases natively supported by Numba
if
(
# Numba indexing, like Numpy, ignores duplicates in update
not
must_respect_duplicates
# Numba does not support indexes with more than one dimension
and
not
any
(
idx
[
"ndim"
]
>
1
for
idx
in
adv_idxs
)
# Nor multiple vector indexes
and
not
sum
(
idx
[
"ndim"
]
>
0
for
idx
in
adv_idxs
)
>
1
):
return
numba_funcify_default_subtensor
(
op
,
node
,
**
kwargs
)
# Otherwise fallback to obj_mode
return
generate_fallback_impl
(
op
,
node
,
**
kwargs
),
subtensor_op_cache_key
(
op
,
func
=
"fallback_impl"
)
# What's left should all be supported natively by numba
return
numba_funcify_default_subtensor
(
op
,
node
,
**
kwargs
)
def
_broadcasted_to
(
x_bcast
:
tuple
[
bool
,
...
],
to_bcast
:
tuple
[
bool
,
...
]):
# Check that x is not broadcasted to y based on broadcastable info
if
len
(
x_bcast
)
<
len
(
to_bcast
):
return
True
for
x_bcast_dim
,
to_bcast_dim
in
zip
(
x_bcast
,
to_bcast
,
strict
=
True
):
if
x_bcast_dim
and
not
to_bcast_dim
:
return
True
return
False
@register_funcify_and_cache_key
(
AdvancedIncSubtensor1
)
def
numba_funcify_AdvancedIncSubtensor1
(
op
,
node
,
**
kwargs
):
return
vector_integer_advanced_indexing
(
op
,
node
=
node
,
**
kwargs
)
def
numba_funcify_multiple_integer_vector
_indexing
(
op
:
AdvancedSubtensor
|
AdvancedIncSubtensor
,
node
,
**
kwargs
def
vector_integer_advanced
_indexing
(
op
:
AdvancedSubtensor
1
|
AdvancedSubtensor
|
AdvancedIncSubtensor
,
node
,
**
kwargs
):
# Special-case implementation for multiple consecutive vector integer indices (and set/incsubtensor)
if
isinstance
(
op
,
AdvancedSubtensor
):
idxs
=
node
.
inputs
[
1
:]
else
:
idxs
=
node
.
inputs
[
2
:]
"""Implement all forms of advanced indexing (and assignment) that combine basic and vector integer indices.
first_axis
=
next
(
i
for
i
,
idx
in
enumerate
(
idxs
)
if
isinstance
(
idx
.
type
,
TensorType
)
)
try
:
after_last_axis
=
next
(
i
for
i
,
idx
in
enumerate
(
idxs
[
first_axis
:],
start
=
first_axis
)
if
not
isinstance
(
idx
.
type
,
TensorType
)
It does not support `newaxis` in basic indices
It handles += like `np.add.at` would, accumulating add for duplicate indices.
Examples
--------
Codegen for an AdvancedSubtensor, with non-consecutive matrix indices, and a slice(1, None) basic index
.. code-block:: python
# AdvancedSubtensor [id A] <Tensor3(int64, shape=(2, 2, 3))>
# ├─ <Tensor3(int64, shape=(3, 4, 5))> [id B] <Tensor3(int64, shape=(3, 4, 5))>
# ├─ [[1 2] [2 1]] [id C] <Matrix(uint8, shape=(2, 2))>
# ├─ SliceConstant{1, None, None} [id D] <slice>
# └─ [[0 0] [0 0]] [id E] <Matrix(uint8, shape=(2, 2))>
def advanced_integer_vector_indexing(x, idx0, idx1, idx2):
# Move advanced indexed dims to the front (if needed)
x_adv_dims_front = x.transpose((0, 2, 1))
# Perform basic indexing once (if needed)
basic_indexed_x = x_adv_dims_front[:, :, idx1]
# Broadcast indices
adv_idx_shape = np.broadcast_shapes(idx0.shape, idx2.shape)
(idx0, idx2) = (
np.broadcast_to(idx0, adv_idx_shape),
np.broadcast_to(idx2, adv_idx_shape),
)
except
StopIteration
:
after_last_axis
=
len
(
idxs
)
last_axis
=
after_last_axis
-
1
vector_indices
=
idxs
[
first_axis
:
after_last_axis
]
assert
all
(
v
.
type
.
broadcastable
==
(
False
,)
for
v
in
vector_indices
)
y_is_broadcasted
=
False
# Create output buffer
adv_idx_size = idx0.size
basic_idx_shape = basic_indexed_x.shape[2:]
out_buffer = np.empty((adv_idx_size, *basic_idx_shape), dtype=x.dtype)
if
isinstance
(
op
,
AdvancedSubtensor
):
# Index over tuples of raveled advanced indices and write to output buffer
for i, scalar_idxs in enumerate(zip(idx0.ravel(), idx2.ravel())):
out_buffer[i] = basic_indexed_x[scalar_idxs]
@numba_basic.numba_njit
def
advanced_subtensor_multiple_vector
(
x
,
*
idxs
):
none_slices
=
idxs
[:
first_axis
]
vec_idxs
=
idxs
[
first_axis
:
after_last_axis
]
x_shape
=
x
.
shape
idx_shape
=
vec_idxs
[
0
]
.
shape
shape_bef
=
x_shape
[:
first_axis
]
shape_aft
=
x_shape
[
after_last_axis
:]
out_shape
=
(
*
shape_bef
,
*
idx_shape
,
*
shape_aft
)
out_buffer
=
np
.
empty
(
out_shape
,
dtype
=
x
.
dtype
)
for
i
,
scalar_idxs
in
enumerate
(
zip
(
*
vec_idxs
)):
out_buffer
[(
*
none_slices
,
i
)]
=
x
[(
*
none_slices
,
*
scalar_idxs
)]
# Unravel out_buffer (if needed)
out_buffer = out_buffer.reshape((*adv_idx_shape, *basic_idx_shape))
# Move advanced output indexing group to its final position (if needed) and return
return out_buffer
ret_func
=
advanced_subtensor_multiple_vector
else
:
inplace
=
op
.
inplace
Codegen for similar AdvancedSetSubtensor
# Check if y must be broadcasted
# Includes the last integer vector index,
x
,
y
=
node
.
inputs
[:
2
]
indexed_bcast_dims
=
(
*
x
.
type
.
broadcastable
[:
first_axis
],
*
x
.
type
.
broadcastable
[
last_axis
:],
)
y_is_broadcasted
=
_broadcasted_to
(
y
.
type
.
broadcastable
,
indexed_bcast_dims
)
.. code-block::python
if
op
.
set_instead_of_inc
:
AdvancedSetSubtensor [id A] <Tensor3(int64, shape=(3, 4, 5))>
├─ x [id B] <Tensor3(int64, shape=(3, 4, 5))>
├─ y [id C] <Matrix(int64, shape=(2, 4))>
├─ [1 2] [id D] <Vector(uint8, shape=(2,))>
├─ SliceConstant{None, None, None} [id E] <slice>
└─ [3 4] [id F] <Vector(uint8, shape=(2,))>
@numba_basic.numba_njit
def
advanced_set_subtensor_multiple_vector
(
x
,
y
,
*
idxs
):
vec_idxs
=
idxs
[
first_axis
:
after_last_axis
]
x_shape
=
x
.
shape
def set_advanced_integer_vector_indexing(x, y, idx0, idx1, idx2):
# Expand dims of y explicitly (if needed)
y = y
if
inplace
:
out
=
x
else
:
out
=
x
.
copy
()
# Copy x (if not inplace)
x = x.copy()
if
y_is_broadcasted
:
y
=
np
.
broadcast_to
(
y
,
x_shape
[:
first_axis
]
+
x_shape
[
last_axis
:])
# Move advanced indexed dims to the front (if needed)
# This will remain a view of x
x_adv_dims_front = x.transpose((0, 2, 1))
for
outer
in
np
.
ndindex
(
x_shape
[:
first_axis
]):
for
i
,
scalar_idxs
in
enumerate
(
zip
(
*
vec_idxs
)):
out
[(
*
outer
,
*
scalar_idxs
)]
=
y
[(
*
outer
,
i
)]
return
out
# Perform basic indexing once (if needed)
# This will remain a view of x
basic_indexed_x = x_adv_dims_front[:, :, idx1]
ret_func
=
advanced_set_subtensor_multiple_vector
# Broadcast indices
adv_idx_shape = np.broadcast_shapes(idx0.shape, idx2.shape)
(idx0, idx2) = (np.broadcast_to(idx0, adv_idx_shape), np.broadcast_to(idx2, adv_idx_shape))
else
:
# Move advanced indexed dims to the front (if needed)
y_adv_dims_front = y
@numba_basic.numba_njit
def
advanced_inc_subtensor_multiple_vector
(
x
,
y
,
*
idxs
):
vec_idxs
=
idxs
[
first_axis
:
after_last_axis
]
x_shape
=
x
.
shape
# Broadcast y to the shape of each assignment/update
adv_idx_shape = idx0.shape
basic_idx_shape = basic_indexed_x.shape[2:
]
y_bcast = np.broadcast_to(y_adv_dims_front, (*adv_idx_shape, *basic_idx_shape))
if
inplace
:
out
=
x
else
:
out
=
x
.
copy
()
# Ravel the advanced dims (if needed)
# Note that numba reshape only supports C-arrays, so we ravel before reshape
y_bcast = y_bcast
if
y_is_broadcasted
:
y
=
np
.
broadcast_to
(
y
,
x_shape
[:
first_axis
]
+
x_shape
[
last_axis
:])
# Index over tuples of raveled advanced indices and update buffer
for i, scalar_idxs in enumerate(zip(idx0, idx2)):
basic_indexed_x[scalar_idxs] = y_bcast[i]
for
outer
in
np
.
ndindex
(
x_shape
[:
first_axis
]):
for
i
,
scalar_idxs
in
enumerate
(
zip
(
*
vec_idxs
)):
out
[(
*
outer
,
*
scalar_idxs
)]
+=
y
[(
*
outer
,
i
)]
return
out
# Return the original x, with the entries updated
return x
ret_func
=
advanced_inc_subtensor_multiple_vector
cache_key
=
subtensor_op_cache_key
(
op
,
func
=
"multiple_integer_vector_indexing"
,
y_is_broadcasted
=
y_is_broadcasted
,
first_axis
=
first_axis
,
last_axis
=
last_axis
,
)
return
ret_func
,
cache_key
Codegen for an AdvancedIncSubtensor, with two contiguous advanced groups not in the leading axis
.. code-block::python
@register_funcify_and_cache_key
(
AdvancedIncSubtensor1
)
def
numba_funcify_AdvancedIncSubtensor1
(
op
,
node
,
**
kwargs
):
inplace
=
op
.
inplace
set_instead_of_inc
=
op
.
set_instead_of_inc
x
,
vals
,
_idxs
=
node
.
inputs
broadcast_with_index
=
vals
.
type
.
ndim
<
x
.
type
.
ndim
or
vals
.
type
.
broadcastable
[
0
]
# TODO: Add runtime_broadcast check
if
set_instead_of_inc
:
if
broadcast_with_index
:
@numba_basic.numba_njit
(
boundscheck
=
True
)
def
advancedincsubtensor1_inplace
(
x
,
val
,
idxs
):
if
val
.
ndim
==
x
.
ndim
:
core_val
=
val
[
0
]
elif
val
.
ndim
==
0
:
# Workaround for https://github.com/numba/numba/issues/9573
core_val
=
val
.
item
()
else
:
core_val
=
val
AdvancedIncSubtensor [id A] <Tensor3(int64, shape=(3, 4, 5))>
├─ x [id B] <Tensor3(int64, shape=(3, 4, 5))>
├─ y [id C] <Matrix(int64, shape=(2, 2))>
├─ SliceConstant{1, None, None} [id D] <slice>
├─ [1 2] [id E] <Vector(uint8, shape=(2,))>
└─ [3 4] [id F] <Vector(uint8, shape=(2,))>
def inc_advanced_integer_vector_indexing(x, y, idx0, idx1, idx2):
# Expand dims of y explicitly (if needed)
y = y
# Copy x (if not inplace)
x = x.copy()
# Move advanced indexed dims to the front (if needed)
# This will remain a view of x
x_adv_dims_front = x.transpose((1, 2, 0))
# Perform basic indexing once (if needed)
# This will remain a view of x
basic_indexed_x = x_adv_dims_front[:, :, idx0]
# Broadcast indices
adv_idx_shape = np.broadcast_shapes(idx1.shape, idx2.shape)
(idx1, idx2) = (np.broadcast_to(idx1, adv_idx_shape), np.broadcast_to(idx2, adv_idx_shape))
# Move advanced indexed dims to the front (if needed)
y_adv_dims_front = y.transpose((1, 0))
# Broadcast y to the shape of each assignment/update
adv_idx_shape = idx1.shape
basic_idx_shape = basic_indexed_x.shape[2:]
y_bcast = np.broadcast_to(y_adv_dims_front, (*adv_idx_shape, *basic_idx_shape))
# Ravel the advanced dims (if needed)
# Note that numba reshape only supports C-arrays, so we ravel before reshape
y_bcast = y_bcast
for
idx
in
idxs
:
x
[
idx
]
=
core_val
# Index over tuples of raveled advanced indices and update buffer
for i, scalar_idxs in enumerate(zip(idx1, idx2)):
basic_indexed_x[scalar_idxs] += y_bcast[i]
# Return the original x, with the entries updated
return x
"""
if
isinstance
(
op
,
AdvancedSubtensor1
|
AdvancedSubtensor
):
x
,
*
idxs
=
node
.
inputs
else
:
x
,
y
,
*
idxs
=
node
.
inputs
[
out
]
=
node
.
outputs
@numba_basic.numba_njit
(
boundscheck
=
True
)
def
advancedincsubtensor1_inplace
(
x
,
vals
,
idxs
):
if
not
len
(
idxs
)
==
len
(
vals
):
raise
ValueError
(
"The number of indices and values must match."
)
# no strict argument because incompatible with numba
for
idx
,
val
in
zip
(
idxs
,
vals
):
x
[
idx
]
=
val
return
x
adv_indices_pos
=
tuple
(
i
for
i
,
idx
in
enumerate
(
idxs
)
if
isinstance
(
idx
.
type
,
TensorType
)
)
assert
adv_indices_pos
# Otherwise it's just basic indexing
basic_indices_pos
=
tuple
(
i
for
i
,
idx
in
enumerate
(
idxs
)
if
not
isinstance
(
idx
.
type
,
TensorType
)
)
explicit_basic_indices_pos
=
(
*
basic_indices_pos
,
*
range
(
len
(
idxs
),
x
.
type
.
ndim
))
# Create index signature and split them among basic and advanced
idx_signature
=
", "
.
join
(
f
"idx{i}"
for
i
in
range
(
len
(
idxs
)))
adv_indices
=
[
f
"idx{i}"
for
i
in
adv_indices_pos
]
basic_indices
=
[
f
"idx{i}"
for
i
in
basic_indices_pos
]
# Define transpose axis so that advanced indexing dims are on the front
adv_axis_front_order
=
(
*
adv_indices_pos
,
*
explicit_basic_indices_pos
)
adv_axis_front_transpose_needed
=
adv_axis_front_order
!=
tuple
(
range
(
x
.
ndim
))
adv_idx_ndim
=
max
(
idxs
[
i
]
.
ndim
for
i
in
adv_indices_pos
)
# Helper needed for basic indexing after moving advanced indices to the front
basic_indices_with_none_slices
=
", "
.
join
(
(
*
((
":"
,)
*
len
(
adv_indices
)),
*
basic_indices
)
)
# Position of the first advanced index dimension after indexing the array
if
(
np
.
diff
(
adv_indices_pos
)
>
1
)
.
any
():
# If not consecutive, it's always at the front
out_adv_axis_pos
=
0
else
:
if
broadcast_with_index
:
@numba_basic.numba_njit
(
boundscheck
=
True
)
def
advancedincsubtensor1_inplace
(
x
,
val
,
idxs
):
if
val
.
ndim
==
x
.
ndim
:
core_val
=
val
[
0
]
elif
val
.
ndim
==
0
:
# Workaround for https://github.com/numba/numba/issues/9573
core_val
=
val
.
item
()
# Otherwise wherever the first advanced index is located
out_adv_axis_pos
=
adv_indices_pos
[
0
]
to_tuple
=
create_tuple_string
# alias to make code more readable below
if
isinstance
(
op
,
AdvancedSubtensor1
|
AdvancedSubtensor
):
# Define transpose axis on the output to restore original meaning
# After (potentially) having transposed advanced indexing dims to the front unlike numpy
_final_axis_order
=
list
(
range
(
adv_idx_ndim
,
out
.
type
.
ndim
))
for
i
in
range
(
adv_idx_ndim
):
_final_axis_order
.
insert
(
out_adv_axis_pos
+
i
,
i
)
final_axis_order
=
tuple
(
_final_axis_order
)
del
_final_axis_order
final_axis_transpose_needed
=
final_axis_order
!=
tuple
(
range
(
out
.
type
.
ndim
))
func_name
=
"advanced_integer_vector_indexing"
codegen
=
dedent
(
f
"""
def {func_name}(x, {idx_signature}):
# Move advanced indexed dims to the front (if needed)
x_adv_dims_front = {f"x.transpose({adv_axis_front_order})" if adv_axis_front_transpose_needed else "x"}
# Perform basic indexing once (if needed)
basic_indexed_x = {f"x_adv_dims_front[{basic_indices_with_none_slices}]" if basic_indices else "x_adv_dims_front"}
"""
)
if
len
(
adv_indices
)
>
1
:
codegen
+=
indent
(
dedent
(
f
"""
# Broadcast indices
adv_idx_shape = np.broadcast_shapes{to_tuple([f"{idx}.shape" for idx in adv_indices])}
{to_tuple(adv_indices)} = {to_tuple([f"np.broadcast_to({idx}, adv_idx_shape)" for idx in adv_indices])}
"""
),
" "
*
4
,
)
codegen
+=
indent
(
dedent
(
f
"""
# Create output buffer
adv_idx_size = {adv_indices[0]}.size
basic_idx_shape = basic_indexed_x.shape[{len(adv_indices)}:]
out_buffer = np.empty((adv_idx_size, *basic_idx_shape), dtype=x.dtype)
# Index over tuples of raveled advanced indices and write to output buffer
for i, scalar_idxs in enumerate(zip{to_tuple([f"{idx}.ravel()" for idx in adv_indices] if adv_idx_ndim != 1 else adv_indices)}):
out_buffer[i] = basic_indexed_x[scalar_idxs]
# Unravel out_buffer (if needed)
out_buffer = {f"out_buffer.reshape((*{adv_indices[0]}.shape, *basic_idx_shape))" if adv_idx_ndim != 1 else "out_buffer"}
# Move advanced output indexing group to its final position (if needed) and return
return {f"out_buffer.transpose({final_axis_order})" if final_axis_transpose_needed else "out_buffer"}
"""
),
" "
*
4
,
)
else
:
core_val
=
val
# Make implicit dims of y explicit to simplify code
# Numba doesn't support `np.expand_dims` with multiple axis, so we use indexing with newaxis
indexed_ndim
=
x
[
tuple
(
idxs
)]
.
type
.
ndim
y_expand_dims
=
[
":"
]
*
y
.
type
.
ndim
y_implicit_dims
=
range
(
indexed_ndim
-
y
.
type
.
ndim
)
for
axis
in
y_implicit_dims
:
y_expand_dims
.
insert
(
axis
,
"None"
)
# We transpose the advanced dimensions of x to the front for indexing
# We may have to do the same for y
# Note that if there are non-contiguous advanced indices,
# y must already be aligned with the indices jumping to the front
y_adv_axis_front_order
=
tuple
(
range
(
# Position of the first advanced axis after indexing
out_adv_axis_pos
,
# Position of the last advanced axis after indexing
out_adv_axis_pos
+
adv_idx_ndim
,
)
)
y_order
=
tuple
(
range
(
indexed_ndim
))
y_adv_axis_front_order
=
(
*
y_adv_axis_front_order
,
# Basic indices, after explicit_expand_dims
*
(
o
for
o
in
y_order
if
o
not
in
y_adv_axis_front_order
),
)
y_adv_axis_front_transpose_needed
=
y_adv_axis_front_order
!=
y_order
for
idx
in
idxs
:
x
[
idx
]
+=
core_val
return
x
func_name
=
f
"{'set' if op.set_instead_of_inc else 'inc'}_advanced_integer_vector_indexing"
codegen
=
dedent
(
f
"""
def {func_name}(x, y, {idx_signature}):
# Expand dims of y explicitly (if needed)
y = {f"y[{', '.join(y_expand_dims)},]" if y_implicit_dims else "y"}
else
:
# Copy x (if not inplace)
x = {"x" if op.inplace else "x.copy()"}
@numba_basic.numba_njit
(
boundscheck
=
True
)
def
advancedincsubtensor1_inplace
(
x
,
vals
,
idxs
):
if
not
len
(
idxs
)
==
len
(
vals
):
raise
ValueError
(
"The number of indices and values must match."
)
# no strict argument because unsupported by numba
# TODO: this doesn't come up in tests
for
idx
,
val
in
zip
(
idxs
,
vals
):
x
[
idx
]
+=
val
# Move advanced indexed dims to the front (if needed)
# This will remain a view of x
x_adv_dims_front = {f"x.transpose({adv_axis_front_order})" if adv_axis_front_transpose_needed else "x"}
# Perform basic indexing once (if needed)
# This will remain a view of x
basic_indexed_x = {f"x_adv_dims_front[{basic_indices_with_none_slices}]" if basic_indices else "x_adv_dims_front"}
"""
)
if
len
(
adv_indices
)
>
1
:
codegen
+=
indent
(
dedent
(
f
"""
# Broadcast indices
adv_idx_shape = np.broadcast_shapes{to_tuple([f"{idx}.shape" for idx in adv_indices])}
{to_tuple(adv_indices)} = {to_tuple([f"np.broadcast_to({idx}, adv_idx_shape)" for idx in adv_indices])}
"""
),
" "
*
4
,
)
codegen
+=
indent
(
dedent
(
f
"""
# Move advanced indexed dims to the front (if needed)
y_adv_dims_front = {f"y.transpose({y_adv_axis_front_order})" if y_adv_axis_front_transpose_needed else "y"}
# Broadcast y to the shape of each assignment/update
adv_idx_shape = {adv_indices[0]}.shape
basic_idx_shape = basic_indexed_x.shape[{len(adv_indices)}:]
y_bcast = np.broadcast_to(y_adv_dims_front, (*adv_idx_shape, *basic_idx_shape))
# Ravel the advanced dims (if needed)
# Note that numba reshape only supports C-arrays, so we ravel before reshape
y_bcast = {"y_bcast.ravel().reshape((-1, *basic_idx_shape))" if adv_idx_ndim != 1 else "y_bcast"}
# Index over tuples of raveled advanced indices and update buffer
for i, scalar_idxs in enumerate(zip{to_tuple([f"{idx}.ravel()" for idx in adv_indices] if adv_idx_ndim != 1 else adv_indices)}):
basic_indexed_x[scalar_idxs] {"=" if op.set_instead_of_inc else "+="} y_bcast[i]
# Return the original x, with the entries updated
return x
"""
),
" "
*
4
,
)
cache_key
=
subtensor_op_cache_key
(
op
,
func
=
"numba_funcify_advancedincsubtensor1"
,
broadcast_with_index
=
broadcast_with_index
,
codegen
=
codegen
,
)
if
inplace
:
return
advancedincsubtensor1_inplace
,
cache_key
else
:
@numba_basic.numba_njit
def
advancedincsubtensor1
(
x
,
vals
,
idxs
):
x
=
x
.
copy
()
return
advancedincsubtensor1_inplace
(
x
,
vals
,
idxs
)
return
advancedincsubtensor1
,
cache_key
ret_func
=
numba_basic
.
numba_njit
(
compile_numba_function_src
(
codegen
,
function_name
=
func_name
,
global_env
=
globals
(),
cache_key
=
cache_key
,
)
)
return
ret_func
,
cache_key
pytensor/tensor/rewriting/subtensor.py
浏览文件 @
d8501d14
...
...
@@ -83,7 +83,7 @@ from pytensor.tensor.subtensor import (
inc_subtensor
,
indices_from_subtensor
,
)
from
pytensor.tensor.type
import
TensorType
,
integer_dtypes
from
pytensor.tensor.type
import
TensorType
from
pytensor.tensor.type_other
import
NoneTypeT
,
SliceType
from
pytensor.tensor.variable
import
TensorConstant
,
TensorVariable
...
...
@@ -1744,205 +1744,45 @@ def local_blockwise_inc_subtensor(fgraph, node):
@node_rewriter
(
tracks
=
[
AdvancedSubtensor
,
AdvancedIncSubtensor
])
def
ravel_multidimensional_bool_idx
(
fgraph
,
node
):
"""Convert
multidimensional boolean indexing into equivalent vector boolean index, supported by Numba
def
bool_idx_to_nonzero
(
fgraph
,
node
):
"""Convert
boolean indexing into equivalent vector boolean index, supported by our dispatch
x[eye(3, dtype=bool)] -> x.ravel()[eye(3).ravel()]
x[eye(3, dtype=bool)].set(y) -> x.ravel()[eye(3).ravel()].set(y).reshape(x.shape)
x[1:, eye(3, dtype=bool), 1:] -> x[1:, *eye(3).nonzero()]
"""
if
isinstance
(
node
.
op
,
AdvancedSubtensor
):
x
,
*
idxs
=
node
.
inputs
else
:
x
,
y
,
*
idxs
=
node
.
inputs
if
any
(
(
(
isinstance
(
idx
.
type
,
TensorType
)
and
idx
.
type
.
dtype
in
integer_dtypes
)
or
isinstance
(
idx
.
type
,
NoneTypeT
)
)
for
idx
in
idxs
):
# Get out if there are any other advanced indexes or np.newaxis
return
None
bool_idxs
=
[
(
i
,
idx
)
bool_pos
=
{
i
for
i
,
idx
in
enumerate
(
idxs
)
if
(
isinstance
(
idx
.
type
,
TensorType
)
and
idx
.
dtype
==
"bool"
)
]
if
len
(
bool_idxs
)
!=
1
:
# Get out if there are no or multiple boolean idxs
return
None
}
[(
bool_idx_pos
,
bool_idx
)]
=
bool_idxs
bool_idx_ndim
=
bool_idx
.
type
.
ndim
if
bool_idx
.
type
.
ndim
<
2
:
# No need to do anything if it's a vector or scalar, as it's already supported by Numba
if
not
bool_pos
:
return
None
x_shape
=
x
.
shape
raveled_x
=
x
.
reshape
(
(
*
x_shape
[:
bool_idx_pos
],
-
1
,
*
x_shape
[
bool_idx_pos
+
bool_idx_ndim
:])
)
raveled_bool_idx
=
bool_idx
.
ravel
()
new_idxs
=
list
(
idxs
)
new_idxs
[
bool_idx_pos
]
=
raveled_bool_idx
if
isinstance
(
node
.
op
,
AdvancedSubtensor
):
new_out
=
node
.
op
(
raveled_x
,
*
new_idxs
)
new_idxs
=
[]
for
i
,
idx
in
enumerate
(
idxs
):
if
i
in
bool_pos
:
new_idxs
.
extend
(
idx
.
nonzero
())
else
:
# The dimensions of y that correspond to the boolean indices
# must already be raveled in the original graph, so we don't need to do anything to it
new_out
=
node
.
op
(
raveled_x
,
y
,
*
new_idxs
)
# But we must reshape the output to math the original shape
new_out
=
new_out
.
reshape
(
x_shape
)
return
[
copy_stack_trace
(
node
.
outputs
[
0
],
new_out
)]
@node_rewriter
(
tracks
=
[
AdvancedSubtensor
,
AdvancedIncSubtensor
])
def
ravel_multidimensional_int_idx
(
fgraph
,
node
):
"""Convert multidimensional integer indexing into equivalent consecutive vector integer index,
supported by Numba or by our specialized dispatchers
x[eye(3)] -> x[eye(3).ravel()].reshape((3, 3))
NOTE: This is very similar to the rewrite `local_replace_AdvancedSubtensor` except it also handles non-full slices
x[eye(3), 2:] -> x[eye(3).ravel(), 2:].reshape((3, 3, ...)), where ... are the remaining output shapes
It also handles multiple integer indices, but only if they don't broadcast
x[eye(3,), 2:, eye(3)] -> x[eye(3).ravel(), eye(3).ravel(), 2:].reshape((3, 3, ...)), where ... are the remaining output shapes
Also handles AdvancedIncSubtensor, but only if the advanced indices are consecutive and neither indices nor y broadcast
x[eye(3), 2:].set(y) -> x[eye(3).ravel(), 2:].set(y.reshape(-1, y.shape[1:]))
"""
op
=
node
.
op
non_consecutive_adv_indexing
=
op
.
non_consecutive_adv_indexing
(
node
)
is_inc_subtensor
=
isinstance
(
op
,
AdvancedIncSubtensor
)
if
is_inc_subtensor
:
x
,
y
,
*
idxs
=
node
.
inputs
# Inc/SetSubtensor is harder to reason about due to y
# We get out if it's broadcasting or if the advanced indices are non-consecutive
if
non_consecutive_adv_indexing
or
(
y
.
type
.
broadcastable
!=
x
[
tuple
(
idxs
)]
.
type
.
broadcastable
):
return
None
else
:
x
,
*
idxs
=
node
.
inputs
if
any
(
(
(
isinstance
(
idx
.
type
,
TensorType
)
and
idx
.
type
.
dtype
==
"bool"
)
or
isinstance
(
idx
.
type
,
NoneTypeT
)
)
for
idx
in
idxs
):
# Get out if there are any other advanced indices or np.newaxis
return
None
int_idxs_and_pos
=
[
(
i
,
idx
)
for
i
,
idx
in
enumerate
(
idxs
)
if
(
isinstance
(
idx
.
type
,
TensorType
)
and
idx
.
dtype
in
integer_dtypes
)
]
if
not
int_idxs_and_pos
:
return
None
int_idxs_pos
,
int_idxs
=
zip
(
*
int_idxs_and_pos
,
strict
=
False
)
# strict=False because by definition it's true
first_int_idx_pos
=
int_idxs_pos
[
0
]
first_int_idx
=
int_idxs
[
0
]
first_int_idx_bcast
=
first_int_idx
.
type
.
broadcastable
if
any
(
int_idx
.
type
.
broadcastable
!=
first_int_idx_bcast
for
int_idx
in
int_idxs
):
# We don't have a view-only broadcasting operation
# Explicitly broadcasting the indices can incur a memory / copy overhead
return
None
int_idxs_ndim
=
len
(
first_int_idx_bcast
)
if
(
int_idxs_ndim
==
0
):
# This should be a basic indexing operation, rewrite elsewhere
return
None
int_idxs_need_raveling
=
int_idxs_ndim
>
1
if
not
(
int_idxs_need_raveling
or
non_consecutive_adv_indexing
):
# Numba or our dispatch natively supports consecutive vector indices, nothing needs to be done
return
None
# Reorder non-consecutive indices
if
non_consecutive_adv_indexing
:
assert
not
is_inc_subtensor
# Sanity check that we got out if this was the case
# This case works as if all the advanced indices were on the front
transposition
=
list
(
int_idxs_pos
)
+
[
i
for
i
in
range
(
len
(
idxs
))
if
i
not
in
int_idxs_pos
]
idxs
=
tuple
(
idxs
[
a
]
for
a
in
transposition
)
x
=
x
.
transpose
(
transposition
)
first_int_idx_pos
=
0
del
int_idxs_pos
# Make sure they are not wrongly used
# Ravel multidimensional indices
if
int_idxs_need_raveling
:
idxs
=
list
(
idxs
)
for
idx_pos
,
int_idx
in
enumerate
(
int_idxs
,
start
=
first_int_idx_pos
):
idxs
[
idx_pos
]
=
int_idx
.
ravel
()
# Index with reordered and/or raveled indices
new_subtensor
=
x
[
tuple
(
idxs
)]
if
is_inc_subtensor
:
y_shape
=
tuple
(
y
.
shape
)
y_raveled_shape
=
(
*
y_shape
[:
first_int_idx_pos
],
-
1
,
*
y_shape
[
first_int_idx_pos
+
int_idxs_ndim
:],
)
y_raveled
=
y
.
reshape
(
y_raveled_shape
)
new_out
=
inc_subtensor
(
new_subtensor
,
y_raveled
,
set_instead_of_inc
=
op
.
set_instead_of_inc
,
ignore_duplicates
=
op
.
ignore_duplicates
,
inplace
=
op
.
inplace
,
)
new_idxs
.
append
(
idx
)
if
isinstance
(
node
.
op
,
AdvancedSubtensor
):
new_out
=
node
.
op
(
x
,
*
new_idxs
)
else
:
# Unravel advanced indexing dimensions
raveled_shape
=
tuple
(
new_subtensor
.
shape
)
unraveled_shape
=
(
*
raveled_shape
[:
first_int_idx_pos
],
*
first_int_idx
.
shape
,
*
raveled_shape
[
first_int_idx_pos
+
1
:],
)
new_out
=
new_subtensor
.
reshape
(
unraveled_shape
)
new_out
=
node
.
op
(
x
,
y
,
*
new_idxs
)
return
[
copy_stack_trace
(
node
.
outputs
[
0
],
new_out
)]
optdb
[
"specialize"
]
.
register
(
ravel_multidimensional_bool_idx
.
__name__
,
ravel_multidimensional_bool_idx
,
"numba"
,
use_db_name_as_tag
=
False
,
# Not included if only "specialize" is requested
)
optdb
[
"specialize"
]
.
register
(
ravel_multidimensional_int_idx
.
__name__
,
ravel_multidimensional_int_idx
,
bool_idx_to_nonzero
.
__name__
,
bool_idx_to_nonzero
,
"numba"
,
"shape_unsafe"
,
# It can mask invalid mask sizes
use_db_name_as_tag
=
False
,
# Not included if only "specialize" is requested
)
...
...
tests/link/numba/test_subtensor.py
浏览文件 @
d8501d14
...
...
@@ -109,115 +109,93 @@ def test_AdvancedSubtensor1_out_of_bounds():
@pytest.mark.parametrize
(
"x, indices
, objmode_needed
"
,
"x, indices"
,
[
# Single vector indexing
(supported natively by Numba)
# Single vector indexing
(
as_tensor
(
np
.
arange
(
3
*
4
*
5
)
.
reshape
((
3
,
4
,
5
))),
(
0
,
[
1
,
2
,
2
,
3
]),
False
,
),
(
as_tensor
(
np
.
arange
(
3
*
4
*
5
)
.
reshape
((
3
,
4
,
5
))),
(
np
.
array
([
True
,
False
,
False
])),
False
,
),
# Single multidimensional indexing
(supported after specialization rewrites)
# Single multidimensional indexing
(
as_tensor
(
np
.
arange
(
3
*
3
)
.
reshape
((
3
,
3
))),
(
np
.
eye
(
3
)
.
astype
(
int
)),
False
,
),
(
as_tensor
(
np
.
arange
(
3
*
3
)
.
reshape
((
3
,
3
))),
(
np
.
eye
(
3
)
.
astype
(
bool
)),
False
,
),
(
as_tensor
(
np
.
arange
(
3
*
3
*
2
)
.
reshape
((
3
,
3
,
2
))),
(
np
.
eye
(
3
)
.
astype
(
int
)),
False
,
),
(
as_tensor
(
np
.
arange
(
3
*
3
*
2
)
.
reshape
((
3
,
3
,
2
))),
(
np
.
eye
(
3
)
.
astype
(
bool
)),
False
,
),
(
as_tensor
(
np
.
arange
(
2
*
3
*
3
)
.
reshape
((
2
,
3
,
3
))),
(
slice
(
2
,
None
),
np
.
eye
(
3
)
.
astype
(
int
)),
False
,
),
(
as_tensor
(
np
.
arange
(
2
*
3
*
3
)
.
reshape
((
2
,
3
,
3
))),
(
slice
(
2
,
None
),
np
.
eye
(
3
)
.
astype
(
bool
)),
False
,
),
# Multiple vector indexing
(supported by our dispatcher)
# Multiple vector indexing
(
pt
.
as_tensor
(
np
.
arange
(
3
*
4
*
5
)
.
reshape
((
3
,
4
,
5
))),
([
1
,
2
],
[
2
,
3
]),
False
,
),
(
as_tensor
(
np
.
arange
(
3
*
4
*
5
)
.
reshape
((
3
,
4
,
5
))),
(
slice
(
None
),
[
1
,
2
],
[
3
,
4
]),
False
,
),
(
as_tensor
(
np
.
arange
(
3
*
5
*
7
)
.
reshape
((
3
,
5
,
7
))),
([
1
,
2
],
[
3
,
4
],
[
5
,
6
]),
False
,
),
# Non-consecutive vector indexing
, supported by our dispatcher after rewriting
# Non-consecutive vector indexing
(
as_tensor
(
np
.
arange
(
3
*
4
*
5
)
.
reshape
((
3
,
4
,
5
))),
([
1
,
2
],
slice
(
None
),
[
3
,
4
]),
False
,
),
# Multiple multidimensional integer indexing
(supported by our dispatcher)
# Multiple multidimensional integer indexing
(
as_tensor
(
np
.
arange
(
3
*
4
*
5
)
.
reshape
((
3
,
4
,
5
))),
([[
1
,
2
],
[
2
,
1
]],
[[
0
,
0
],
[
0
,
0
]]),
False
,
),
(
as_tensor
(
np
.
arange
(
2
*
3
*
4
*
5
)
.
reshape
((
2
,
3
,
4
,
5
))),
(
slice
(
None
),
[[
1
,
2
],
[
2
,
1
]],
slice
(
None
),
[[
0
,
0
],
[
0
,
0
]]),
False
,
),
# Multiple multidimensional indexing with broadcasting
, only supported in obj mode
# Multiple multidimensional indexing with broadcasting
(
as_tensor
(
np
.
arange
(
3
*
4
*
5
)
.
reshape
((
3
,
4
,
5
))),
([[
1
,
2
],
[
2
,
1
]],
[
0
,
0
]),
True
,
),
# multiple multidimensional integer indexing mixed with basic indexing
, only supported in obj mode
# multiple multidimensional integer indexing mixed with basic indexing
(
as_tensor
(
np
.
arange
(
3
*
4
*
5
)
.
reshape
((
3
,
4
,
5
))),
([[
1
,
2
],
[
2
,
1
]],
slice
(
1
,
None
),
[[
0
,
0
],
[
0
,
0
]]),
True
,
),
],
)
@pytest.mark.filterwarnings
(
"error"
)
# Raise if we did not expect objmode to be needed
def
test_AdvancedSubtensor
(
x
,
indices
,
objmode_needed
):
def
test_AdvancedSubtensor
(
x
,
indices
):
"""Test NumPy's advanced indexing in more than one dimension."""
x_pt
=
x
.
type
()
out_pt
=
x_pt
[
indices
]
assert
isinstance
(
out_pt
.
owner
.
op
,
AdvancedSubtensor
)
with
(
pytest
.
warns
(
UserWarning
,
match
=
"Numba will use object mode to run AdvancedSubtensor's perform method"
,
)
if
objmode_needed
else
contextlib
.
nullcontext
()
):
compare_numba_and_py
(
[
x_pt
],
[
out_pt
],
[
x
.
data
],
# Specialize allows running boolean indexing without falling back to object mode
# Thanks to bool_idx_to_nonzero rewrite
numba_mode
=
numba_mode
.
including
(
"specialize"
),
)
...
...
@@ -323,7 +301,7 @@ def test_AdvancedIncSubtensor1(x, y, indices):
@pytest.mark.parametrize
(
"x, y, indices, duplicate_indices,
set_requires_objmode, inc_requires_obj
mode"
,
"x, y, indices, duplicate_indices,
duplicate_indices_require_obj_
mode"
,
[
(
np
.
arange
(
3
*
4
*
5
)
.
reshape
((
3
,
4
,
5
)),
...
...
@@ -331,7 +309,6 @@ def test_AdvancedIncSubtensor1(x, y, indices):
(
slice
(
None
,
None
,
2
),
[
1
,
2
,
3
]),
# Mixed basic and vector index
False
,
False
,
False
,
),
(
np
.
arange
(
3
*
4
*
5
)
.
reshape
((
3
,
4
,
5
)),
...
...
@@ -343,7 +320,6 @@ def test_AdvancedIncSubtensor1(x, y, indices):
),
# Mixed basic and broadcasted vector idx
False
,
False
,
False
,
),
(
np
.
arange
(
3
*
4
*
5
)
.
reshape
((
3
,
4
,
5
)),
...
...
@@ -351,7 +327,6 @@ def test_AdvancedIncSubtensor1(x, y, indices):
(
slice
(
None
,
None
,
2
),
[
1
,
2
,
3
]),
# Mixed basic and vector idx
False
,
False
,
False
,
),
(
np
.
arange
(
3
*
4
*
5
)
.
reshape
((
3
,
4
,
5
)),
...
...
@@ -359,7 +334,6 @@ def test_AdvancedIncSubtensor1(x, y, indices):
(
0
,
[
1
,
2
,
2
,
3
]),
# Broadcasted vector index with repeated values
True
,
False
,
True
,
),
(
np
.
arange
(
3
*
4
*
5
)
.
reshape
((
3
,
4
,
5
)),
...
...
@@ -367,21 +341,11 @@ def test_AdvancedIncSubtensor1(x, y, indices):
(
0
,
[
1
,
2
,
2
,
3
]),
# Broadcasted vector index with repeated values
True
,
False
,
True
,
),
(
np
.
arange
(
3
*
4
*
5
)
.
reshape
((
3
,
4
,
5
)),
-
np
.
arange
(
1
*
4
*
5
)
.
reshape
(
1
,
4
,
5
),
(
np
.
array
([
True
,
False
,
False
])),
# Broadcasted boolean index
False
,
# It shouldn't matter what we set this to, boolean indices cannot be duplicate
False
,
False
,
),
(
np
.
arange
(
3
*
4
*
5
)
.
reshape
((
3
,
4
,
5
)),
-
np
.
arange
(
1
*
4
*
5
)
.
reshape
(
1
,
4
,
5
),
(
np
.
array
([
True
,
False
,
False
])),
# Broadcasted boolean index
True
,
# It shouldn't matter what we set this to, boolean indices cannot be duplicate
False
,
False
,
),
...
...
@@ -391,7 +355,6 @@ def test_AdvancedIncSubtensor1(x, y, indices):
(
np
.
eye
(
3
)
.
astype
(
bool
)),
# Boolean index
False
,
False
,
False
,
),
(
np
.
arange
(
3
*
3
*
5
)
.
reshape
((
3
,
3
,
5
)),
...
...
@@ -402,7 +365,6 @@ def test_AdvancedIncSubtensor1(x, y, indices):
),
# Boolean index, mixed with basic index
False
,
False
,
False
,
),
(
np
.
arange
(
3
*
4
*
5
)
.
reshape
((
3
,
4
,
5
)),
...
...
@@ -410,7 +372,6 @@ def test_AdvancedIncSubtensor1(x, y, indices):
([
1
,
2
],
[
2
,
3
]),
# 2 vector indices
False
,
False
,
False
,
),
(
np
.
arange
(
3
*
4
*
5
)
.
reshape
((
3
,
4
,
5
)),
...
...
@@ -418,7 +379,6 @@ def test_AdvancedIncSubtensor1(x, y, indices):
(
slice
(
None
),
[
1
,
2
],
[
2
,
3
]),
# 2 vector indices
False
,
False
,
False
,
),
(
np
.
arange
(
3
*
4
*
6
)
.
reshape
((
3
,
4
,
6
)),
...
...
@@ -426,7 +386,6 @@ def test_AdvancedIncSubtensor1(x, y, indices):
([
1
,
2
],
[
2
,
3
],
[
4
,
5
]),
# 3 vector indices
False
,
False
,
False
,
),
(
np
.
arange
(
3
*
4
*
5
)
.
reshape
((
3
,
4
,
5
)),
...
...
@@ -434,15 +393,13 @@ def test_AdvancedIncSubtensor1(x, y, indices):
([
1
,
2
],
[
2
,
3
]),
# 2 vector indices
False
,
False
,
False
,
),
(
np
.
arange
(
3
*
4
*
5
)
.
reshape
((
3
,
4
,
5
)),
rng
.
poisson
(
size
=
(
2
,
4
)),
([
1
,
2
],
slice
(
None
),
[
3
,
4
]),
# Non-consecutive vector indices
False
,
True
,
True
,
False
,
),
(
np
.
arange
(
3
*
4
*
5
)
.
reshape
((
3
,
4
,
5
)),
...
...
@@ -453,8 +410,7 @@ def test_AdvancedIncSubtensor1(x, y, indices):
[
3
,
4
],
),
# Mixed double vector index and basic index
False
,
True
,
True
,
False
,
),
(
np
.
arange
(
5
),
...
...
@@ -462,7 +418,6 @@ def test_AdvancedIncSubtensor1(x, y, indices):
([[
1
,
2
],
[
2
,
3
]]),
# matrix index
False
,
False
,
False
,
),
(
np
.
arange
(
3
*
5
)
.
reshape
((
3
,
5
)),
...
...
@@ -470,23 +425,20 @@ def test_AdvancedIncSubtensor1(x, y, indices):
(
slice
(
1
,
3
),
[[
1
,
2
],
[
2
,
3
]]),
# matrix index, mixed with basic index
False
,
False
,
False
,
),
(
np
.
arange
(
3
*
5
)
.
reshape
((
3
,
5
)),
rng
.
poisson
(
size
=
(
1
,
2
,
2
)),
# Same as before, but Y broadcasts
(
slice
(
1
,
3
),
[[
1
,
2
],
[
2
,
3
]]),
False
,
True
,
True
,
False
,
),
(
np
.
arange
(
3
*
4
*
5
)
.
reshape
((
3
,
4
,
5
)),
rng
.
poisson
(
size
=
(
2
,
5
)),
([
1
,
1
],
[
2
,
2
]),
# Repeated indices
True
,
False
,
False
,
True
,
),
(
np
.
arange
(
3
*
4
*
5
)
.
reshape
((
3
,
4
,
5
)),
...
...
@@ -494,7 +446,6 @@ def test_AdvancedIncSubtensor1(x, y, indices):
(
slice
(
None
),
[[
1
,
2
],
[
2
,
1
]],
[[
2
,
3
],
[
0
,
0
]]),
# 2 matrix indices
False
,
False
,
False
,
),
],
)
...
...
@@ -505,8 +456,7 @@ def test_AdvancedIncSubtensor(
y
,
indices
,
duplicate_indices
,
set_requires_objmode
,
inc_requires_objmode
,
duplicate_indices_require_obj_mode
,
inplace
,
):
# Need rewrite to support certain forms of advanced indexing without object mode
...
...
@@ -518,14 +468,6 @@ def test_AdvancedIncSubtensor(
out_pt
=
set_subtensor
(
x_pt
[
indices
],
y_pt
,
inplace
=
inplace
)
assert
isinstance
(
out_pt
.
owner
.
op
,
AdvancedIncSubtensor
)
with
(
pytest
.
warns
(
UserWarning
,
match
=
"Numba will use object mode to run AdvancedSetSubtensor's perform method"
,
)
if
set_requires_objmode
else
contextlib
.
nullcontext
()
):
fn
,
_
=
compare_numba_and_py
(
[
x_pt
,
y_pt
],
out_pt
,
[
x
,
y
],
numba_mode
=
mode
,
inplace
=
inplace
)
...
...
@@ -536,16 +478,32 @@ def test_AdvancedIncSubtensor(
fn
(
x
,
y
+
1
)
assert
not
np
.
all
(
x
==
x_orig
)
out_pt
=
inc_subtensor
(
x_pt
[
indices
],
y_pt
,
inplace
=
inplace
)
assert
isinstance
(
out_pt
.
owner
.
op
,
AdvancedIncSubtensor
)
fn
,
_
=
compare_numba_and_py
(
[
x_pt
,
y_pt
],
out_pt
,
[
x
,
y
],
numba_mode
=
mode
,
inplace
=
inplace
)
if
inplace
:
# Test updates inplace
x_orig
=
x
.
copy
()
fn
(
x
,
y
)
assert
not
np
.
all
(
x
==
x_orig
)
if
duplicate_indices
:
# If inc_subtensor is called with `ignore_duplicates=True`, and it's not one of the cases supported by Numba
# We have to fall back to obj_mode
out_pt
=
inc_subtensor
(
x_pt
[
indices
],
y_pt
,
ignore_duplicates
=
not
duplicate_indices
,
inplace
=
inplac
e
x_pt
[
indices
],
y_pt
,
inplace
=
inplace
,
ignore_duplicates
=
Tru
e
)
assert
isinstance
(
out_pt
.
owner
.
op
,
AdvancedIncSubtensor
)
with
(
pytest
.
warns
(
UserWarning
,
match
=
"Numba will use object mode to run AdvancedIncSubtensor's perform method"
,
)
if
inc_requires_obj
mode
if
duplicate_indices_require_obj_
mode
else
contextlib
.
nullcontext
()
):
fn
,
_
=
compare_numba_and_py
(
...
...
@@ -556,3 +514,22 @@ def test_AdvancedIncSubtensor(
x_orig
=
x
.
copy
()
fn
(
x
,
y
)
assert
not
np
.
all
(
x
==
x_orig
)
def
test_advanced_indexing_with_newaxis_fallback_obj_mode
():
# This should be automatically solved with https://github.com/pymc-devs/pytensor/issues/1564
# After which we can add these parametrizations to the relevant tests above
x
=
pt
.
matrix
(
"x"
)
out
=
x
[
None
,
[
0
,
1
,
2
],
[
0
,
1
,
2
]]
with
pytest
.
warns
(
UserWarning
,
match
=
r"Numba will use object mode to run AdvancedSubtensor's perform method"
,
):
compare_numba_and_py
([
x
],
[
out
],
[
np
.
random
.
normal
(
size
=
(
4
,
4
))])
out
=
x
[
None
,
[
0
,
1
,
2
],
[
0
,
1
,
2
]]
.
inc
(
5
)
with
pytest
.
warns
(
UserWarning
,
match
=
r"Numba will use object mode to run AdvancedIncSubtensor's perform method"
,
):
compare_numba_and_py
([
x
],
[
out
],
[
np
.
random
.
normal
(
size
=
(
4
,
4
))])
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论