Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
3bf15cac
提交
3bf15cac
authored
5月 21, 2025
作者:
ricardoV94
提交者:
Ricardo Vieira
6月 21, 2025
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Implement index for XTensorVariables
上级
42936163
全部展开
显示空白字符变更
内嵌
并排
正在显示
6 个修改的文件
包含
439 行增加
和
3 行删除
+439
-3
__init__.py
pytensor/xtensor/__init__.py
+0
-1
indexing.py
pytensor/xtensor/indexing.py
+186
-0
__init__.py
pytensor/xtensor/rewriting/__init__.py
+1
-0
indexing.py
pytensor/xtensor/rewriting/indexing.py
+150
-0
type.py
pytensor/xtensor/type.py
+102
-2
test_indexing.py
tests/xtensor/test_indexing.py
+0
-0
没有找到文件。
pytensor/xtensor/__init__.py
浏览文件 @
3bf15cac
...
...
@@ -4,7 +4,6 @@ import pytensor.xtensor.rewriting
from
pytensor.xtensor
import
linalg
from
pytensor.xtensor.shape
import
concat
from
pytensor.xtensor.type
import
(
XTensorType
,
as_xtensor
,
xtensor
,
xtensor_constant
,
...
...
pytensor/xtensor/indexing.py
0 → 100644
浏览文件 @
3bf15cac
# HERE LIE DRAGONS
# Useful links to make sense of all the numpy/xarray complexity
# https://numpy.org/devdocs//user/basics.indexing.html
# https://numpy.org/neps/nep-0021-advanced-indexing.html
# https://docs.xarray.dev/en/latest/user-guide/indexing.html
# https://tutorial.xarray.dev/intermediate/indexing/advanced-indexing.html
from
pytensor.graph.basic
import
Apply
,
Constant
,
Variable
from
pytensor.scalar.basic
import
discrete_dtypes
from
pytensor.tensor.basic
import
as_tensor
from
pytensor.tensor.type_other
import
NoneTypeT
,
SliceType
,
make_slice
from
pytensor.xtensor.basic
import
XOp
,
xtensor_from_tensor
from
pytensor.xtensor.type
import
XTensorType
,
as_xtensor
,
xtensor
def
as_idx_variable
(
idx
,
indexed_dim
:
str
):
if
idx
is
None
or
(
isinstance
(
idx
,
Variable
)
and
isinstance
(
idx
.
type
,
NoneTypeT
)):
raise
TypeError
(
"XTensors do not support indexing with None (np.newaxis), use expand_dims instead"
)
if
isinstance
(
idx
,
slice
):
idx
=
make_slice
(
idx
)
elif
isinstance
(
idx
,
Variable
)
and
isinstance
(
idx
.
type
,
SliceType
):
pass
elif
(
isinstance
(
idx
,
tuple
)
and
len
(
idx
)
==
2
and
(
isinstance
(
idx
[
0
],
str
)
or
(
isinstance
(
idx
[
0
],
tuple
|
list
)
and
all
(
isinstance
(
d
,
str
)
for
d
in
idx
[
0
])
)
)
):
# Special case for ("x", array) that xarray supports
dim
,
idx
=
idx
if
isinstance
(
idx
,
Variable
)
and
isinstance
(
idx
.
type
,
XTensorType
):
raise
IndexError
(
f
"Giving a dimension name to an XTensorVariable indexer is not supported: {(dim, idx)}. "
"Use .rename() instead."
)
if
isinstance
(
dim
,
str
):
dims
=
(
dim
,)
else
:
dims
=
tuple
(
dim
)
idx
=
as_xtensor
(
as_tensor
(
idx
),
dims
=
dims
)
else
:
# Must be integer / boolean indices, we already counted for None and slices
try
:
idx
=
as_xtensor
(
idx
)
except
TypeError
:
idx
=
as_tensor
(
idx
)
if
idx
.
type
.
ndim
>
1
:
# Same error that xarray raises
raise
IndexError
(
"Unlabeled multi-dimensional array cannot be used for indexing"
)
# This is implicitly an XTensorVariable with dim matching the indexed one
idx
=
xtensor_from_tensor
(
idx
,
dims
=
(
indexed_dim
,)[:
idx
.
type
.
ndim
])
if
idx
.
type
.
dtype
==
"bool"
:
if
idx
.
type
.
ndim
!=
1
:
# xarray allaws `x[True]`, but I think it is a bug: https://github.com/pydata/xarray/issues/10379
# Otherwise, it is always restricted to 1d boolean indexing arrays
raise
NotImplementedError
(
"Only 1d boolean indexing arrays are supported"
)
if
idx
.
type
.
dims
!=
(
indexed_dim
,):
raise
IndexError
(
"Boolean indexer should be unlabeled or on the same dimension to the indexed array. "
f
"Indexer is on {idx.type.dims} but the target dimension is {indexed_dim}."
)
# Convert to nonzero indices
idx
=
as_xtensor
(
idx
.
values
.
nonzero
()[
0
],
dims
=
idx
.
type
.
dims
)
elif
idx
.
type
.
dtype
not
in
discrete_dtypes
:
raise
TypeError
(
"Numerical indices must be integers or boolean"
)
return
idx
def
get_static_slice_length
(
slc
:
Variable
,
dim_length
:
None
|
int
)
->
int
|
None
:
if
dim_length
is
None
:
return
None
if
isinstance
(
slc
,
Constant
):
d
=
slc
.
data
start
,
stop
,
step
=
d
.
start
,
d
.
stop
,
d
.
step
elif
slc
.
owner
is
None
:
# It's a root variable no way of knowing what we're getting
return
None
else
:
# It's a MakeSliceOp
start
,
stop
,
step
=
slc
.
owner
.
inputs
if
isinstance
(
start
,
Constant
):
start
=
start
.
data
else
:
return
None
if
isinstance
(
stop
,
Constant
):
stop
=
stop
.
data
else
:
return
None
if
isinstance
(
step
,
Constant
):
step
=
step
.
data
else
:
return
None
return
len
(
range
(
*
slice
(
start
,
stop
,
step
)
.
indices
(
dim_length
)))
class
Index
(
XOp
):
__props__
=
()
def
make_node
(
self
,
x
,
*
idxs
):
x
=
as_xtensor
(
x
)
if
any
(
idx
is
Ellipsis
for
idx
in
idxs
):
if
idxs
.
count
(
Ellipsis
)
>
1
:
raise
IndexError
(
"an index can only have a single ellipsis ('...')"
)
# Convert intermediate Ellipsis to slice(None)
ellipsis_loc
=
idxs
.
index
(
Ellipsis
)
n_implied_none_slices
=
x
.
type
.
ndim
-
(
len
(
idxs
)
-
1
)
idxs
=
(
*
idxs
[:
ellipsis_loc
],
*
((
slice
(
None
),)
*
n_implied_none_slices
),
*
idxs
[
ellipsis_loc
+
1
:],
)
x_ndim
=
x
.
type
.
ndim
x_dims
=
x
.
type
.
dims
x_shape
=
x
.
type
.
shape
out_dims
=
[]
out_shape
=
[]
def
combine_dim_info
(
idx_dim
,
idx_dim_shape
):
if
idx_dim
not
in
out_dims
:
# First information about the dimension length
out_dims
.
append
(
idx_dim
)
out_shape
.
append
(
idx_dim_shape
)
else
:
# Dim already introduced in output by a previous index
# Update static shape or raise if incompatible
out_dim_pos
=
out_dims
.
index
(
idx_dim
)
out_dim_shape
=
out_shape
[
out_dim_pos
]
if
out_dim_shape
is
None
:
# We don't know the size of the dimension yet
out_shape
[
out_dim_pos
]
=
idx_dim_shape
elif
idx_dim_shape
is
not
None
and
idx_dim_shape
!=
out_dim_shape
:
raise
IndexError
(
f
"Dimension of indexers mismatch for dim {idx_dim}"
)
if
len
(
idxs
)
>
x_ndim
:
raise
IndexError
(
"Too many indices"
)
idxs
=
[
as_idx_variable
(
idx
,
dim
)
for
idx
,
dim
in
zip
(
idxs
,
x_dims
,
strict
=
False
)
]
for
i
,
idx
in
enumerate
(
idxs
):
if
isinstance
(
idx
.
type
,
SliceType
):
idx_dim
=
x_dims
[
i
]
idx_dim_shape
=
get_static_slice_length
(
idx
,
x_shape
[
i
])
combine_dim_info
(
idx_dim
,
idx_dim_shape
)
else
:
if
idx
.
type
.
ndim
==
0
:
# Scalar index, dimension is dropped
continue
assert
isinstance
(
idx
.
type
,
XTensorType
)
idx_dims
=
idx
.
type
.
dims
for
idx_dim
in
idx_dims
:
idx_dim_shape
=
idx
.
type
.
shape
[
idx_dims
.
index
(
idx_dim
)]
combine_dim_info
(
idx_dim
,
idx_dim_shape
)
for
dim_i
,
shape_i
in
zip
(
x_dims
[
i
+
1
:],
x_shape
[
i
+
1
:]):
# Add back any unindexed dimensions
if
dim_i
not
in
out_dims
:
# If the dimension was not indexed, we keep it as is
combine_dim_info
(
dim_i
,
shape_i
)
output
=
xtensor
(
dtype
=
x
.
type
.
dtype
,
shape
=
out_shape
,
dims
=
out_dims
)
return
Apply
(
self
,
[
x
,
*
idxs
],
[
output
])
index
=
Index
()
pytensor/xtensor/rewriting/__init__.py
浏览文件 @
3bf15cac
import
pytensor.xtensor.rewriting.basic
import
pytensor.xtensor.rewriting.indexing
import
pytensor.xtensor.rewriting.reduction
import
pytensor.xtensor.rewriting.shape
import
pytensor.xtensor.rewriting.vectorization
pytensor/xtensor/rewriting/indexing.py
0 → 100644
浏览文件 @
3bf15cac
from
itertools
import
zip_longest
from
pytensor
import
as_symbolic
from
pytensor.graph
import
Constant
,
node_rewriter
from
pytensor.tensor
import
TensorType
,
arange
,
specify_shape
from
pytensor.tensor.subtensor
import
_non_consecutive_adv_indexing
from
pytensor.tensor.type_other
import
NoneTypeT
,
SliceType
from
pytensor.xtensor.basic
import
tensor_from_xtensor
,
xtensor_from_tensor
from
pytensor.xtensor.indexing
import
Index
from
pytensor.xtensor.rewriting.utils
import
register_lower_xtensor
from
pytensor.xtensor.type
import
XTensorType
def
to_basic_idx
(
idx
):
if
isinstance
(
idx
.
type
,
SliceType
):
if
isinstance
(
idx
,
Constant
):
return
idx
.
data
elif
idx
.
owner
:
# MakeSlice Op
# We transform NoneConsts to regular None so that basic Subtensor can be used if possible
return
slice
(
*
[
None
if
isinstance
(
i
.
type
,
NoneTypeT
)
else
i
for
i
in
idx
.
owner
.
inputs
]
)
else
:
return
idx
if
(
isinstance
(
idx
.
type
,
XTensorType
)
and
idx
.
type
.
ndim
==
0
and
idx
.
type
.
dtype
!=
bool
):
return
idx
.
values
raise
TypeError
(
"Cannot convert idx to basic idx"
)
@register_lower_xtensor
@node_rewriter
(
tracks
=
[
Index
])
def
lower_index
(
fgraph
,
node
):
"""Lower XTensorVariable indexing to regular TensorVariable indexing.
xarray-like indexing has two modes:
1. Orthogonal indexing: Indices of different output labeled dimensions are combined to produce all combinations of indices.
2. Vectorized indexing: Indices of the same output labeled dimension are combined point-wise like in regular numpy advanced indexing.
An Index Op can combine both modes.
To achieve orthogonal indexing using numpy semantics we must use multidimensional advanced indexing.
We expand the dims of each index so they are as large as the number of output dimensions, place the indices that
belong to the same output dimension in the same axis, and those that belong to different output dimensions in different axes.
For instance to do an outer 2x2 indexing we can select x[arange(x.shape[0])[:, None], arange(x.shape[1])[None, :]],
This is a generalization of `np.ix_` that allows combining some dimensions, and not others, as well as have
indices that have more than one dimension at the start.
In addition, xarray basic index (slices), can be vectorized with other advanced indices (if they act on the same output dimension).
However, in numpy, basic indices are always orthogonal to advanced indices. To make them behave like vectorized indices
we have to convert the slices to equivalent advanced indices.
We do this by creating an `arange` tensor that matches the shape of the dimension being indexed,
and then indexing it with the original slice. This index is then handled as a regular advanced index.
Note: The IndexOp has only 2 types of indices: Slices and XTensorVariables. Regular array indices
are converted to the appropriate XTensorVariable by `Index.make_node`
"""
x
,
*
idxs
=
node
.
inputs
[
out
]
=
node
.
outputs
x_tensor
=
tensor_from_xtensor
(
x
)
if
all
(
(
isinstance
(
idx
.
type
,
SliceType
)
or
(
isinstance
(
idx
.
type
,
XTensorType
)
and
idx
.
type
.
ndim
==
0
)
)
for
idx
in
idxs
):
# Special case having just basic indexing
x_tensor_indexed
=
x_tensor
[
tuple
(
to_basic_idx
(
idx
)
for
idx
in
idxs
)]
else
:
# General case, we have to align the indices positionally to achieve vectorized or orthogonal indexing
# May need to convert basic indexing to advanced indexing if it acts on a dimension that is also indexed by an advanced index
x_dims
=
x
.
type
.
dims
x_shape
=
tuple
(
x
.
shape
)
out_ndim
=
out
.
type
.
ndim
out_dims
=
out
.
type
.
dims
aligned_idxs
=
[]
basic_idx_axis
=
[]
# zip_longest adds the implicit slice(None)
for
i
,
(
idx
,
x_dim
)
in
enumerate
(
zip_longest
(
idxs
,
x_dims
,
fillvalue
=
as_symbolic
(
slice
(
None
)))
):
if
isinstance
(
idx
.
type
,
SliceType
):
if
not
any
(
(
isinstance
(
other_idx
.
type
,
XTensorType
)
and
x_dim
in
other_idx
.
dims
)
for
j
,
other_idx
in
enumerate
(
idxs
)
if
j
!=
i
):
# We can use basic indexing directly if no other index acts on this dimension
# This is an optimization that avoids creating an unnecessary arange tensor
# and facilitates the use of the specialized AdvancedSubtensor1 when possible
aligned_idxs
.
append
(
idx
)
basic_idx_axis
.
append
(
out_dims
.
index
(
x_dim
))
else
:
# Otherwise we need to convert the basic index into an equivalent advanced indexing
# And align it so it interacts correctly with the other advanced indices
adv_idx_equivalent
=
arange
(
x_shape
[
i
])[
to_basic_idx
(
idx
)]
ds_order
=
[
"x"
]
*
out_ndim
ds_order
[
out_dims
.
index
(
x_dim
)]
=
0
aligned_idxs
.
append
(
adv_idx_equivalent
.
dimshuffle
(
ds_order
))
else
:
assert
isinstance
(
idx
.
type
,
XTensorType
)
if
idx
.
type
.
ndim
==
0
:
# Scalar index, we can use it directly
aligned_idxs
.
append
(
idx
.
values
)
else
:
# Vector index, we need to align the indexing dimensions with the base_dims
ds_order
=
[
"x"
]
*
out_ndim
for
j
,
idx_dim
in
enumerate
(
idx
.
dims
):
ds_order
[
out_dims
.
index
(
idx_dim
)]
=
j
aligned_idxs
.
append
(
idx
.
values
.
dimshuffle
(
ds_order
))
# Squeeze indexing dimensions that were not used because we kept basic indexing slices
if
basic_idx_axis
:
aligned_idxs
=
[
idx
.
squeeze
(
axis
=
basic_idx_axis
)
if
(
isinstance
(
idx
.
type
,
TensorType
)
and
idx
.
type
.
ndim
>
0
)
else
idx
for
idx
in
aligned_idxs
]
x_tensor_indexed
=
x_tensor
[
tuple
(
aligned_idxs
)]
if
basic_idx_axis
and
_non_consecutive_adv_indexing
(
aligned_idxs
):
# Numpy moves advanced indexing dimensions to the front when they are not consecutive
# We need to transpose them back to the expected output order
x_tensor_indexed_basic_dims
=
[
out_dims
[
axis
]
for
axis
in
basic_idx_axis
]
x_tensor_indexed_dims
=
[
dim
for
dim
in
out_dims
if
dim
not
in
x_tensor_indexed_basic_dims
]
+
x_tensor_indexed_basic_dims
transpose_order
=
[
x_tensor_indexed_dims
.
index
(
dim
)
for
dim
in
out_dims
]
x_tensor_indexed
=
x_tensor_indexed
.
transpose
(
transpose_order
)
# Add lost shape information
x_tensor_indexed
=
specify_shape
(
x_tensor_indexed
,
out
.
type
.
shape
)
new_out
=
xtensor_from_tensor
(
x_tensor_indexed
,
dims
=
out
.
type
.
dims
)
return
[
new_out
]
pytensor/xtensor/type.py
浏览文件 @
3bf15cac
import
typing
import
warnings
from
types
import
EllipsisType
from
pytensor.compile
import
(
...
...
@@ -24,7 +25,7 @@ except ModuleNotFoundError:
XARRAY_AVAILABLE
=
False
from
collections.abc
import
Sequence
from
typing
import
Literal
,
TypeVar
from
typing
import
Any
,
Literal
,
TypeVar
import
numpy
as
np
...
...
@@ -421,7 +422,106 @@ class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]):
raise
NotImplementedError
(
"sel not implemented for XTensorVariable"
)
def
__getitem__
(
self
,
idx
):
raise
NotImplementedError
(
"Indexing not yet implemnented"
)
if
isinstance
(
idx
,
dict
):
return
self
.
isel
(
idx
)
if
not
isinstance
(
idx
,
tuple
):
idx
=
(
idx
,)
return
px
.
indexing
.
index
(
self
,
*
idx
)
def
isel
(
self
,
indexers
:
dict
[
str
,
Any
]
|
None
=
None
,
drop
:
bool
=
False
,
# Unused by PyTensor
missing_dims
:
Literal
[
"raise"
,
"warn"
,
"ignore"
]
=
"raise"
,
**
indexers_kwargs
,
):
if
indexers_kwargs
:
if
indexers
is
not
None
:
raise
ValueError
(
"Cannot pass both indexers and indexers_kwargs to isel"
)
indexers
=
indexers_kwargs
if
not
indexers
:
# No-op
return
self
if
missing_dims
not
in
{
"raise"
,
"warn"
,
"ignore"
}:
raise
ValueError
(
f
"Unrecognized options {missing_dims} for missing_dims argument"
)
# Sort indices and pass them to index
dims
=
self
.
type
.
dims
indices
=
[
slice
(
None
)]
*
self
.
type
.
ndim
for
key
,
idx
in
indexers
.
items
():
if
idx
is
Ellipsis
:
# Xarray raises a less informative error, suggesting indices must be integer
# But slices are also fine
raise
TypeError
(
"Ellipsis (...) is an invalid labeled index"
)
try
:
indices
[
dims
.
index
(
key
)]
=
idx
except
IndexError
:
if
missing_dims
==
"raise"
:
raise
ValueError
(
f
"Dimension {key} does not exist. Expected one of {dims}"
)
elif
missing_dims
==
"warn"
:
warnings
.
warn
(
f
"Dimension {key} does not exist. Expected one of {dims}"
,
UserWarning
,
)
return
px
.
indexing
.
index
(
self
,
*
indices
)
def
_head_tail_or_thin
(
self
,
indexers
:
dict
[
str
,
Any
]
|
int
|
None
,
indexers_kwargs
:
dict
[
str
,
Any
],
*
,
kind
:
Literal
[
"head"
,
"tail"
,
"thin"
],
):
if
indexers_kwargs
:
if
indexers
is
not
None
:
raise
ValueError
(
"Cannot pass both indexers and indexers_kwargs to head"
)
indexers
=
indexers_kwargs
if
indexers
is
None
:
if
kind
==
"thin"
:
raise
TypeError
(
"thin() indexers must be either dict-like or a single integer"
)
else
:
# Default to 5 for head and tail
indexers
=
{
dim
:
5
for
dim
in
self
.
type
.
dims
}
elif
not
isinstance
(
indexers
,
dict
):
indexers
=
{
dim
:
indexers
for
dim
in
self
.
type
.
dims
}
if
kind
==
"head"
:
indices
=
{
dim
:
slice
(
None
,
value
)
for
dim
,
value
in
indexers
.
items
()}
elif
kind
==
"tail"
:
sizes
=
self
.
sizes
# Can't use slice(-value, None), in case value is zero
indices
=
{
dim
:
slice
(
sizes
[
dim
]
-
value
,
None
)
for
dim
,
value
in
indexers
.
items
()
}
elif
kind
==
"thin"
:
indices
=
{
dim
:
slice
(
None
,
None
,
value
)
for
dim
,
value
in
indexers
.
items
()}
return
self
.
isel
(
indices
)
def
head
(
self
,
indexers
:
dict
[
str
,
Any
]
|
int
|
None
=
None
,
**
indexers_kwargs
):
return
self
.
_head_tail_or_thin
(
indexers
,
indexers_kwargs
,
kind
=
"head"
)
def
tail
(
self
,
indexers
:
dict
[
str
,
Any
]
|
int
|
None
=
None
,
**
indexers_kwargs
):
return
self
.
_head_tail_or_thin
(
indexers
,
indexers_kwargs
,
kind
=
"tail"
)
def
thin
(
self
,
indexers
:
dict
[
str
,
Any
]
|
int
|
None
=
None
,
**
indexers_kwargs
):
return
self
.
_head_tail_or_thin
(
indexers
,
indexers_kwargs
,
kind
=
"thin"
)
# ndarray methods
# https://docs.xarray.dev/en/latest/api.html#id7
...
...
tests/xtensor/test_indexing.py
0 → 100644
浏览文件 @
3bf15cac
差异被折叠。
点击展开。
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论