Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
528b8d4b
提交
528b8d4b
authored
4月 06, 2025
作者:
ricardoV94
提交者:
Ricardo Vieira
4月 29, 2025
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Specialized C-impl for vector AdvancedIncSubtensor1
Also add checks for runtime broadcast
上级
4311f893
显示空白字符变更
内嵌
并排
正在显示
5 个修改的文件
包含
240 行增加
和
30 行删除
+240
-30
subtensor.py
pytensor/link/jax/dispatch/subtensor.py
+3
-0
subtensor.py
pytensor/link/numba/dispatch/subtensor.py
+4
-4
subtensor.py
pytensor/link/pytorch/dispatch/subtensor.py
+4
-0
subtensor.py
pytensor/tensor/subtensor.py
+143
-9
test_subtensor.py
tests/tensor/test_subtensor.py
+86
-17
没有找到文件。
pytensor/link/jax/dispatch/subtensor.py
浏览文件 @
528b8d4b
...
...
@@ -67,6 +67,9 @@ def jax_funcify_IncSubtensor(op, node, **kwargs):
if
len
(
indices
)
==
1
:
indices
=
indices
[
0
]
if
isinstance
(
op
,
AdvancedIncSubtensor1
):
op
.
_check_runtime_broadcasting
(
node
,
x
,
y
,
indices
)
return
jax_fn
(
x
,
indices
,
y
)
return
incsubtensor
...
...
pytensor/link/numba/dispatch/subtensor.py
浏览文件 @
528b8d4b
...
...
@@ -287,11 +287,11 @@ def numba_funcify_AdvancedIncSubtensor1(op, node, **kwargs):
inplace
=
op
.
inplace
set_instead_of_inc
=
op
.
set_instead_of_inc
x
,
vals
,
idxs
=
node
.
inputs
# TODO: Add explicit expand_dims in make_node so we don't need to worry about this here
broadcast
=
vals
.
type
.
ndim
<
x
.
type
.
ndim
or
vals
.
type
.
broadcastable
[
0
]
broadcast_with_index
=
vals
.
type
.
ndim
<
x
.
type
.
ndim
or
vals
.
type
.
broadcastable
[
0
]
# TODO: Add runtime_broadcast check
if
set_instead_of_inc
:
if
broadcast
:
if
broadcast
_with_index
:
@numba_njit
(
boundscheck
=
True
)
def
advancedincsubtensor1_inplace
(
x
,
val
,
idxs
):
...
...
@@ -318,7 +318,7 @@ def numba_funcify_AdvancedIncSubtensor1(op, node, **kwargs):
x
[
idx
]
=
val
return
x
else
:
if
broadcast
:
if
broadcast
_with_index
:
@numba_njit
(
boundscheck
=
True
)
def
advancedincsubtensor1_inplace
(
x
,
val
,
idxs
):
...
...
pytensor/link/pytorch/dispatch/subtensor.py
浏览文件 @
528b8d4b
...
...
@@ -109,6 +109,8 @@ def pytorch_funcify_AdvancedIncSubtensor(op, node, **kwargs):
def
adv_set_subtensor
(
x
,
y
,
*
indices
):
check_negative_steps
(
indices
)
if
isinstance
(
op
,
AdvancedIncSubtensor1
):
op
.
_check_runtime_broadcasting
(
node
,
x
,
y
,
indices
)
if
not
inplace
:
x
=
x
.
clone
()
x
[
indices
]
=
y
.
type_as
(
x
)
...
...
@@ -120,6 +122,8 @@ def pytorch_funcify_AdvancedIncSubtensor(op, node, **kwargs):
def
adv_inc_subtensor_no_duplicates
(
x
,
y
,
*
indices
):
check_negative_steps
(
indices
)
if
isinstance
(
op
,
AdvancedIncSubtensor1
):
op
.
_check_runtime_broadcasting
(
node
,
x
,
y
,
indices
)
if
not
inplace
:
x
=
x
.
clone
()
x
[
indices
]
+=
y
.
type_as
(
x
)
...
...
pytensor/tensor/subtensor.py
浏览文件 @
528b8d4b
...
...
@@ -2262,6 +2262,12 @@ class AdvancedIncSubtensor1(COp):
check_input
=
False
params_type
=
ParamsType
(
inplace
=
ps
.
bool
,
set_instead_of_inc
=
ps
.
bool
)
_runtime_broadcast_error_msg
=
(
"Runtime broadcasting not allowed. "
"AdvancedIncSubtensor1 was asked to broadcast the second input (y) along a dimension that was not marked as broadcastable. "
"If broadcasting was intended, use `specify_broadcastable` on the relevant dimension(s)."
)
def
__init__
(
self
,
inplace
=
False
,
set_instead_of_inc
=
False
):
self
.
inplace
=
bool
(
inplace
)
self
.
set_instead_of_inc
=
bool
(
set_instead_of_inc
)
...
...
@@ -2333,6 +2339,9 @@ class AdvancedIncSubtensor1(COp):
NPY_ARRAY_ENSURECOPY, NULL)"""
def
c_support_code
(
self
,
**
kwargs
):
if
numpy_version
<
"1.8.0"
or
using_numpy_2
:
return
None
types
=
[
"npy_"
+
t
for
t
in
[
...
...
@@ -2523,15 +2532,117 @@ class AdvancedIncSubtensor1(COp):
return
code
def
c_code
(
self
,
node
,
name
,
input_names
,
output_names
,
sub
):
if
numpy_version
<
"1.8.0"
or
using_numpy_2
:
raise
NotImplementedError
x
,
y
,
idx
=
input_names
out
=
output_names
[
0
]
[
out
]
=
output_names
copy_of_x
=
self
.
copy_of_x
(
x
)
params
=
sub
[
"params"
]
fail
=
sub
[
"fail"
]
x_
,
y_
,
idx_
=
node
.
inputs
y_cdtype
=
y_
.
type
.
dtype_specs
()[
1
]
idx_cdtype
=
idx_
.
type
.
dtype_specs
()[
1
]
out_cdtype
=
node
.
outputs
[
0
]
.
type
.
dtype_specs
()[
1
]
y_bcast
=
y_
.
type
.
broadcastable
!=
idx_
.
type
.
broadcastable
if
(
x_
.
type
.
ndim
==
1
and
y_
.
type
.
ndim
==
1
and
not
y_bcast
and
x_
.
type
.
dtype
not
in
complex_dtypes
and
y_
.
type
.
dtype
not
in
complex_dtypes
):
# Simple implementation for vector x, y cases
idx_may_be_neg
=
not
(
isinstance
(
idx_
,
Constant
)
and
idx_
.
data
.
min
()
>=
0
)
idx_may_be_invalid
=
AdvancedSubtensor1
.
_idx_may_be_invalid
(
x_
,
idx_
)
shape0
=
x_
.
type
.
shape
[
0
]
# This is used to make sure that when we trust the indices to be valid
# we are not fooled by a wrong static shape
# We mention x to the user in error messages but we work (and make checks) on out,
# which should be x or a copy of it
unexpected_shape0
=
(
f
"PyArray_SHAPE({out})[0] != {shape0}"
if
shape0
is
not
None
else
"0"
)
op
=
"="
if
self
.
set_instead_of_inc
else
"+="
code
=
f
"""
if ({params}->inplace)
{{
if ({x} != {out})
{{
Py_XDECREF({out});
Py_INCREF({x});
{out} = {x};
}}
}}
else
{{
Py_XDECREF({out});
{out} = {copy_of_x};
if (!{out}) {{
// Exception already set
{fail}
}}
}}
if (PyArray_NDIM({out}) != 1) {{
PyErr_Format(PyExc_ValueError, "AdvancedIncSubtensor1: first input (x) ndim should be 1, got
%
d", PyArray_NDIM({out}));
{fail}
}}
if ({unexpected_shape0}) {{
PyErr_Format(PyExc_ValueError, "AdvancedIncSubtensor1: first input (x) shape should be {shape0}, got
%
d", PyArray_SHAPE({out})[0]);
{fail}
}}
if (PyArray_NDIM({idx}) != 1) {{
PyErr_Format(PyExc_ValueError, "AdvancedIncSubtensor1: indices ndim should be 1, got
%
d", PyArray_NDIM({idx}));
{fail}
}}
if (PyArray_NDIM({y}) != 1) {{
PyErr_Format(PyExc_ValueError, "AdvancedIncSubtensor1: second input (y) ndim should be 1, got
%
d", PyArray_NDIM({y}));
{fail}
}}
if (PyArray_SHAPE({y})[0] != PyArray_SHAPE({idx})[0]) {{
if ((PyArray_NDIM({y}) == 1) && (PyArray_SHAPE({y})[0] == 1)){{
PyErr_Format(PyExc_ValueError, "{self._runtime_broadcast_error_msg}");
}} else {{
PyErr_Format(PyExc_ValueError,
"AdvancedIncSubtensor1: Shapes of second input (y) and indices do not match:
%
d,
%
d",
PyArray_SHAPE({y})[0], PyArray_SHAPE({idx})[0]);
}}
{fail}
}}
{{
npy_intp out_shape0 = PyArray_SHAPE({out})[0];
{out_cdtype}* out_data = ({out_cdtype}*)PyArray_DATA({out});
{y_cdtype}* y_data = ({y_cdtype}*)PyArray_DATA({y});
{idx_cdtype}* idx_data = ({idx_cdtype}*)PyArray_DATA({idx});
npy_intp n = PyArray_SHAPE({idx})[0];
npy_intp out_jump = PyArray_STRIDES({out})[0] / PyArray_ITEMSIZE({out});
npy_intp y_jump = PyArray_STRIDES({y})[0] / PyArray_ITEMSIZE({y});
npy_intp idx_jump = PyArray_STRIDES({idx})[0] / PyArray_ITEMSIZE({idx});
for(int i = 0; i < n; i++){{
{idx_cdtype} idx = idx_data[i * idx_jump];
if ({int(idx_may_be_neg)}){{
if (idx < 0) {{
idx += out_shape0;
}}
}}
if ({int(idx_may_be_invalid)}){{
if ((idx < 0) || (idx >= out_shape0)) {{
PyErr_Format(PyExc_IndexError,"index
%
d out of bounds for array with shape
%
d", idx_data[i * idx_jump], out_shape0);
{fail}
}}
}}
out_data[idx * out_jump] {op} y_data[i * y_jump];
}}
}}
"""
return
code
if
numpy_version
<
"1.8.0"
or
using_numpy_2
:
raise
NotImplementedError
return
f
"""
PyObject* rval = NULL;
if ({params}->inplace)
...
...
@@ -2559,14 +2670,37 @@ class AdvancedIncSubtensor1(COp):
"""
def
c_code_cache_version
(
self
):
return
(
8
,)
return
(
9
,)
def
_check_runtime_broadcasting
(
self
,
node
:
Apply
,
x
:
np
.
ndarray
,
y
:
np
.
ndarray
,
idx
:
np
.
ndarray
)
->
None
:
if
y
.
ndim
>
0
:
y_pt_bcast
=
node
.
inputs
[
1
]
.
broadcastable
# type: ignore
if
not
y_pt_bcast
[
0
]
and
y
.
shape
[
0
]
==
1
and
y
.
shape
[
0
]
!=
idx
.
shape
[
0
]:
# Attempting to broadcast with index
raise
ValueError
(
self
.
_runtime_broadcast_error_msg
)
if
any
(
not
y_bcast
and
y_dim
==
1
and
y_dim
!=
x_dim
for
y_bcast
,
y_dim
,
x_dim
in
zip
(
reversed
(
y_pt_bcast
),
reversed
(
y
.
shape
),
reversed
(
x
.
shape
),
strict
=
False
,
)
):
# Attempting to broadcast with buffer
raise
ValueError
(
self
.
_runtime_broadcast_error_msg
)
def
perform
(
self
,
node
,
inputs
,
output_storage
):
x
,
y
,
idx
=
inputs
def
perform
(
self
,
node
,
inp
,
out_
):
x
,
y
,
idx
=
inp
(
out
,)
=
out_
if
not
self
.
inplace
:
x
=
x
.
copy
()
self
.
_check_runtime_broadcasting
(
node
,
x
,
y
,
idx
)
if
self
.
set_instead_of_inc
:
x
[
idx
]
=
y
else
:
...
...
@@ -2574,7 +2708,7 @@ class AdvancedIncSubtensor1(COp):
# many times: it does it only once.
np
.
add
.
at
(
x
,
idx
,
y
)
out
[
0
]
=
x
out
put_storage
[
0
]
[
0
]
=
x
def
infer_shape
(
self
,
fgraph
,
node
,
ishapes
):
x
,
y
,
ilist
=
ishapes
...
...
tests/tensor/test_subtensor.py
浏览文件 @
528b8d4b
...
...
@@ -5,6 +5,7 @@ from io import StringIO
import
numpy
as
np
import
pytest
from
numpy.testing
import
assert_array_equal
from
packaging
import
version
import
pytensor
import
pytensor.scalar
as
scal
...
...
@@ -26,7 +27,7 @@ from pytensor.tensor.blockwise import Blockwise
from
pytensor.tensor.elemwise
import
DimShuffle
from
pytensor.tensor.math
import
exp
,
isinf
,
lt
,
switch
from
pytensor.tensor.math
import
sum
as
pt_sum
from
pytensor.tensor.shape
import
specify_shape
from
pytensor.tensor.shape
import
specify_
broadcastable
,
specify_
shape
from
pytensor.tensor.subtensor
import
(
AdvancedIncSubtensor
,
AdvancedIncSubtensor1
,
...
...
@@ -1101,9 +1102,9 @@ class TestSubtensor(utt.OptimizationTestMixin):
n
=
self
.
shared
(
data
)
for
idx
in
idxs
:
# Should stay on the cpu.
idx_
=
shared
(
np
.
asarray
(
idx
))
t
=
n
[
idx_
]
idx_np
=
np
.
asarray
(
idx
)
idx_
pt
=
shared
(
idx_np
,
shape
=
(
1
if
idx_np
.
shape
[
0
]
==
1
else
None
,
))
t
=
n
[
idx_
pt
]
gn
=
pytensor
.
grad
(
pt_sum
(
exp
(
t
)),
n
)
f
=
self
.
function
([],
[
gn
,
gn
.
shape
],
op
=
AdvancedIncSubtensor1
)
topo
=
f
.
maker
.
fgraph
.
toposort
()
...
...
@@ -1126,13 +1127,13 @@ class TestSubtensor(utt.OptimizationTestMixin):
assert
np
.
allclose
(
gshape
,
data
.
shape
)
def
fct
(
t
):
return
pt_sum
(
t
[
idx_
])
return
pt_sum
(
t
[
idx_
pt
])
utt
.
verify_grad
(
fct
,
[
data
],
mode
=
self
.
mode
)
# Test the grad of the grad (e.i. AdvancedIncSubtensor1.grad)
def
fct2
(
t
):
return
pytensor
.
grad
(
pt_sum
(
t
[
idx_
]),
t
)
return
pytensor
.
grad
(
pt_sum
(
t
[
idx_
pt
]),
t
)
utt
.
verify_grad
(
fct2
,
[
data
],
mode
=
self
.
mode
)
...
...
@@ -1143,7 +1144,9 @@ class TestSubtensor(utt.OptimizationTestMixin):
ops
=
subtensor_ops
if
idx
is
idxs
[
0
]:
# TODO FIXME: This is a very poorly specified test.
f
=
self
.
function
([],
[
gn
.
shape
,
n
[
idx_
]
.
shape
],
op
=
ops
,
N
=
0
,
N_fast
=
0
)
f
=
self
.
function
(
[],
[
gn
.
shape
,
n
[
idx_pt
]
.
shape
],
op
=
ops
,
N
=
0
,
N_fast
=
0
)
f
()
def
test_wrong_exception_regression
(
self
):
...
...
@@ -1231,10 +1234,7 @@ class TestSubtensor(utt.OptimizationTestMixin):
data_num_init
=
np
.
arange
(
data_size
,
dtype
=
self
.
dtype
)
data_num_init
=
data_num_init
.
reshape
(
data_shape
)
inc_shapes
=
[
data_shape
[
i
:]
for
i
in
range
(
0
,
len
(
data_shape
)
+
1
)]
# Test broadcasting of y.
inc_shapes
+=
[(
1
,)
+
inc_shapes
[
-
1
][
1
:]]
for
inc_shape
in
inc_shapes
:
inc_n_dims
=
len
(
inc_shape
)
# We copy the numeric value to be 100% sure there is no
# risk of accidentally sharing it.
data_num
=
data_num_init
.
copy
()
...
...
@@ -1263,10 +1263,7 @@ class TestSubtensor(utt.OptimizationTestMixin):
replace
=
(
not
set_instead_of_inc
),
)
idx_num
=
idx_num
.
astype
(
"int64"
)
# Symbolic variable with increment value.
inc_var
=
TensorType
(
shape
=
(
None
,)
*
inc_n_dims
,
dtype
=
self
.
dtype
)()
# Trick for the case where `inc_shape` is the same as
# `data_shape`: what we actually want is the first
# shape element to be equal to the number of rows to
...
...
@@ -1275,6 +1272,15 @@ class TestSubtensor(utt.OptimizationTestMixin):
len
(
inc_shapes
)
==
0
or
inc_shape
[
0
]
!=
1
):
inc_shape
=
(
n_to_inc
,)
+
inc_shape
[
1
:]
# Symbolic variable with increment value.
inc_var_static_shape
=
tuple
(
1
if
dim_length
==
1
else
None
for
dim_length
in
inc_shape
)
inc_var
=
TensorType
(
shape
=
inc_var_static_shape
,
dtype
=
self
.
dtype
)()
# The param dtype is needed when inc_shape is empty.
# By default, it would return a float and rng.uniform
# with NumPy 1.10 will raise a Deprecation warning.
...
...
@@ -1341,6 +1347,31 @@ class TestSubtensor(utt.OptimizationTestMixin):
# you enable the debug code above.
assert
np
.
allclose
(
f_out
,
output_num
),
(
params
,
f_out
,
output_num
)
@pytest.mark.skipif
(
version
.
parse
(
np
.
__version__
)
<
version
.
parse
(
"2.0"
),
reason
=
"Legacy C-implementation did not check for runtime broadcast"
,
)
@pytest.mark.parametrize
(
"func"
,
(
advanced_inc_subtensor1
,
advanced_set_subtensor1
))
def
test_advanced1_inc_runtime_broadcast
(
self
,
func
):
y
=
matrix
(
"y"
,
dtype
=
"float64"
,
shape
=
(
None
,
None
))
x
=
ptb
.
zeros
((
10
,
5
))
idxs
=
np
.
repeat
(
np
.
arange
(
10
),
2
)
out
=
func
(
x
,
y
,
idxs
)
f
=
function
([
y
],
out
)
f
(
np
.
ones
((
20
,
5
)))
# Fine
with
pytest
.
raises
(
ValueError
,
match
=
"Runtime broadcasting not allowed. AdvancedIncSubtensor1 was asked"
,
):
f
(
np
.
ones
((
1
,
5
)))
with
pytest
.
raises
(
ValueError
,
match
=
"Runtime broadcasting not allowed. AdvancedIncSubtensor1 was asked"
,
):
f
(
np
.
ones
((
20
,
1
)))
def
test_adv_constant_arg
(
self
):
# Test case provided (and bug detected, gh-607) by John Salvatier
m
=
matrix
(
"m"
)
...
...
@@ -2398,7 +2429,11 @@ class TestInferShape(utt.InferShapeTester):
aivec_val
=
[
2
,
3
]
self
.
_compile_and_check
(
[
admat
,
bdmat
],
[
advanced_set_subtensor1
(
admat
,
bdmat
,
aivec_val
)],
[
advanced_set_subtensor1
(
admat
,
specify_broadcastable
(
bdmat
,
0
),
aivec_val
)
],
[
admat_val
,
[[
1
,
2
,
3
,
4
]]],
AdvancedIncSubtensor1
,
)
...
...
@@ -2425,7 +2460,11 @@ class TestInferShape(utt.InferShapeTester):
aivec_val
=
[
2
,
3
]
self
.
_compile_and_check
(
[
adtens4
,
bdtens4
],
[
advanced_set_subtensor1
(
adtens4
,
bdtens4
,
aivec_val
)],
[
advanced_set_subtensor1
(
adtens4
,
specify_broadcastable
(
bdtens4
,
0
,
1
,
2
),
aivec_val
)
],
[
adtens4_val
,
[[[[
1
,
2
,
3
,
4
,
5
]]]]],
AdvancedIncSubtensor1
,
warn
=
False
,
...
...
@@ -2476,7 +2515,11 @@ class TestInferShape(utt.InferShapeTester):
aivec_val
=
[
2
,
3
]
self
.
_compile_and_check
(
[
adtens4
,
bdtens4
],
[
advanced_set_subtensor1
(
adtens4
,
bdtens4
,
aivec_val
)],
[
advanced_set_subtensor1
(
adtens4
,
specify_broadcastable
(
bdtens4
,
1
,
2
),
aivec_val
)
],
[
adtens4_val
,
[[[[
1
,
2
,
3
,
4
,
5
]]],
[[[
6
,
7
,
8
,
9
,
10
]]]]],
AdvancedIncSubtensor1
,
warn
=
False
,
...
...
@@ -3028,3 +3071,29 @@ class TestBenchmarks:
)
fn
.
vm
.
allow_gc
=
gc
benchmark
(
fn
,
x_values
,
idxs_values
)
@pytest.mark.parametrize
(
"static_shape"
,
(
False
,
True
),
ids
=
lambda
x
:
f
"static_shape={x}"
)
@pytest.mark.parametrize
(
"gc"
,
(
False
,
True
),
ids
=
lambda
x
:
f
"gc={x}"
)
@pytest.mark.parametrize
(
"func"
,
(
inc_subtensor
,
set_subtensor
))
def
test_advanced_incsubtensor1
(
self
,
func
,
static_shape
,
gc
,
benchmark
):
x
=
vector
(
"x"
,
shape
=
(
85
if
static_shape
else
None
,))
x_values
=
np
.
zeros
((
85
,))
buffer
=
ptb
.
zeros_like
(
x
)
y_values
=
np
.
random
.
normal
(
size
=
(
85
*
11
,))
idxs_values
=
np
.
arange
(
85
)
.
repeat
(
11
)
# With static shape and constant indices we know all idxs are valid
# Reuse same buffer of zeros, to check we rather allocate twice than copy inside IncSubtensor
out1
=
func
(
buffer
[
idxs_values
],
y_values
)
out2
=
func
(
buffer
[
idxs_values
[::
-
1
]],
y_values
)
fn
=
pytensor
.
function
(
[
x
],
[
pytensor
.
Out
(
out1
,
borrow
=
True
),
pytensor
.
Out
(
out2
,
borrow
=
True
)],
on_unused_input
=
"ignore"
,
trust_input
=
True
,
)
fn
.
vm
.
allow_gc
=
gc
benchmark
(
fn
,
x_values
)
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论