Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
82a57574
提交
82a57574
authored
5月 03, 2024
作者:
Ricardo Vieira
提交者:
Luciano Paz
5月 10, 2024
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Implement vectorize_node dispatch for some forms of AdvancedSubtensor
上级
56637af8
隐藏空白字符变更
内嵌
并排
正在显示
2 个修改的文件
包含
152 行增加
和
2 行删除
+152
-2
subtensor.py
pytensor/tensor/subtensor.py
+67
-2
test_subtensor.py
tests/tensor/test_subtensor.py
+85
-0
没有找到文件。
pytensor/tensor/subtensor.py
浏览文件 @
82a57574
...
...
@@ -47,6 +47,7 @@ from pytensor.tensor.type import (
zscalar
,
)
from
pytensor.tensor.type_other
import
NoneConst
,
NoneTypeT
,
SliceType
,
make_slice
from
pytensor.tensor.variable
import
TensorVariable
_logger
=
logging
.
getLogger
(
"pytensor.tensor.subtensor"
)
...
...
@@ -473,6 +474,13 @@ def group_indices(indices):
return
idx_groups
def
_non_contiguous_adv_indexing
(
indices
)
->
bool
:
"""Check if the advanced indexing is non-contiguous (i.e., split by basic indexing)."""
idx_groups
=
group_indices
(
indices
)
# This means that there are at least two groups of advanced indexing separated by basic indexing
return
len
(
idx_groups
)
>
3
or
(
len
(
idx_groups
)
==
3
and
not
idx_groups
[
0
][
0
])
def
indexed_result_shape
(
array_shape
,
indices
,
indices_are_shapes
=
False
):
"""Compute the symbolic shape resulting from `a[indices]` for `a.shape == array_shape`.
...
...
@@ -497,8 +505,7 @@ def indexed_result_shape(array_shape, indices, indices_are_shapes=False):
remaining_dims
=
range
(
pytensor
.
tensor
.
basic
.
get_vector_length
(
array_shape
))
idx_groups
=
group_indices
(
indices
)
if
len
(
idx_groups
)
>
3
or
(
len
(
idx_groups
)
==
3
and
not
idx_groups
[
0
][
0
]):
# This means that there are at least two groups of advanced indexing separated by basic indexing
if
_non_contiguous_adv_indexing
(
indices
):
# In this case NumPy places the advanced index groups in the front of the array
# https://numpy.org/devdocs/user/basics.indexing.html#combining-advanced-and-basic-indexing
idx_groups
=
sorted
(
idx_groups
,
key
=
lambda
x
:
x
[
0
])
...
...
@@ -2682,10 +2689,68 @@ class AdvancedSubtensor(Op):
rest
)
@staticmethod
def
non_contiguous_adv_indexing
(
node
:
Apply
)
->
bool
:
"""
Check if the advanced indexing is non-contiguous (i.e. interrupted by basic indexing).
This function checks if the advanced indexing is non-contiguous,
in which case the advanced index dimensions are placed on the left of the
output array, regardless of their opriginal position.
See: https://numpy.org/doc/stable/user/basics.indexing.html#combining-advanced-and-basic-indexing
Parameters
----------
node : Apply
The node of the AdvancedSubtensor operation.
Returns
-------
bool
True if the advanced indexing is non-contiguous, False otherwise.
"""
_
,
*
idxs
=
node
.
inputs
return
_non_contiguous_adv_indexing
(
idxs
)
advanced_subtensor
=
AdvancedSubtensor
()
@_vectorize_node.register
(
AdvancedSubtensor
)
def
vectorize_advanced_subtensor
(
op
:
AdvancedSubtensor
,
node
,
*
batch_inputs
):
x
,
*
idxs
=
node
.
inputs
batch_x
,
*
batch_idxs
=
batch_inputs
x_is_batched
=
x
.
type
.
ndim
<
batch_x
.
type
.
ndim
idxs_are_batched
=
any
(
batch_idx
.
type
.
ndim
>
idx
.
type
.
ndim
for
batch_idx
,
idx
in
zip
(
batch_idxs
,
idxs
)
if
isinstance
(
batch_idx
,
TensorVariable
)
)
if
idxs_are_batched
or
(
x_is_batched
and
op
.
non_contiguous_adv_indexing
(
node
)):
# Fallback to Blockwise if idxs are batched or if we have non contiguous advanced indexing
# which would put the indexed results to the left of the batch dimensions!
# TODO: Not all cases must be handled by Blockwise, but the logic is complex
# Blockwise doesn't accept None or Slices types so we raise informative error here
# TODO: Implement these internally, so Blockwise is always a safe fallback
if
any
(
not
isinstance
(
idx
,
TensorVariable
)
for
idx
in
idxs
):
raise
NotImplementedError
(
"Vectorized AdvancedSubtensor with batched indexes or non-contiguous advanced indexing "
"and slices or newaxis is currently not supported."
)
else
:
return
vectorize_node_fallback
(
op
,
node
,
batch_x
,
*
batch_idxs
)
# Otherwise we just need to add None slices for every new batch dim
x_batch_ndim
=
batch_x
.
type
.
ndim
-
x
.
type
.
ndim
empty_slices
=
(
slice
(
None
),)
*
x_batch_ndim
return
op
.
make_node
(
batch_x
,
*
empty_slices
,
*
batch_idxs
)
class
AdvancedIncSubtensor
(
Op
):
"""Increments a subtensor using advanced indexing."""
...
...
tests/tensor/test_subtensor.py
浏览文件 @
82a57574
...
...
@@ -2751,3 +2751,88 @@ def test_vectorize_subtensor_without_batch_indices():
vectorize_pt
(
x_test
,
start_test
),
vectorize_np
(
x_test
,
start_test
),
)
@pytest.mark.parametrize
(
"core_idx_fn, signature, x_shape, idx_shape, uses_blockwise"
,
[
# Core case
((
lambda
x
,
idx
:
x
[:,
idx
,
:]),
"(7,5,3),(2)->(7,2,3)"
,
(
7
,
5
,
3
),
(
2
,),
False
),
# Batched x, core idx
(
(
lambda
x
,
idx
:
x
[:,
idx
,
:]),
"(7,5,3),(2)->(7,2,3)"
,
(
11
,
7
,
5
,
3
),
(
2
,),
False
,
),
(
(
lambda
x
,
idx
:
x
[
idx
,
None
]),
"(5,7,3),(2)->(2,1,7,3)"
,
(
11
,
5
,
7
,
3
),
(
2
,),
False
,
),
# (this is currently failing because PyTensor tries to vectorize the slice(None) operation,
# due to the exact same None constant being used there and in the np.newaxis)
pytest
.
param
(
(
lambda
x
,
idx
:
x
[:,
idx
,
None
]),
"(7,5,3),(2)->(7,2,1,3)"
,
(
11
,
7
,
5
,
3
),
(
2
,),
False
,
marks
=
pytest
.
mark
.
xfail
(
raises
=
NotImplementedError
),
),
(
(
lambda
x
,
idx
:
x
[:,
idx
,
idx
,
:]),
"(7,5,5,3),(2)->(7,2,3)"
,
(
11
,
7
,
5
,
5
,
3
),
(
2
,),
False
,
),
# (not supported, because fallback Blocwise can't handle slices)
pytest
.
param
(
(
lambda
x
,
idx
:
x
[:,
idx
,
:,
idx
]),
"(7,5,3,5),(2)->(2,7,3)"
,
(
11
,
7
,
5
,
3
,
5
),
(
2
,),
True
,
marks
=
pytest
.
mark
.
xfail
(
raises
=
NotImplementedError
),
),
# Core x, batched idx
((
lambda
x
,
idx
:
x
[
idx
]),
"(t1),(idx)->(tx)"
,
(
7
,),
(
11
,
2
),
True
),
# Batched x, batched idx
((
lambda
x
,
idx
:
x
[
idx
]),
"(t1),(idx)->(tx)"
,
(
11
,
7
),
(
11
,
2
),
True
),
# (not supported, because fallback Blocwise can't handle slices)
pytest
.
param
(
(
lambda
x
,
idx
:
x
[:,
idx
,
:]),
"(t1,t2,t3),(idx)->(t1,tx,t3)"
,
(
11
,
7
,
5
,
3
),
(
11
,
2
),
True
,
marks
=
pytest
.
mark
.
xfail
(
raises
=
NotImplementedError
),
),
],
)
def
test_vectorize_adv_subtensor
(
core_idx_fn
,
signature
,
x_shape
,
idx_shape
,
uses_blockwise
):
x
=
tensor
(
shape
=
x_shape
,
dtype
=
"float64"
)
idx
=
tensor
(
shape
=
idx_shape
,
dtype
=
"int64"
)
vectorize_pt
=
function
(
[
x
,
idx
],
vectorize
(
core_idx_fn
,
signature
=
signature
)(
x
,
idx
)
)
has_blockwise
=
any
(
isinstance
(
node
.
op
,
Blockwise
)
for
node
in
vectorize_pt
.
maker
.
fgraph
.
apply_nodes
)
assert
has_blockwise
==
uses_blockwise
x_test
=
np
.
random
.
normal
(
size
=
x
.
type
.
shape
)
.
astype
(
x
.
type
.
dtype
)
# Idx dimension should be length 5
idx_test
=
np
.
random
.
randint
(
0
,
5
,
size
=
idx
.
type
.
shape
)
vectorize_np
=
np
.
vectorize
(
core_idx_fn
,
signature
=
signature
)
np
.
testing
.
assert_allclose
(
vectorize_pt
(
x_test
,
idx_test
),
vectorize_np
(
x_test
,
idx_test
),
)
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论