Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
bd281be6
提交
bd281be6
authored
11月 26, 2024
作者:
Ricardo Vieira
提交者:
Ricardo Vieira
11月 28, 2024
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Support consecutive integer vector indexing in Numba backend
上级
b8356ff9
隐藏空白字符变更
内嵌
并排
正在显示
4 个修改的文件
包含
300 行增加
和
67 行删除
+300
-67
subtensor.py
pytensor/link/numba/dispatch/subtensor.py
+145
-6
subtensor.py
pytensor/tensor/subtensor.py
+25
-0
test_basic.py
tests/link/numba/test_basic.py
+12
-5
test_subtensor.py
tests/link/numba/test_subtensor.py
+118
-56
没有找到文件。
pytensor/link/numba/dispatch/subtensor.py
浏览文件 @
bd281be6
...
@@ -5,6 +5,7 @@ from pytensor.link.numba.dispatch import numba_funcify
...
@@ -5,6 +5,7 @@ from pytensor.link.numba.dispatch import numba_funcify
from
pytensor.link.numba.dispatch.basic
import
generate_fallback_impl
,
numba_njit
from
pytensor.link.numba.dispatch.basic
import
generate_fallback_impl
,
numba_njit
from
pytensor.link.utils
import
compile_function_src
,
unique_name_generator
from
pytensor.link.utils
import
compile_function_src
,
unique_name_generator
from
pytensor.tensor
import
TensorType
from
pytensor.tensor
import
TensorType
from
pytensor.tensor.rewriting.subtensor
import
is_full_slice
from
pytensor.tensor.subtensor
import
(
from
pytensor.tensor.subtensor
import
(
AdvancedIncSubtensor
,
AdvancedIncSubtensor
,
AdvancedIncSubtensor1
,
AdvancedIncSubtensor1
,
...
@@ -13,6 +14,7 @@ from pytensor.tensor.subtensor import (
...
@@ -13,6 +14,7 @@ from pytensor.tensor.subtensor import (
IncSubtensor
,
IncSubtensor
,
Subtensor
,
Subtensor
,
)
)
from
pytensor.tensor.type_other
import
NoneTypeT
,
SliceType
@numba_funcify.register
(
Subtensor
)
@numba_funcify.register
(
Subtensor
)
...
@@ -104,18 +106,73 @@ def {function_name}({", ".join(input_names)}):
...
@@ -104,18 +106,73 @@ def {function_name}({", ".join(input_names)}):
@numba_funcify.register
(
AdvancedSubtensor
)
@numba_funcify.register
(
AdvancedSubtensor
)
@numba_funcify.register
(
AdvancedIncSubtensor
)
@numba_funcify.register
(
AdvancedIncSubtensor
)
def
numba_funcify_AdvancedSubtensor
(
op
,
node
,
**
kwargs
):
def
numba_funcify_AdvancedSubtensor
(
op
,
node
,
**
kwargs
):
idxs
=
node
.
inputs
[
1
:]
if
isinstance
(
op
,
AdvancedSubtensor
)
else
node
.
inputs
[
2
:]
if
isinstance
(
op
,
AdvancedSubtensor
):
adv_idxs_dims
=
[
x
,
y
,
idxs
=
node
.
inputs
[
0
],
None
,
node
.
inputs
[
1
:]
idx
.
type
.
ndim
else
:
x
,
y
,
*
idxs
=
node
.
inputs
basic_idxs
=
[
idx
for
idx
in
idxs
for
idx
in
idxs
if
(
isinstance
(
idx
.
type
,
TensorType
)
and
idx
.
type
.
ndim
>
0
)
if
(
isinstance
(
idx
.
type
,
NoneTypeT
)
or
(
isinstance
(
idx
.
type
,
SliceType
)
and
not
is_full_slice
(
idx
))
)
]
adv_idxs
=
[
{
"axis"
:
i
,
"dtype"
:
idx
.
type
.
dtype
,
"bcast"
:
idx
.
type
.
broadcastable
,
"ndim"
:
idx
.
type
.
ndim
,
}
for
i
,
idx
in
enumerate
(
idxs
)
if
isinstance
(
idx
.
type
,
TensorType
)
]
]
# Special case for consecutive consecutive vector indices
def
broadcasted_to
(
x_bcast
:
tuple
[
bool
,
...
],
to_bcast
:
tuple
[
bool
,
...
]):
# Check that x is not broadcasted to y based on broadcastable info
if
len
(
x_bcast
)
<
len
(
to_bcast
):
return
True
for
x_bcast_dim
,
to_bcast_dim
in
zip
(
x_bcast
,
to_bcast
,
strict
=
True
):
if
x_bcast_dim
and
not
to_bcast_dim
:
return
True
return
False
# Special implementation for consecutive integer vector indices
if
(
not
basic_idxs
and
len
(
adv_idxs
)
>=
2
# Must be integer vectors
# Todo: we could allow shape=(1,) if this is the shape of x
and
all
(
(
adv_idx
[
"bcast"
]
==
(
False
,)
and
adv_idx
[
"dtype"
]
!=
"bool"
)
for
adv_idx
in
adv_idxs
)
# Must be consecutive
and
not
op
.
non_contiguous_adv_indexing
(
node
)
# y in set/inc_subtensor cannot be broadcasted
and
(
y
is
None
or
not
broadcasted_to
(
y
.
type
.
broadcastable
,
(
x
.
type
.
broadcastable
[:
adv_idxs
[
0
][
"axis"
]]
+
x
.
type
.
broadcastable
[
adv_idxs
[
-
1
][
"axis"
]
:]
),
)
)
):
return
numba_funcify_multiple_integer_vector_indexing
(
op
,
node
,
**
kwargs
)
# Other cases not natively supported by Numba (fallback to obj-mode)
if
(
if
(
# Numba does not support indexes with more than one dimension
# Numba does not support indexes with more than one dimension
any
(
idx
[
"ndim"
]
>
1
for
idx
in
adv_idxs
)
# Nor multiple vector indexes
# Nor multiple vector indexes
(
len
(
adv_idxs_dims
)
>
1
or
adv_idxs_dims
[
0
]
>
1
)
or
sum
(
idx
[
"ndim"
]
>
0
for
idx
in
adv_idxs
)
>
1
# The default
index
implementation does not handle duplicate indices correctly
# The default
PyTensor
implementation does not handle duplicate indices correctly
or
(
or
(
isinstance
(
op
,
AdvancedIncSubtensor
)
isinstance
(
op
,
AdvancedIncSubtensor
)
and
not
op
.
set_instead_of_inc
and
not
op
.
set_instead_of_inc
...
@@ -124,9 +181,91 @@ def numba_funcify_AdvancedSubtensor(op, node, **kwargs):
...
@@ -124,9 +181,91 @@ def numba_funcify_AdvancedSubtensor(op, node, **kwargs):
):
):
return
generate_fallback_impl
(
op
,
node
,
**
kwargs
)
return
generate_fallback_impl
(
op
,
node
,
**
kwargs
)
# What's left should all be supported natively by numba
return
numba_funcify_default_subtensor
(
op
,
node
,
**
kwargs
)
return
numba_funcify_default_subtensor
(
op
,
node
,
**
kwargs
)
def
numba_funcify_multiple_integer_vector_indexing
(
op
:
AdvancedSubtensor
|
AdvancedIncSubtensor
,
node
,
**
kwargs
):
# Special-case implementation for multiple consecutive vector integer indices (and set/incsubtensor)
if
isinstance
(
op
,
AdvancedSubtensor
):
y
,
idxs
=
None
,
node
.
inputs
[
1
:]
else
:
y
,
*
idxs
=
node
.
inputs
[
1
:]
first_axis
=
next
(
i
for
i
,
idx
in
enumerate
(
idxs
)
if
isinstance
(
idx
.
type
,
TensorType
)
)
try
:
after_last_axis
=
next
(
i
for
i
,
idx
in
enumerate
(
idxs
[
first_axis
:],
start
=
first_axis
)
if
not
isinstance
(
idx
.
type
,
TensorType
)
)
except
StopIteration
:
after_last_axis
=
len
(
idxs
)
if
isinstance
(
op
,
AdvancedSubtensor
):
@numba_njit
def
advanced_subtensor_multiple_vector
(
x
,
*
idxs
):
none_slices
=
idxs
[:
first_axis
]
vec_idxs
=
idxs
[
first_axis
:
after_last_axis
]
x_shape
=
x
.
shape
idx_shape
=
vec_idxs
[
0
]
.
shape
shape_bef
=
x_shape
[:
first_axis
]
shape_aft
=
x_shape
[
after_last_axis
:]
out_shape
=
(
*
shape_bef
,
*
idx_shape
,
*
shape_aft
)
out_buffer
=
np
.
empty
(
out_shape
,
dtype
=
x
.
dtype
)
for
i
,
scalar_idxs
in
enumerate
(
zip
(
*
vec_idxs
)):
# noqa: B905
out_buffer
[(
*
none_slices
,
i
)]
=
x
[(
*
none_slices
,
*
scalar_idxs
)]
return
out_buffer
return
advanced_subtensor_multiple_vector
elif
op
.
set_instead_of_inc
:
inplace
=
op
.
inplace
@numba_njit
def
advanced_set_subtensor_multiple_vector
(
x
,
y
,
*
idxs
):
vec_idxs
=
idxs
[
first_axis
:
after_last_axis
]
x_shape
=
x
.
shape
if
inplace
:
out
=
x
else
:
out
=
x
.
copy
()
for
outer
in
np
.
ndindex
(
x_shape
[:
first_axis
]):
for
i
,
scalar_idxs
in
enumerate
(
zip
(
*
vec_idxs
)):
# noqa: B905
out
[(
*
outer
,
*
scalar_idxs
)]
=
y
[(
*
outer
,
i
)]
return
out
return
advanced_set_subtensor_multiple_vector
else
:
inplace
=
op
.
inplace
@numba_njit
def
advanced_inc_subtensor_multiple_vector
(
x
,
y
,
*
idxs
):
vec_idxs
=
idxs
[
first_axis
:
after_last_axis
]
x_shape
=
x
.
shape
if
inplace
:
out
=
x
else
:
out
=
x
.
copy
()
for
outer
in
np
.
ndindex
(
x_shape
[:
first_axis
]):
for
i
,
scalar_idxs
in
enumerate
(
zip
(
*
vec_idxs
)):
# noqa: B905
out
[(
*
outer
,
*
scalar_idxs
)]
+=
y
[(
*
outer
,
i
)]
return
out
return
advanced_inc_subtensor_multiple_vector
@numba_funcify.register
(
AdvancedIncSubtensor1
)
@numba_funcify.register
(
AdvancedIncSubtensor1
)
def
numba_funcify_AdvancedIncSubtensor1
(
op
,
node
,
**
kwargs
):
def
numba_funcify_AdvancedIncSubtensor1
(
op
,
node
,
**
kwargs
):
inplace
=
op
.
inplace
inplace
=
op
.
inplace
...
...
pytensor/tensor/subtensor.py
浏览文件 @
bd281be6
...
@@ -2937,6 +2937,31 @@ class AdvancedIncSubtensor(Op):
...
@@ -2937,6 +2937,31 @@ class AdvancedIncSubtensor(Op):
gy
=
_sum_grad_over_bcasted_dims
(
y
,
gy
)
gy
=
_sum_grad_over_bcasted_dims
(
y
,
gy
)
return
[
gx
,
gy
]
+
[
DisconnectedType
()()
for
_
in
idxs
]
return
[
gx
,
gy
]
+
[
DisconnectedType
()()
for
_
in
idxs
]
@staticmethod
def
non_contiguous_adv_indexing
(
node
:
Apply
)
->
bool
:
"""
Check if the advanced indexing is non-contiguous (i.e. interrupted by basic indexing).
This function checks if the advanced indexing is non-contiguous,
in which case the advanced index dimensions are placed on the left of the
output array, regardless of their opriginal position.
See: https://numpy.org/doc/stable/user/basics.indexing.html#combining-advanced-and-basic-indexing
Parameters
----------
node : Apply
The node of the AdvancedSubtensor operation.
Returns
-------
bool
True if the advanced indexing is non-contiguous, False otherwise.
"""
_
,
_
,
*
idxs
=
node
.
inputs
return
_non_contiguous_adv_indexing
(
idxs
)
advanced_inc_subtensor
=
AdvancedIncSubtensor
()
advanced_inc_subtensor
=
AdvancedIncSubtensor
()
advanced_set_subtensor
=
AdvancedIncSubtensor
(
set_instead_of_inc
=
True
)
advanced_set_subtensor
=
AdvancedIncSubtensor
(
set_instead_of_inc
=
True
)
...
...
tests/link/numba/test_basic.py
浏览文件 @
bd281be6
...
@@ -228,9 +228,11 @@ def compare_numba_and_py(
...
@@ -228,9 +228,11 @@ def compare_numba_and_py(
fgraph
:
FunctionGraph
|
tuple
[
Sequence
[
"Variable"
],
Sequence
[
"Variable"
]],
fgraph
:
FunctionGraph
|
tuple
[
Sequence
[
"Variable"
],
Sequence
[
"Variable"
]],
inputs
:
Sequence
[
"TensorLike"
],
inputs
:
Sequence
[
"TensorLike"
],
assert_fn
:
Callable
|
None
=
None
,
assert_fn
:
Callable
|
None
=
None
,
*
,
numba_mode
=
numba_mode
,
numba_mode
=
numba_mode
,
py_mode
=
py_mode
,
py_mode
=
py_mode
,
updates
=
None
,
updates
=
None
,
inplace
:
bool
=
False
,
eval_obj_mode
:
bool
=
True
,
eval_obj_mode
:
bool
=
True
,
)
->
tuple
[
Callable
,
Any
]:
)
->
tuple
[
Callable
,
Any
]:
"""Function to compare python graph output and Numba compiled output for testing equality
"""Function to compare python graph output and Numba compiled output for testing equality
...
@@ -276,7 +278,14 @@ def compare_numba_and_py(
...
@@ -276,7 +278,14 @@ def compare_numba_and_py(
pytensor_py_fn
=
function
(
pytensor_py_fn
=
function
(
fn_inputs
,
fn_outputs
,
mode
=
py_mode
,
accept_inplace
=
True
,
updates
=
updates
fn_inputs
,
fn_outputs
,
mode
=
py_mode
,
accept_inplace
=
True
,
updates
=
updates
)
)
py_res
=
pytensor_py_fn
(
*
inputs
)
test_inputs
=
(
inp
.
copy
()
for
inp
in
inputs
)
if
inplace
else
inputs
py_res
=
pytensor_py_fn
(
*
test_inputs
)
# Get some coverage (and catch errors in python mode before unreadable numba ones)
if
eval_obj_mode
:
test_inputs
=
(
inp
.
copy
()
for
inp
in
inputs
)
if
inplace
else
inputs
eval_python_only
(
fn_inputs
,
fn_outputs
,
test_inputs
,
mode
=
numba_mode
)
pytensor_numba_fn
=
function
(
pytensor_numba_fn
=
function
(
fn_inputs
,
fn_inputs
,
...
@@ -285,11 +294,9 @@ def compare_numba_and_py(
...
@@ -285,11 +294,9 @@ def compare_numba_and_py(
accept_inplace
=
True
,
accept_inplace
=
True
,
updates
=
updates
,
updates
=
updates
,
)
)
numba_res
=
pytensor_numba_fn
(
*
inputs
)
# Get some coverage
test_inputs
=
(
inp
.
copy
()
for
inp
in
inputs
)
if
inplace
else
inputs
if
eval_obj_mode
:
numba_res
=
pytensor_numba_fn
(
*
test_inputs
)
eval_python_only
(
fn_inputs
,
fn_outputs
,
inputs
,
mode
=
numba_mode
)
if
len
(
fn_outputs
)
>
1
:
if
len
(
fn_outputs
)
>
1
:
for
j
,
p
in
zip
(
numba_res
,
py_res
,
strict
=
True
):
for
j
,
p
in
zip
(
numba_res
,
py_res
,
strict
=
True
):
...
...
tests/link/numba/test_subtensor.py
浏览文件 @
bd281be6
...
@@ -85,7 +85,11 @@ def test_AdvancedSubtensor1_out_of_bounds():
...
@@ -85,7 +85,11 @@ def test_AdvancedSubtensor1_out_of_bounds():
(
np
.
array
([
True
,
False
,
False
])),
(
np
.
array
([
True
,
False
,
False
])),
False
,
False
,
),
),
(
pt
.
as_tensor
(
np
.
arange
(
3
*
4
*
5
)
.
reshape
((
3
,
4
,
5
))),
([
1
,
2
],
[
2
,
3
]),
True
),
(
pt
.
as_tensor
(
np
.
arange
(
3
*
4
*
5
)
.
reshape
((
3
,
4
,
5
))),
([
1
,
2
],
[
2
,
3
]),
False
,
),
# Single multidimensional indexing (supported after specialization rewrites)
# Single multidimensional indexing (supported after specialization rewrites)
(
(
as_tensor
(
np
.
arange
(
3
*
3
)
.
reshape
((
3
,
3
))),
as_tensor
(
np
.
arange
(
3
*
3
)
.
reshape
((
3
,
3
))),
...
@@ -117,17 +121,23 @@ def test_AdvancedSubtensor1_out_of_bounds():
...
@@ -117,17 +121,23 @@ def test_AdvancedSubtensor1_out_of_bounds():
(
slice
(
2
,
None
),
np
.
eye
(
3
)
.
astype
(
bool
)),
(
slice
(
2
,
None
),
np
.
eye
(
3
)
.
astype
(
bool
)),
False
,
False
,
),
),
# Multiple advanced indexing, only supported in obj mode
(
(
as_tensor
(
np
.
arange
(
3
*
4
*
5
)
.
reshape
((
3
,
4
,
5
))),
as_tensor
(
np
.
arange
(
3
*
4
*
5
)
.
reshape
((
3
,
4
,
5
))),
(
slice
(
None
),
[
1
,
2
],
[
3
,
4
]),
(
slice
(
None
),
[
1
,
2
],
[
3
,
4
]),
True
,
False
,
),
(
as_tensor
(
np
.
arange
(
3
*
5
*
7
)
.
reshape
((
3
,
5
,
7
))),
([
1
,
2
],
[
3
,
4
],
[
5
,
6
]),
False
,
),
),
# Non-contiguous vector indexing, only supported in obj mode
(
(
as_tensor
(
np
.
arange
(
3
*
4
*
5
)
.
reshape
((
3
,
4
,
5
))),
as_tensor
(
np
.
arange
(
3
*
4
*
5
)
.
reshape
((
3
,
4
,
5
))),
([
1
,
2
],
slice
(
None
),
[
3
,
4
]),
([
1
,
2
],
slice
(
None
),
[
3
,
4
]),
True
,
True
,
),
),
# >1d vector indexing, only supported in obj mode
(
(
as_tensor
(
np
.
arange
(
3
*
4
*
5
)
.
reshape
((
3
,
4
,
5
))),
as_tensor
(
np
.
arange
(
3
*
4
*
5
)
.
reshape
((
3
,
4
,
5
))),
([[
1
,
2
],
[
2
,
1
]],
[
0
,
0
]),
([[
1
,
2
],
[
2
,
1
]],
[
0
,
0
]),
...
@@ -135,7 +145,7 @@ def test_AdvancedSubtensor1_out_of_bounds():
...
@@ -135,7 +145,7 @@ def test_AdvancedSubtensor1_out_of_bounds():
),
),
],
],
)
)
@pytest.mark.filterwarnings
(
"error"
)
@pytest.mark.filterwarnings
(
"error"
)
# Raise if we did not expect objmode to be needed
def
test_AdvancedSubtensor
(
x
,
indices
,
objmode_needed
):
def
test_AdvancedSubtensor
(
x
,
indices
,
objmode_needed
):
"""Test NumPy's advanced indexing in more than one dimension."""
"""Test NumPy's advanced indexing in more than one dimension."""
x_pt
=
x
.
type
()
x_pt
=
x
.
type
()
...
@@ -268,94 +278,151 @@ def test_AdvancedIncSubtensor1(x, y, indices):
...
@@ -268,94 +278,151 @@ def test_AdvancedIncSubtensor1(x, y, indices):
"x, y, indices, duplicate_indices, set_requires_objmode, inc_requires_objmode"
,
"x, y, indices, duplicate_indices, set_requires_objmode, inc_requires_objmode"
,
[
[
(
(
as_tensor
(
np
.
arange
(
3
*
4
*
5
)
.
reshape
((
3
,
4
,
5
)
)),
np
.
arange
(
3
*
4
*
5
)
.
reshape
((
3
,
4
,
5
)),
-
np
.
arange
(
3
*
5
)
.
reshape
(
3
,
5
),
-
np
.
arange
(
3
*
5
)
.
reshape
(
3
,
5
),
(
slice
(
None
,
None
,
2
),
[
1
,
2
,
3
]),
(
slice
(
None
,
None
,
2
),
[
1
,
2
,
3
]),
# Mixed basic and vector index
False
,
False
,
False
,
False
,
False
,
False
,
),
),
(
(
as_tensor
(
np
.
arange
(
3
*
4
*
5
)
.
reshape
((
3
,
4
,
5
))),
np
.
arange
(
3
*
4
*
5
)
.
reshape
((
3
,
4
,
5
)),
-
99
,
np
.
array
(
-
99
),
# Broadcasted value
(
slice
(
None
,
None
,
2
),
[
1
,
2
,
3
],
-
1
),
(
slice
(
None
,
None
,
2
),
[
1
,
2
,
3
],
-
1
,
),
# Mixed basic and broadcasted vector idx
False
,
False
,
False
,
False
,
False
,
False
,
),
),
(
(
as_tensor
(
np
.
arange
(
3
*
4
*
5
)
.
reshape
((
3
,
4
,
5
)
)),
np
.
arange
(
3
*
4
*
5
)
.
reshape
((
3
,
4
,
5
)),
-
99
,
# Broadcasted value
np
.
array
(
-
99
)
,
# Broadcasted value
(
slice
(
None
,
None
,
2
),
[
1
,
2
,
3
]),
(
slice
(
None
,
None
,
2
),
[
1
,
2
,
3
]),
# Mixed basic and vector idx
False
,
False
,
False
,
False
,
False
,
False
,
),
),
(
(
as_tensor
(
np
.
arange
(
3
*
4
*
5
)
.
reshape
((
3
,
4
,
5
)
)),
np
.
arange
(
3
*
4
*
5
)
.
reshape
((
3
,
4
,
5
)),
-
np
.
arange
(
4
*
5
)
.
reshape
(
4
,
5
),
-
np
.
arange
(
4
*
5
)
.
reshape
(
4
,
5
),
(
0
,
[
1
,
2
,
2
,
3
]),
(
0
,
[
1
,
2
,
2
,
3
]),
# Broadcasted vector index
True
,
True
,
False
,
False
,
True
,
True
,
),
),
(
(
as_tensor
(
np
.
arange
(
3
*
4
*
5
)
.
reshape
((
3
,
4
,
5
)
)),
np
.
arange
(
3
*
4
*
5
)
.
reshape
((
3
,
4
,
5
)),
[
-
99
],
# Broadcs
asted value
np
.
array
([
-
99
]),
# Broadc
asted value
(
0
,
[
1
,
2
,
2
,
3
]),
(
0
,
[
1
,
2
,
2
,
3
]),
# Broadcasted vector index
True
,
True
,
False
,
False
,
True
,
True
,
),
),
(
(
as_tensor
(
np
.
arange
(
3
*
4
*
5
)
.
reshape
((
3
,
4
,
5
)
)),
np
.
arange
(
3
*
4
*
5
)
.
reshape
((
3
,
4
,
5
)),
-
np
.
arange
(
1
*
4
*
5
)
.
reshape
(
1
,
4
,
5
),
-
np
.
arange
(
1
*
4
*
5
)
.
reshape
(
1
,
4
,
5
),
(
np
.
array
([
True
,
False
,
False
])),
(
np
.
array
([
True
,
False
,
False
])),
# Broadcasted boolean index
False
,
False
,
False
,
False
,
False
,
False
,
),
),
(
(
as_tensor
(
np
.
arange
(
3
*
3
)
.
reshape
((
3
,
3
)
)),
np
.
arange
(
3
*
3
)
.
reshape
((
3
,
3
)),
-
np
.
arange
(
3
),
-
np
.
arange
(
3
),
(
np
.
eye
(
3
)
.
astype
(
bool
)),
(
np
.
eye
(
3
)
.
astype
(
bool
)),
# Boolean index
False
,
False
,
True
,
True
,
True
,
True
,
),
),
(
(
as_tensor
(
np
.
arange
(
3
*
4
*
5
)
.
reshape
((
3
,
4
,
5
))),
np
.
arange
(
3
*
4
*
5
)
.
reshape
((
3
,
4
,
5
)),
as_tensor
(
rng
.
poisson
(
size
=
(
2
,
5
))),
rng
.
poisson
(
size
=
(
2
,
5
)),
([
1
,
2
],
[
2
,
3
]),
([
1
,
2
],
[
2
,
3
]),
# 2 vector indices
False
,
False
,
False
,
),
(
np
.
arange
(
3
*
4
*
5
)
.
reshape
((
3
,
4
,
5
)),
rng
.
poisson
(
size
=
(
3
,
2
)),
(
slice
(
None
),
[
1
,
2
],
[
2
,
3
]),
# 2 vector indices
False
,
False
,
False
,
),
(
np
.
arange
(
3
*
4
*
6
)
.
reshape
((
3
,
4
,
6
)),
rng
.
poisson
(
size
=
(
2
,)),
([
1
,
2
],
[
2
,
3
],
[
4
,
5
]),
# 3 vector indices
False
,
False
,
False
,
),
(
np
.
arange
(
3
*
4
*
5
)
.
reshape
((
3
,
4
,
5
)),
np
.
array
(
-
99
),
# Broadcasted value
([
1
,
2
],
[
2
,
3
]),
# 2 vector indices
False
,
False
,
True
,
True
,
True
,
True
,
),
),
(
(
as_tensor
(
np
.
arange
(
3
*
4
*
5
)
.
reshape
((
3
,
4
,
5
)
)),
np
.
arange
(
3
*
4
*
5
)
.
reshape
((
3
,
4
,
5
)),
as_tensor
(
rng
.
poisson
(
size
=
(
2
,
4
)
)),
rng
.
poisson
(
size
=
(
2
,
4
)),
([
1
,
2
],
slice
(
None
),
[
3
,
4
]),
([
1
,
2
],
slice
(
None
),
[
3
,
4
]),
# Non-contiguous vector indices
False
,
False
,
True
,
True
,
True
,
True
,
),
),
pytest
.
param
(
(
as_tensor
(
np
.
arange
(
3
*
4
*
5
)
.
reshape
((
3
,
4
,
5
))),
np
.
arange
(
3
*
4
*
5
)
.
reshape
((
3
,
4
,
5
)),
as_tensor
(
rng
.
poisson
(
size
=
(
2
,
5
))),
rng
.
poisson
(
size
=
(
2
,
2
)),
([
1
,
1
],
[
2
,
2
]),
(
slice
(
1
,
None
),
[
1
,
2
],
[
3
,
4
],
),
# Mixed double vector index and basic index
False
,
True
,
True
,
),
(
np
.
arange
(
5
),
rng
.
poisson
(
size
=
(
2
,
2
)),
([[
1
,
2
],
[
2
,
3
]]),
# matrix indices
False
,
False
,
True
,
True
,
True
,
True
,
),
),
pytest
.
param
(
np
.
arange
(
3
*
4
*
5
)
.
reshape
((
3
,
4
,
5
)),
rng
.
poisson
(
size
=
(
2
,
5
)),
([
1
,
1
],
[
2
,
2
]),
# Repeated indices
True
,
False
,
False
,
),
],
],
)
)
@pytest.mark.filterwarnings
(
"error"
)
@pytest.mark.parametrize
(
"inplace"
,
(
False
,
True
))
@pytest.mark.filterwarnings
(
"error"
)
# Raise if we did not expect objmode to be needed
def
test_AdvancedIncSubtensor
(
def
test_AdvancedIncSubtensor
(
x
,
y
,
indices
,
duplicate_indices
,
set_requires_objmode
,
inc_requires_objmode
x
,
y
,
indices
,
duplicate_indices
,
set_requires_objmode
,
inc_requires_objmode
,
inplace
,
):
):
out_pt
=
set_subtensor
(
x
[
indices
],
y
)
x_pt
=
pt
.
as_tensor
(
x
)
.
type
(
"x"
)
y_pt
=
pt
.
as_tensor
(
y
)
.
type
(
"y"
)
out_pt
=
set_subtensor
(
x_pt
[
indices
],
y_pt
,
inplace
=
inplace
)
assert
isinstance
(
out_pt
.
owner
.
op
,
AdvancedIncSubtensor
)
assert
isinstance
(
out_pt
.
owner
.
op
,
AdvancedIncSubtensor
)
out_fg
=
FunctionGraph
([],
[
out_pt
])
with
(
with
(
pytest
.
warns
(
pytest
.
warns
(
...
@@ -365,11 +432,18 @@ def test_AdvancedIncSubtensor(
...
@@ -365,11 +432,18 @@ def test_AdvancedIncSubtensor(
if
set_requires_objmode
if
set_requires_objmode
else
contextlib
.
nullcontext
()
else
contextlib
.
nullcontext
()
):
):
compare_numba_and_py
(
out_fg
,
[])
fn
,
_
=
compare_numba_and_py
(([
x_pt
,
y_pt
],
[
out_pt
]),
[
x
,
y
])
if
inplace
:
# Test updates inplace
x_orig
=
x
.
copy
()
fn
(
x
,
y
+
1
)
assert
not
np
.
all
(
x
==
x_orig
)
out_pt
=
inc_subtensor
(
x
[
indices
],
y
,
ignore_duplicates
=
not
duplicate_indices
)
out_pt
=
inc_subtensor
(
x_pt
[
indices
],
y_pt
,
ignore_duplicates
=
not
duplicate_indices
,
inplace
=
inplace
)
assert
isinstance
(
out_pt
.
owner
.
op
,
AdvancedIncSubtensor
)
assert
isinstance
(
out_pt
.
owner
.
op
,
AdvancedIncSubtensor
)
out_fg
=
FunctionGraph
([],
[
out_pt
])
with
(
with
(
pytest
.
warns
(
pytest
.
warns
(
UserWarning
,
UserWarning
,
...
@@ -378,21 +452,9 @@ def test_AdvancedIncSubtensor(
...
@@ -378,21 +452,9 @@ def test_AdvancedIncSubtensor(
if
inc_requires_objmode
if
inc_requires_objmode
else
contextlib
.
nullcontext
()
else
contextlib
.
nullcontext
()
):
):
compare_numba_and_py
(
out_fg
,
[])
fn
,
_
=
compare_numba_and_py
(([
x_pt
,
y_pt
],
[
out_pt
]),
[
x
,
y
])
if
inplace
:
x_pt
=
x
.
type
()
# Test updates inplace
out_pt
=
set_subtensor
(
x_pt
[
indices
],
y
)
x_orig
=
x
.
copy
()
# Inplace isn't really implemented for `AdvancedIncSubtensor`, so we just
fn
(
x
,
y
)
# hack it on here
assert
not
np
.
all
(
x
==
x_orig
)
out_pt
.
owner
.
op
.
inplace
=
True
assert
isinstance
(
out_pt
.
owner
.
op
,
AdvancedIncSubtensor
)
out_fg
=
FunctionGraph
([
x_pt
],
[
out_pt
])
with
(
pytest
.
warns
(
UserWarning
,
match
=
"Numba will use object mode to run AdvancedSetSubtensor's perform method"
,
)
if
set_requires_objmode
else
contextlib
.
nullcontext
()
):
compare_numba_and_py
(
out_fg
,
[
x
.
data
])
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论