Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
0fd160b6
提交
0fd160b6
authored
7月 31, 2025
作者:
Ricardo Vieira
提交者:
Ricardo Vieira
8月 01, 2025
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Implement static_shape inference for AdvancedSubtensor
上级
f7cc0f07
隐藏空白字符变更
内嵌
并排
正在显示
4 个修改的文件
包含
222 行增加
和
67 行删除
+222
-67
subtensor.py
pytensor/tensor/subtensor.py
+125
-65
type_other.py
pytensor/tensor/type_other.py
+4
-1
variable.py
pytensor/tensor/variable.py
+3
-1
test_subtensor.py
tests/tensor/test_subtensor.py
+90
-0
没有找到文件。
pytensor/tensor/subtensor.py
浏览文件 @
0fd160b6
...
...
@@ -2,7 +2,7 @@ import logging
import
sys
import
warnings
from
collections.abc
import
Callable
,
Iterable
,
Sequence
from
itertools
import
chain
,
groupby
from
itertools
import
chain
,
groupby
,
zip_longest
from
typing
import
cast
,
overload
import
numpy
as
np
...
...
@@ -39,7 +39,7 @@ from pytensor.tensor.basic import (
from
pytensor.tensor.blockwise
import
vectorize_node_fallback
from
pytensor.tensor.elemwise
import
DimShuffle
from
pytensor.tensor.exceptions
import
AdvancedIndexingError
,
NotScalarConstantError
from
pytensor.tensor.math
import
clip
from
pytensor.tensor.math
import
add
,
clip
from
pytensor.tensor.shape
import
Reshape
,
Shape_i
,
specify_broadcastable
from
pytensor.tensor.type
import
(
TensorType
,
...
...
@@ -63,6 +63,7 @@ from pytensor.tensor.type import (
from
pytensor.tensor.type_other
import
(
MakeSlice
,
NoneConst
,
NoneSliceConst
,
NoneTypeT
,
SliceConstant
,
SliceType
,
...
...
@@ -844,6 +845,24 @@ def as_nontensor_scalar(a: Variable) -> ps.ScalarVariable:
return
ps
.
as_scalar
(
a
)
def
slice_static_length
(
slc
,
dim_length
):
if
dim_length
is
None
:
# TODO: Some cases must be zero by definition, we could handle those
return
None
entries
=
[
None
,
None
,
None
]
for
i
,
entry
in
enumerate
((
slc
.
start
,
slc
.
stop
,
slc
.
step
)):
if
entry
is
None
:
continue
try
:
entries
[
i
]
=
get_scalar_constant_value
(
entry
)
except
NotScalarConstantError
:
return
None
return
len
(
range
(
*
slice
(
*
entries
)
.
indices
(
dim_length
)))
class
Subtensor
(
COp
):
"""Basic NumPy indexing operator."""
...
...
@@ -886,50 +905,15 @@ class Subtensor(COp):
)
padded
=
[
*
get_idx_list
((
None
,
*
inputs
)
,
self
.
idx_list
),
*
indices_from_subtensor
(
inputs
,
self
.
idx_list
),
*
[
slice
(
None
,
None
,
None
)]
*
(
x
.
type
.
ndim
-
len
(
idx_list
)),
]
out_shape
=
[]
def
extract_const
(
value
):
if
value
is
None
:
return
value
,
True
try
:
value
=
get_scalar_constant_value
(
value
)
return
value
,
True
except
NotScalarConstantError
:
return
value
,
False
for
the_slice
,
length
in
zip
(
padded
,
x
.
type
.
shape
,
strict
=
True
):
if
not
isinstance
(
the_slice
,
slice
):
continue
if
length
is
None
:
out_shape
.
append
(
None
)
continue
start
=
the_slice
.
start
stop
=
the_slice
.
stop
step
=
the_slice
.
step
is_slice_const
=
True
start
,
is_const
=
extract_const
(
start
)
is_slice_const
=
is_slice_const
and
is_const
stop
,
is_const
=
extract_const
(
stop
)
is_slice_const
=
is_slice_const
and
is_const
step
,
is_const
=
extract_const
(
step
)
is_slice_const
=
is_slice_const
and
is_const
if
not
is_slice_const
:
out_shape
.
append
(
None
)
continue
slice_length
=
len
(
range
(
*
slice
(
start
,
stop
,
step
)
.
indices
(
length
)))
out_shape
.
append
(
slice_length
)
out_shape
=
[
slice_static_length
(
slc
,
length
)
for
slc
,
length
in
zip
(
padded
,
x
.
type
.
shape
,
strict
=
True
)
if
isinstance
(
slc
,
slice
)
]
return
Apply
(
self
,
...
...
@@ -2826,36 +2810,112 @@ class AdvancedSubtensor(Op):
__props__
=
()
def
make_node
(
self
,
x
,
*
ind
ex
):
def
make_node
(
self
,
x
,
*
ind
ices
):
x
=
as_tensor_variable
(
x
)
index
=
tuple
(
map
(
as_index_variable
,
index
))
indices
=
tuple
(
map
(
as_index_variable
,
indices
))
explicit_indices
=
[]
new_axes
=
[]
for
idx
in
indices
:
if
isinstance
(
idx
.
type
,
TensorType
)
and
idx
.
dtype
==
"bool"
:
if
idx
.
type
.
ndim
==
0
:
raise
NotImplementedError
(
"Indexing with scalar booleans not supported"
)
# We create a fake symbolic shape tuple and identify the broadcast
# dimensions from the shape result of this entire subtensor operation.
with
config
.
change_flags
(
compute_test_value
=
"off"
):
fake_shape
=
tuple
(
tensor
(
dtype
=
"int64"
,
shape
=
())
if
s
!=
1
else
1
for
s
in
x
.
type
.
shape
)
# Check static shape aligned
axis
=
len
(
explicit_indices
)
-
len
(
new_axes
)
indexed_shape
=
x
.
type
.
shape
[
axis
:
axis
+
idx
.
type
.
ndim
]
for
j
,
(
indexed_length
,
indexer_length
)
in
enumerate
(
zip
(
indexed_shape
,
idx
.
type
.
shape
)
):
if
(
indexed_length
is
not
None
and
indexer_length
is
not
None
and
indexed_length
!=
indexer_length
):
raise
IndexError
(
f
"boolean index did not match indexed tensor along axis {axis + j};"
f
"size of axis is {indexed_length} but size of corresponding boolean axis is {indexer_length}"
)
# Convert boolean indices to integer with nonzero, to reason about static shape next
if
isinstance
(
idx
,
Constant
):
nonzero_indices
=
[
tensor_constant
(
i
)
for
i
in
idx
.
data
.
nonzero
()]
else
:
# Note: Sometimes we could infer a shape error by reasoning about the largest possible size of nonzero
# and seeing that other integer indices cannot possible match it
nonzero_indices
=
idx
.
nonzero
()
explicit_indices
.
extend
(
nonzero_indices
)
else
:
if
isinstance
(
idx
.
type
,
NoneTypeT
):
new_axes
.
append
(
len
(
explicit_indices
))
explicit_indices
.
append
(
idx
)
fake_index
=
tuple
(
chain
.
from_iterable
(
pytensor
.
tensor
.
basic
.
nonzero
(
idx
)
if
getattr
(
idx
,
"ndim"
,
0
)
>
0
and
getattr
(
idx
,
"dtype"
,
None
)
==
"bool"
else
(
idx
,)
for
idx
in
index
)
if
(
len
(
explicit_indices
)
-
len
(
new_axes
))
>
x
.
type
.
ndim
:
raise
IndexError
(
f
"too many indices for array: tensor is {x.type.ndim}-dimensional, but {len(explicit_indices) - len(new_axes)} were indexed"
)
out_shape
=
tuple
(
i
.
value
if
isinstance
(
i
,
Constant
)
else
None
for
i
in
indexed_result_shape
(
fake_shape
,
fake_index
)
)
# Perform basic and advanced indexing shape inference separately
basic_group_shape
=
[]
advanced_indices
=
[]
adv_group_axis
=
None
last_adv_group_axis
=
None
expanded_x_shape
=
tuple
(
np
.
insert
(
np
.
array
(
x
.
type
.
shape
,
dtype
=
object
),
1
,
new_axes
)
)
for
i
,
(
idx
,
dim_length
)
in
enumerate
(
zip_longest
(
explicit_indices
,
expanded_x_shape
,
fillvalue
=
NoneSliceConst
)
):
if
isinstance
(
idx
.
type
,
NoneTypeT
):
basic_group_shape
.
append
(
1
)
# New-axis
elif
isinstance
(
idx
.
type
,
SliceType
):
if
isinstance
(
idx
,
Constant
):
basic_group_shape
.
append
(
slice_static_length
(
idx
.
data
,
dim_length
))
elif
idx
.
owner
is
not
None
and
isinstance
(
idx
.
owner
.
op
,
MakeSlice
):
basic_group_shape
.
append
(
slice_static_length
(
slice
(
*
idx
.
owner
.
inputs
),
dim_length
)
)
else
:
# Symbolic root slice (owner is None), or slice operation we don't understand
basic_group_shape
.
append
(
None
)
else
:
# TensorType
# Keep track of advanced group axis
if
adv_group_axis
is
None
:
# First time we see an advanced index
adv_group_axis
,
last_adv_group_axis
=
i
,
i
elif
last_adv_group_axis
==
(
i
-
1
):
# Another advanced indexing aligned with the first group
last_adv_group_axis
=
i
else
:
# Non-consecutive advanced index, all advanced index views get moved to the front
adv_group_axis
=
0
advanced_indices
.
append
(
idx
)
if
advanced_indices
:
try
:
# Use variadic add to infer static shape of advanced integer indices
advanced_group_static_shape
=
add
(
*
advanced_indices
)
.
type
.
shape
except
ValueError
:
# It fails when static shapes are inconsistent
static_shapes
=
[
idx
.
type
.
shape
for
idx
in
advanced_indices
]
raise
IndexError
(
f
"shape mismatch: indexing tensors could not be broadcast together with shapes {static_shapes}"
)
# Combine advanced and basic views
indexed_shape
=
[
*
basic_group_shape
[:
adv_group_axis
],
*
advanced_group_static_shape
,
*
basic_group_shape
[
adv_group_axis
:],
]
else
:
# This could have been a basic subtensor!
indexed_shape
=
basic_group_shape
return
Apply
(
self
,
(
x
,
*
index
)
,
[
tensor
(
dtype
=
x
.
type
.
dtype
,
shape
=
out_shape
)],
[
x
,
*
indices
]
,
[
tensor
(
dtype
=
x
.
type
.
dtype
,
shape
=
tuple
(
indexed_shape
)
)],
)
def
R_op
(
self
,
inputs
,
eval_points
):
...
...
pytensor/tensor/type_other.py
浏览文件 @
0fd160b6
...
...
@@ -114,6 +114,9 @@ def as_symbolic_slice(x, **kwargs):
return
SliceConstant
(
slicetype
,
x
)
NoneSliceConst
=
Constant
(
slicetype
,
slice
(
None
),
name
=
"slice(None)"
)
class
NoneTypeT
(
Generic
):
"""
Inherit from Generic to have c code working.
...
...
@@ -137,4 +140,4 @@ def as_symbolic_None(x, **kwargs):
return
NoneConst
__all__
=
[
"make_slice"
,
"slicetype"
,
"none_type_t"
,
"NoneConst"
]
__all__
=
[
"make_slice"
,
"slicetype"
,
"none_type_t"
,
"NoneConst"
,
"NoneSliceConst"
]
pytensor/tensor/variable.py
浏览文件 @
0fd160b6
...
...
@@ -506,7 +506,9 @@ class _tensor_py_operators:
# Check if the number of dimensions isn't too large.
if
self
.
ndim
<
index_dim_count
:
raise
IndexError
(
"too many indices for array"
)
raise
IndexError
(
f
"too many indices for tensor: tensor is {self.ndim}-dimensional, but {index_dim_count} were indexed"
)
# Convert an Ellipsis if provided into an appropriate number of
# slice(None).
...
...
tests/tensor/test_subtensor.py
浏览文件 @
0fd160b6
import
logging
import
re
import
sys
from
io
import
StringIO
...
...
@@ -1847,6 +1848,95 @@ class TestAdvancedSubtensor:
self
.
ix2
=
lmatrix
()
self
.
ixr
=
lrow
()
def
test_static_shape
(
self
):
x
=
tensor
(
"x"
,
shape
=
(
None
,
None
))
y
=
tensor
(
"y"
,
shape
=
(
4
,
5
,
6
))
idx1
=
tensor
(
"idx1"
,
shape
=
(
10
,),
dtype
=
int
)
idx2
=
tensor
(
"idx2"
,
shape
=
(
3
,
None
),
dtype
=
int
)
assert
x
[
idx1
]
.
type
.
shape
==
(
10
,
None
)
assert
x
[:,
idx1
]
.
type
.
shape
==
(
None
,
10
)
assert
x
[
idx2
,
:
5
]
.
type
.
shape
==
(
3
,
None
,
None
)
assert
specify_shape
(
x
,
(
None
,
7
))[
idx2
,
:
5
]
.
type
.
shape
==
(
3
,
None
,
5
)
assert
specify_shape
(
x
,
(
None
,
3
))[
idx2
,
:
5
]
.
type
.
shape
==
(
3
,
None
,
3
)
assert
x
[
idx1
,
idx2
]
.
type
.
shape
==
(
3
,
10
)
assert
x
[
idx2
,
idx1
]
.
type
.
shape
==
(
3
,
10
)
assert
x
[
None
,
idx1
,
idx2
]
.
type
.
shape
==
(
1
,
3
,
10
)
assert
x
[
idx1
,
None
,
idx2
]
.
type
.
shape
==
(
3
,
10
,
1
)
assert
x
[
idx1
,
idx2
,
None
]
.
type
.
shape
==
(
3
,
10
,
1
)
assert
y
[
idx1
,
idx2
,
::
-
1
]
.
type
.
shape
==
(
3
,
10
,
6
)
assert
y
[
idx1
,
::
-
1
,
idx2
]
.
type
.
shape
==
(
3
,
10
,
5
)
assert
y
[::
-
1
,
idx1
,
idx2
]
.
type
.
shape
==
(
4
,
3
,
10
)
assert
y
[::
-
1
,
idx1
,
None
,
idx2
]
.
type
.
shape
==
(
3
,
10
,
4
,
1
)
msg
=
re
.
escape
(
"shape mismatch: indexing tensors could not be broadcast together with shapes [(10,), (9,)]"
)
with
pytest
.
raises
(
IndexError
,
match
=
msg
):
x
[
idx1
,
idx1
[
1
:]]
def
test_static_shape_boolean
(
self
):
y
=
tensor
(
"y"
,
shape
=
(
4
,
5
,
6
))
idx1
=
tensor
(
"idx1"
,
shape
=
(
4
,),
dtype
=
int
)
idx2
=
tensor
(
"idx2"
,
shape
=
(
3
,
None
),
dtype
=
int
)
bool_idx1
=
tensor
(
"bool_idx1"
,
shape
=
(
4
,),
dtype
=
bool
)
bool_idx2
=
tensor
(
"bool_idx2"
,
shape
=
(
None
,
5
,
),
dtype
=
bool
,
)
assert
y
[
bool_idx1
]
.
type
.
shape
==
(
None
,
5
,
6
)
assert
y
[
bool_idx1
,
:,
None
:
-
4
:
-
1
]
.
type
.
shape
==
(
None
,
5
,
3
)
assert
y
[
bool_idx1
,
idx2
]
.
type
.
shape
==
(
3
,
None
,
6
)
assert
y
[
bool_idx1
,
idx1
,
:]
.
type
.
shape
==
(
4
,
6
)
assert
y
[
bool_idx1
,
:,
idx1
]
.
type
.
shape
==
(
4
,
5
)
assert
y
[
bool_idx1
,
idx1
,
idx2
]
.
type
.
shape
==
(
3
,
4
)
assert
y
[
None
,
bool_idx1
,
None
,
idx2
,
None
,
idx1
]
.
type
.
shape
==
(
3
,
4
,
1
,
1
,
1
)
assert
y
[
bool_idx2
,
:]
.
type
.
shape
==
(
None
,
6
)
assert
y
[
bool_idx2
,
idx1
]
.
type
.
shape
==
(
4
,)
assert
y
[
bool_idx2
,
idx2
]
.
type
.
shape
==
(
3
,
None
)
msg
=
re
.
escape
(
"too many indices for tensor: tensor is 3-dimensional, but 4 were indexed"
)
with
pytest
.
raises
(
IndexError
,
match
=
msg
):
y
[
bool_idx2
,
bool_idx2
]
# Case that could conceivably be detected as index error at definition time
bad_idx
=
ptb
.
concatenate
([
idx1
,
idx1
])
assert
y
[
bool_idx1
,
bad_idx
]
.
type
.
shape
==
(
8
,
6
)
def
test_static_shape_constant_boolean
(
self
):
y
=
tensor
(
"y"
,
shape
=
(
None
,
None
,
None
))
idx1
=
tensor
(
"idx1"
,
shape
=
(
3
,),
dtype
=
int
)
idx2
=
tensor
(
"idx2"
,
shape
=
(
4
,
None
),
dtype
=
int
)
bool_idx1
=
constant
(
np
.
array
([
True
,
False
,
True
,
True
]),
name
=
"bool_idx1"
)
bool_idx2
=
constant
(
np
.
array
([[
True
,
False
,
True
,
True
],
[
True
,
False
,
False
,
True
]]),
name
=
"bool_idx2"
,
)
assert
y
[
bool_idx1
]
.
type
.
shape
==
(
3
,
None
,
None
)
assert
y
[
bool_idx1
,
:,
idx1
]
.
type
.
shape
==
(
3
,
None
)
assert
y
[
bool_idx1
,
:,
idx2
]
.
type
.
shape
==
(
4
,
3
,
None
)
assert
y
[
bool_idx2
]
.
type
.
shape
==
(
5
,
None
)
assert
y
[
bool_idx1
,
idx2
]
.
type
.
shape
==
(
4
,
3
,
None
)
bad_idx
=
ptb
.
concatenate
([
idx1
,
idx1
])
msg
=
re
.
escape
(
"shape mismatch: indexing tensors could not be broadcast together with shapes [(3,), (6,)]"
)
with
pytest
.
raises
(
IndexError
,
match
=
msg
):
y
[
bool_idx1
,
bad_idx
]
@pytest.mark.parametrize
(
"inplace"
,
[
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论