Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
d8501d14
提交
d8501d14
authored
12月 07, 2025
作者:
Ricardo Vieira
提交者:
Ricardo Vieira
12月 14, 2025
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Numba AdvancedIndexing: Complete support for integer (and mixed basic) advanced indexing
When default `ignore_updates=True` for inc_subtensor, and boolean indices were rewritten during specialize
上级
fe10f960
全部展开
显示空白字符变更
内嵌
并排
正在显示
3 个修改的文件
包含
19 行增加
和
179 行删除
+19
-179
subtensor.py
pytensor/link/numba/dispatch/subtensor.py
+0
-0
subtensor.py
pytensor/tensor/rewriting/subtensor.py
+19
-179
test_subtensor.py
tests/link/numba/test_subtensor.py
+0
-0
没有找到文件。
pytensor/link/numba/dispatch/subtensor.py
浏览文件 @
d8501d14
差异被折叠。
点击展开。
pytensor/tensor/rewriting/subtensor.py
浏览文件 @
d8501d14
...
@@ -83,7 +83,7 @@ from pytensor.tensor.subtensor import (
...
@@ -83,7 +83,7 @@ from pytensor.tensor.subtensor import (
inc_subtensor
,
inc_subtensor
,
indices_from_subtensor
,
indices_from_subtensor
,
)
)
from
pytensor.tensor.type
import
TensorType
,
integer_dtypes
from
pytensor.tensor.type
import
TensorType
from
pytensor.tensor.type_other
import
NoneTypeT
,
SliceType
from
pytensor.tensor.type_other
import
NoneTypeT
,
SliceType
from
pytensor.tensor.variable
import
TensorConstant
,
TensorVariable
from
pytensor.tensor.variable
import
TensorConstant
,
TensorVariable
...
@@ -1744,205 +1744,45 @@ def local_blockwise_inc_subtensor(fgraph, node):
...
@@ -1744,205 +1744,45 @@ def local_blockwise_inc_subtensor(fgraph, node):
@node_rewriter
(
tracks
=
[
AdvancedSubtensor
,
AdvancedIncSubtensor
])
@node_rewriter
(
tracks
=
[
AdvancedSubtensor
,
AdvancedIncSubtensor
])
def
ravel_multidimensional_bool_idx
(
fgraph
,
node
):
def
bool_idx_to_nonzero
(
fgraph
,
node
):
"""Convert
multidimensional boolean indexing into equivalent vector boolean index, supported by Numba
"""Convert
boolean indexing into equivalent vector boolean index, supported by our dispatch
x[eye(3, dtype=bool)] -> x.ravel()[eye(3).ravel()]
x[1:, eye(3, dtype=bool), 1:] -> x[1:, *eye(3).nonzero()]
x[eye(3, dtype=bool)].set(y) -> x.ravel()[eye(3).ravel()].set(y).reshape(x.shape)
"""
"""
if
isinstance
(
node
.
op
,
AdvancedSubtensor
):
if
isinstance
(
node
.
op
,
AdvancedSubtensor
):
x
,
*
idxs
=
node
.
inputs
x
,
*
idxs
=
node
.
inputs
else
:
else
:
x
,
y
,
*
idxs
=
node
.
inputs
x
,
y
,
*
idxs
=
node
.
inputs
if
any
(
bool_pos
=
{
(
i
(
isinstance
(
idx
.
type
,
TensorType
)
and
idx
.
type
.
dtype
in
integer_dtypes
)
or
isinstance
(
idx
.
type
,
NoneTypeT
)
)
for
idx
in
idxs
):
# Get out if there are any other advanced indexes or np.newaxis
return
None
bool_idxs
=
[
(
i
,
idx
)
for
i
,
idx
in
enumerate
(
idxs
)
for
i
,
idx
in
enumerate
(
idxs
)
if
(
isinstance
(
idx
.
type
,
TensorType
)
and
idx
.
dtype
==
"bool"
)
if
(
isinstance
(
idx
.
type
,
TensorType
)
and
idx
.
dtype
==
"bool"
)
]
}
if
len
(
bool_idxs
)
!=
1
:
# Get out if there are no or multiple boolean idxs
return
None
[(
bool_idx_pos
,
bool_idx
)]
=
bool_idxs
if
not
bool_pos
:
bool_idx_ndim
=
bool_idx
.
type
.
ndim
if
bool_idx
.
type
.
ndim
<
2
:
# No need to do anything if it's a vector or scalar, as it's already supported by Numba
return
None
return
None
x_shape
=
x
.
shape
new_idxs
=
[]
raveled_x
=
x
.
reshape
(
for
i
,
idx
in
enumerate
(
idxs
):
(
*
x_shape
[:
bool_idx_pos
],
-
1
,
*
x_shape
[
bool_idx_pos
+
bool_idx_ndim
:])
if
i
in
bool_pos
:
)
new_idxs
.
extend
(
idx
.
nonzero
())
raveled_bool_idx
=
bool_idx
.
ravel
()
new_idxs
=
list
(
idxs
)
new_idxs
[
bool_idx_pos
]
=
raveled_bool_idx
if
isinstance
(
node
.
op
,
AdvancedSubtensor
):
new_out
=
node
.
op
(
raveled_x
,
*
new_idxs
)
else
:
else
:
# The dimensions of y that correspond to the boolean indices
new_idxs
.
append
(
idx
)
# must already be raveled in the original graph, so we don't need to do anything to it
new_out
=
node
.
op
(
raveled_x
,
y
,
*
new_idxs
)
# But we must reshape the output to math the original shape
new_out
=
new_out
.
reshape
(
x_shape
)
return
[
copy_stack_trace
(
node
.
outputs
[
0
],
new_out
)]
@node_rewriter
(
tracks
=
[
AdvancedSubtensor
,
AdvancedIncSubtensor
])
def
ravel_multidimensional_int_idx
(
fgraph
,
node
):
"""Convert multidimensional integer indexing into equivalent consecutive vector integer index,
supported by Numba or by our specialized dispatchers
x[eye(3)] -> x[eye(3).ravel()].reshape((3, 3))
NOTE: This is very similar to the rewrite `local_replace_AdvancedSubtensor` except it also handles non-full slices
x[eye(3), 2:] -> x[eye(3).ravel(), 2:].reshape((3, 3, ...)), where ... are the remaining output shapes
It also handles multiple integer indices, but only if they don't broadcast
x[eye(3,), 2:, eye(3)] -> x[eye(3).ravel(), eye(3).ravel(), 2:].reshape((3, 3, ...)), where ... are the remaining output shapes
Also handles AdvancedIncSubtensor, but only if the advanced indices are consecutive and neither indices nor y broadcast
x[eye(3), 2:].set(y) -> x[eye(3).ravel(), 2:].set(y.reshape(-1, y.shape[1:]))
"""
op
=
node
.
op
non_consecutive_adv_indexing
=
op
.
non_consecutive_adv_indexing
(
node
)
is_inc_subtensor
=
isinstance
(
op
,
AdvancedIncSubtensor
)
if
is_inc_subtensor
:
x
,
y
,
*
idxs
=
node
.
inputs
# Inc/SetSubtensor is harder to reason about due to y
# We get out if it's broadcasting or if the advanced indices are non-consecutive
if
non_consecutive_adv_indexing
or
(
y
.
type
.
broadcastable
!=
x
[
tuple
(
idxs
)]
.
type
.
broadcastable
):
return
None
else
:
x
,
*
idxs
=
node
.
inputs
if
any
(
(
(
isinstance
(
idx
.
type
,
TensorType
)
and
idx
.
type
.
dtype
==
"bool"
)
or
isinstance
(
idx
.
type
,
NoneTypeT
)
)
for
idx
in
idxs
):
# Get out if there are any other advanced indices or np.newaxis
return
None
int_idxs_and_pos
=
[
(
i
,
idx
)
for
i
,
idx
in
enumerate
(
idxs
)
if
(
isinstance
(
idx
.
type
,
TensorType
)
and
idx
.
dtype
in
integer_dtypes
)
]
if
not
int_idxs_and_pos
:
return
None
int_idxs_pos
,
int_idxs
=
zip
(
*
int_idxs_and_pos
,
strict
=
False
)
# strict=False because by definition it's true
first_int_idx_pos
=
int_idxs_pos
[
0
]
first_int_idx
=
int_idxs
[
0
]
first_int_idx_bcast
=
first_int_idx
.
type
.
broadcastable
if
any
(
int_idx
.
type
.
broadcastable
!=
first_int_idx_bcast
for
int_idx
in
int_idxs
):
# We don't have a view-only broadcasting operation
# Explicitly broadcasting the indices can incur a memory / copy overhead
return
None
int_idxs_ndim
=
len
(
first_int_idx_bcast
)
if
(
int_idxs_ndim
==
0
):
# This should be a basic indexing operation, rewrite elsewhere
return
None
int_idxs_need_raveling
=
int_idxs_ndim
>
1
if
not
(
int_idxs_need_raveling
or
non_consecutive_adv_indexing
):
# Numba or our dispatch natively supports consecutive vector indices, nothing needs to be done
return
None
# Reorder non-consecutive indices
if
non_consecutive_adv_indexing
:
assert
not
is_inc_subtensor
# Sanity check that we got out if this was the case
# This case works as if all the advanced indices were on the front
transposition
=
list
(
int_idxs_pos
)
+
[
i
for
i
in
range
(
len
(
idxs
))
if
i
not
in
int_idxs_pos
]
idxs
=
tuple
(
idxs
[
a
]
for
a
in
transposition
)
x
=
x
.
transpose
(
transposition
)
first_int_idx_pos
=
0
del
int_idxs_pos
# Make sure they are not wrongly used
# Ravel multidimensional indices
if
int_idxs_need_raveling
:
idxs
=
list
(
idxs
)
for
idx_pos
,
int_idx
in
enumerate
(
int_idxs
,
start
=
first_int_idx_pos
):
idxs
[
idx_pos
]
=
int_idx
.
ravel
()
# Index with reordered and/or raveled indices
new_subtensor
=
x
[
tuple
(
idxs
)]
if
is_inc_subtensor
:
y_shape
=
tuple
(
y
.
shape
)
y_raveled_shape
=
(
*
y_shape
[:
first_int_idx_pos
],
-
1
,
*
y_shape
[
first_int_idx_pos
+
int_idxs_ndim
:],
)
y_raveled
=
y
.
reshape
(
y_raveled_shape
)
new_out
=
inc_subtensor
(
new_subtensor
,
y_raveled
,
set_instead_of_inc
=
op
.
set_instead_of_inc
,
ignore_duplicates
=
op
.
ignore_duplicates
,
inplace
=
op
.
inplace
,
)
if
isinstance
(
node
.
op
,
AdvancedSubtensor
):
new_out
=
node
.
op
(
x
,
*
new_idxs
)
else
:
else
:
# Unravel advanced indexing dimensions
new_out
=
node
.
op
(
x
,
y
,
*
new_idxs
)
raveled_shape
=
tuple
(
new_subtensor
.
shape
)
unraveled_shape
=
(
*
raveled_shape
[:
first_int_idx_pos
],
*
first_int_idx
.
shape
,
*
raveled_shape
[
first_int_idx_pos
+
1
:],
)
new_out
=
new_subtensor
.
reshape
(
unraveled_shape
)
return
[
copy_stack_trace
(
node
.
outputs
[
0
],
new_out
)]
return
[
copy_stack_trace
(
node
.
outputs
[
0
],
new_out
)]
optdb
[
"specialize"
]
.
register
(
optdb
[
"specialize"
]
.
register
(
ravel_multidimensional_bool_idx
.
__name__
,
bool_idx_to_nonzero
.
__name__
,
ravel_multidimensional_bool_idx
,
bool_idx_to_nonzero
,
"numba"
,
use_db_name_as_tag
=
False
,
# Not included if only "specialize" is requested
)
optdb
[
"specialize"
]
.
register
(
ravel_multidimensional_int_idx
.
__name__
,
ravel_multidimensional_int_idx
,
"numba"
,
"numba"
,
"shape_unsafe"
,
# It can mask invalid mask sizes
use_db_name_as_tag
=
False
,
# Not included if only "specialize" is requested
use_db_name_as_tag
=
False
,
# Not included if only "specialize" is requested
)
)
...
...
tests/link/numba/test_subtensor.py
浏览文件 @
d8501d14
差异被折叠。
点击展开。
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论