Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
5024d54e
提交
5024d54e
authored
7月 11, 2025
作者:
Ricardo Vieira
提交者:
Ricardo Vieira
7月 14, 2025
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add docstrings to more XTensorVariable methods
Also remove broadcast which is not a method in Xarray
上级
fdb40877
隐藏空白字符变更
内嵌
并排
正在显示
1 个修改的文件
包含
152 行增加
和
23 行删除
+152
-23
type.py
pytensor/xtensor/type.py
+152
-23
没有找到文件。
pytensor/xtensor/type.py
浏览文件 @
5024d54e
...
@@ -366,6 +366,7 @@ class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]):
...
@@ -366,6 +366,7 @@ class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]):
# https://docs.xarray.dev/en/latest/api.html#id1
# https://docs.xarray.dev/en/latest/api.html#id1
@property
@property
def
values
(
self
)
->
TensorVariable
:
def
values
(
self
)
->
TensorVariable
:
"""Convert to a TensorVariable with the same data."""
return
typing
.
cast
(
TensorVariable
,
px
.
basic
.
tensor_from_xtensor
(
self
))
return
typing
.
cast
(
TensorVariable
,
px
.
basic
.
tensor_from_xtensor
(
self
))
# Can't provide property data because that's already taken by Constants!
# Can't provide property data because that's already taken by Constants!
...
@@ -373,14 +374,17 @@ class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]):
...
@@ -373,14 +374,17 @@ class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]):
@property
@property
def
coords
(
self
):
def
coords
(
self
):
"""Not implemented."""
raise
NotImplementedError
(
"coords not implemented for XTensorVariable"
)
raise
NotImplementedError
(
"coords not implemented for XTensorVariable"
)
@property
@property
def
dims
(
self
)
->
tuple
[
str
,
...
]:
def
dims
(
self
)
->
tuple
[
str
,
...
]:
"""The names of the dimensions of the variable."""
return
self
.
type
.
dims
return
self
.
type
.
dims
@property
@property
def
sizes
(
self
)
->
dict
[
str
,
TensorVariable
]:
def
sizes
(
self
)
->
dict
[
str
,
TensorVariable
]:
"""The sizes of the dimensions of the variable."""
return
dict
(
zip
(
self
.
dims
,
self
.
shape
))
return
dict
(
zip
(
self
.
dims
,
self
.
shape
))
@property
@property
...
@@ -392,18 +396,22 @@ class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]):
...
@@ -392,18 +396,22 @@ class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]):
# https://docs.xarray.dev/en/latest/api.html#ndarray-attributes
# https://docs.xarray.dev/en/latest/api.html#ndarray-attributes
@property
@property
def
ndim
(
self
)
->
int
:
def
ndim
(
self
)
->
int
:
"""The number of dimensions of the variable."""
return
self
.
type
.
ndim
return
self
.
type
.
ndim
@property
@property
def
shape
(
self
)
->
tuple
[
TensorVariable
,
...
]:
def
shape
(
self
)
->
tuple
[
TensorVariable
,
...
]:
"""The shape of the variable."""
return
tuple
(
px
.
basic
.
tensor_from_xtensor
(
self
)
.
shape
)
# type: ignore
return
tuple
(
px
.
basic
.
tensor_from_xtensor
(
self
)
.
shape
)
# type: ignore
@property
@property
def
size
(
self
)
->
TensorVariable
:
def
size
(
self
)
->
TensorVariable
:
"""The total number of elements in the variable."""
return
typing
.
cast
(
TensorVariable
,
variadic_mul
(
*
self
.
shape
))
return
typing
.
cast
(
TensorVariable
,
variadic_mul
(
*
self
.
shape
))
@property
@property
def
dtype
(
self
):
def
dtype
(
self
)
->
str
:
"""The data type of the variable."""
return
self
.
type
.
dtype
return
self
.
type
.
dtype
@property
@property
...
@@ -414,6 +422,7 @@ class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]):
...
@@ -414,6 +422,7 @@ class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]):
# DataArray contents
# DataArray contents
# https://docs.xarray.dev/en/latest/api.html#dataarray-contents
# https://docs.xarray.dev/en/latest/api.html#dataarray-contents
def
rename
(
self
,
new_name_or_name_dict
=
None
,
**
names
):
def
rename
(
self
,
new_name_or_name_dict
=
None
,
**
names
):
"""Rename the variable or its dimension(s)."""
if
isinstance
(
new_name_or_name_dict
,
str
):
if
isinstance
(
new_name_or_name_dict
,
str
):
new_name
=
new_name_or_name_dict
new_name
=
new_name_or_name_dict
name_dict
=
None
name_dict
=
None
...
@@ -425,31 +434,41 @@ class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]):
...
@@ -425,31 +434,41 @@ class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]):
return
new_out
return
new_out
def
copy
(
self
,
name
:
str
|
None
=
None
):
def
copy
(
self
,
name
:
str
|
None
=
None
):
"""Create a copy of the variable.
This is just an identity operation, as XTensorVariables are immutable.
"""
out
=
px
.
math
.
identity
(
self
)
out
=
px
.
math
.
identity
(
self
)
out
.
name
=
name
out
.
name
=
name
return
out
return
out
def
astype
(
self
,
dtype
):
def
astype
(
self
,
dtype
):
"""Convert the variable to a different data type."""
return
px
.
math
.
cast
(
self
,
dtype
)
return
px
.
math
.
cast
(
self
,
dtype
)
def
item
(
self
):
def
item
(
self
):
"""Not implemented."""
raise
NotImplementedError
(
"item not implemented for XTensorVariable"
)
raise
NotImplementedError
(
"item not implemented for XTensorVariable"
)
# Indexing
# Indexing
# https://docs.xarray.dev/en/latest/api.html#id2
# https://docs.xarray.dev/en/latest/api.html#id2
def
__setitem__
(
self
,
idx
,
value
):
def
__setitem__
(
self
,
idx
,
value
):
"""Not implemented. Use `x[idx].set(value)` or `x[idx].inc(value)` instead."""
raise
TypeError
(
raise
TypeError
(
"XTensorVariable does not support item assignment. Use the output of `x[idx].set` or `x[idx].inc` instead."
"XTensorVariable does not support item assignment. Use the output of `x[idx].set` or `x[idx].inc` instead."
)
)
@property
@property
def
loc
(
self
):
def
loc
(
self
):
"""Not implemented."""
raise
NotImplementedError
(
"loc not implemented for XTensorVariable"
)
raise
NotImplementedError
(
"loc not implemented for XTensorVariable"
)
def
sel
(
self
,
*
args
,
**
kwargs
):
def
sel
(
self
,
*
args
,
**
kwargs
):
"""Not implemented."""
raise
NotImplementedError
(
"sel not implemented for XTensorVariable"
)
raise
NotImplementedError
(
"sel not implemented for XTensorVariable"
)
def
__getitem__
(
self
,
idx
):
def
__getitem__
(
self
,
idx
):
"""Index the variable positionally."""
if
isinstance
(
idx
,
dict
):
if
isinstance
(
idx
,
dict
):
return
self
.
isel
(
idx
)
return
self
.
isel
(
idx
)
...
@@ -465,6 +484,7 @@ class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]):
...
@@ -465,6 +484,7 @@ class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]):
missing_dims
:
Literal
[
"raise"
,
"warn"
,
"ignore"
]
=
"raise"
,
missing_dims
:
Literal
[
"raise"
,
"warn"
,
"ignore"
]
=
"raise"
,
**
indexers_kwargs
,
**
indexers_kwargs
,
):
):
"""Index the variable along the specified dimension(s)."""
if
indexers_kwargs
:
if
indexers_kwargs
:
if
indexers
is
not
None
:
if
indexers
is
not
None
:
raise
ValueError
(
raise
ValueError
(
...
@@ -505,6 +525,48 @@ class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]):
...
@@ -505,6 +525,48 @@ class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]):
return
px
.
indexing
.
index
(
self
,
*
indices
)
return
px
.
indexing
.
index
(
self
,
*
indices
)
def
set
(
self
,
value
):
def
set
(
self
,
value
):
"""Return a copy of the variable indexed by self with the indexed values set to y.
The original variable is not modified.
Raises
------
ValueError
If self is not the result of an index operation
Examples
--------
.. testcode::
import pytensor.xtensor as ptx
x = ptx.as_xtensor([[0, 0], [0, 0]], dims=("a", "b"))
idx = ptx.as_xtensor([0, 1], dims=("a",))
out = x[:, idx].set(1)
print(out.eval())
.. testoutput::
[[1 0]
[0 1]]
.. testcode::
import pytensor.xtensor as ptx
x = ptx.as_xtensor([[0, 0], [0, 0]], dims=("a", "b"))
idx = ptx.as_xtensor([0, 1], dims=("a",))
out = x.isel({"b": idx}).set(-1)
print(out.eval())
.. testoutput::
[[-1 0]
[ 0 -1]]
"""
if
not
(
if
not
(
self
.
owner
is
not
None
and
isinstance
(
self
.
owner
.
op
,
px
.
indexing
.
Index
)
self
.
owner
is
not
None
and
isinstance
(
self
.
owner
.
op
,
px
.
indexing
.
Index
)
):
):
...
@@ -516,6 +578,48 @@ class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]):
...
@@ -516,6 +578,48 @@ class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]):
return
px
.
indexing
.
index_assignment
(
x
,
value
,
*
idxs
)
return
px
.
indexing
.
index_assignment
(
x
,
value
,
*
idxs
)
def
inc
(
self
,
value
):
def
inc
(
self
,
value
):
"""Return a copy of the variable indexed by self with the indexed values incremented by value.
The original variable is not modified.
Raises
------
ValueError
If self is not the result of an index operation
Examples
--------
.. testcode::
import pytensor.xtensor as ptx
x = ptx.as_xtensor([[1, 1], [1, 1]], dims=("a", "b"))
idx = ptx.as_xtensor([0, 1], dims=("a",))
out = x[:, idx].inc(1)
print(out.eval())
.. testoutput::
[[2 1]
[1 2]]
.. testcode::
import pytensor.xtensor as ptx
x = ptx.as_xtensor([[1, 1], [1, 1]], dims=("a", "b"))
idx = ptx.as_xtensor([0, 1], dims=("a",))
out = x.isel({"b": idx}).inc(-1)
print(out.eval())
.. testoutput::
[[0 1]
[1 0]]
"""
if
not
(
if
not
(
self
.
owner
is
not
None
and
isinstance
(
self
.
owner
.
op
,
px
.
indexing
.
Index
)
self
.
owner
is
not
None
and
isinstance
(
self
.
owner
.
op
,
px
.
indexing
.
Index
)
):
):
...
@@ -579,7 +683,7 @@ class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]):
...
@@ -579,7 +683,7 @@ class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]):
drop
=
None
,
drop
=
None
,
axis
:
int
|
Sequence
[
int
]
|
None
=
None
,
axis
:
int
|
Sequence
[
int
]
|
None
=
None
,
):
):
"""Remove dimensions of size 1
from an XTensorVariable
.
"""Remove dimensions of size 1.
Parameters
Parameters
----------
----------
...
@@ -606,24 +710,21 @@ class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]):
...
@@ -606,24 +710,21 @@ class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]):
axis
:
int
|
Sequence
[
int
]
|
None
=
None
,
axis
:
int
|
Sequence
[
int
]
|
None
=
None
,
**
dim_kwargs
,
**
dim_kwargs
,
):
):
"""Add one or more new dimensions to the
tensor
.
"""Add one or more new dimensions to the
variable
.
Parameters
Parameters
----------
----------
dim : str | Sequence[str] | dict[str, int | Sequence] | None
dim : str | Sequence[str] | dict[str, int | Sequence] | None
If str or sequence of str, new dimensions with size 1.
If str or sequence of str, new dimensions with size 1.
If dict, keys are dimension names and values are either:
If dict, keys are dimension names and values are either:
- int: the new size
- sequence: coordinates (length determines size)
- int: the new size
- sequence: coordinates (length determines size)
create_index_for_new_dim : bool, default: True
create_index_for_new_dim : bool, default: True
Currently ignored. Reserved for future coordinate support.
Ignored by PyTensor
In xarray, when True (default), creates a coordinate index for the new dimension
with values from 0 to size-1. When False, no coordinate index is created.
axis : int | Sequence[int] | None, default: None
axis : int | Sequence[int] | None, default: None
Not implemented yet. In xarray, specifies where to insert the new dimension(s).
Not implemented yet. In xarray, specifies where to insert the new dimension(s).
By default (None), new dimensions are inserted at the beginning (axis=0).
By default (None), new dimensions are inserted at the beginning (axis=0).
Symbolic axis is not supported yet.
Negative values count from the end.
**dim_kwargs : int | Sequence
**dim_kwargs : int | Sequence
Alternative to `dim` dict. Only used if `dim` is None.
Alternative to `dim` dict. Only used if `dim` is None.
...
@@ -643,65 +744,75 @@ class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]):
...
@@ -643,65 +744,75 @@ class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]):
# ndarray methods
# ndarray methods
# https://docs.xarray.dev/en/latest/api.html#id7
# https://docs.xarray.dev/en/latest/api.html#id7
def
clip
(
self
,
min
,
max
):
def
clip
(
self
,
min
,
max
):
"""Clip the values of the variable to a specified range."""
return
px
.
math
.
clip
(
self
,
min
,
max
)
return
px
.
math
.
clip
(
self
,
min
,
max
)
def
conj
(
self
):
def
conj
(
self
):
"""Return the complex conjugate of the variable."""
return
px
.
math
.
conj
(
self
)
return
px
.
math
.
conj
(
self
)
@property
@property
def
imag
(
self
):
def
imag
(
self
):
"""Return the imaginary part of the variable."""
return
px
.
math
.
imag
(
self
)
return
px
.
math
.
imag
(
self
)
@property
@property
def
real
(
self
):
def
real
(
self
):
"""Return the real part of the variable."""
return
px
.
math
.
real
(
self
)
return
px
.
math
.
real
(
self
)
@property
@property
def
T
(
self
):
def
T
(
self
):
"""Return the full transpose of the
tensor
.
"""Return the full transpose of the
variable
.
This is equivalent to calling transpose() with no arguments.
This is equivalent to calling transpose() with no arguments.
Returns
-------
XTensorVariable
Fully transposed tensor.
"""
"""
return
self
.
transpose
()
return
self
.
transpose
()
# Aggregation
# Aggregation
# https://docs.xarray.dev/en/latest/api.html#id6
# https://docs.xarray.dev/en/latest/api.html#id6
def
all
(
self
,
dim
=
None
):
def
all
(
self
,
dim
=
None
):
"""Reduce the variable by applying `all` along some dimension(s)."""
return
px
.
reduction
.
all
(
self
,
dim
)
return
px
.
reduction
.
all
(
self
,
dim
)
def
any
(
self
,
dim
=
None
):
def
any
(
self
,
dim
=
None
):
"""Reduce the variable by applying `any` along some dimension(s)."""
return
px
.
reduction
.
any
(
self
,
dim
)
return
px
.
reduction
.
any
(
self
,
dim
)
def
max
(
self
,
dim
=
None
):
def
max
(
self
,
dim
=
None
):
"""Compute the maximum along the given dimension(s)."""
return
px
.
reduction
.
max
(
self
,
dim
)
return
px
.
reduction
.
max
(
self
,
dim
)
def
min
(
self
,
dim
=
None
):
def
min
(
self
,
dim
=
None
):
"""Compute the minimum along the given dimension(s)."""
return
px
.
reduction
.
min
(
self
,
dim
)
return
px
.
reduction
.
min
(
self
,
dim
)
def
mean
(
self
,
dim
=
None
):
def
mean
(
self
,
dim
=
None
):
"""Compute the mean along the given dimension(s)."""
return
px
.
reduction
.
mean
(
self
,
dim
)
return
px
.
reduction
.
mean
(
self
,
dim
)
def
prod
(
self
,
dim
=
None
):
def
prod
(
self
,
dim
=
None
):
"""Compute the product along the given dimension(s)."""
return
px
.
reduction
.
prod
(
self
,
dim
)
return
px
.
reduction
.
prod
(
self
,
dim
)
def
sum
(
self
,
dim
=
None
):
def
sum
(
self
,
dim
=
None
):
"""Compute the sum along the given dimension(s)."""
return
px
.
reduction
.
sum
(
self
,
dim
)
return
px
.
reduction
.
sum
(
self
,
dim
)
def
std
(
self
,
dim
=
None
,
ddof
=
0
):
def
std
(
self
,
dim
=
None
,
ddof
=
0
):
"""Compute the standard deviation along the given dimension(s)."""
return
px
.
reduction
.
std
(
self
,
dim
,
ddof
=
ddof
)
return
px
.
reduction
.
std
(
self
,
dim
,
ddof
=
ddof
)
def
var
(
self
,
dim
=
None
,
ddof
=
0
):
def
var
(
self
,
dim
=
None
,
ddof
=
0
):
"""Compute the variance along the given dimension(s)."""
return
px
.
reduction
.
var
(
self
,
dim
,
ddof
=
ddof
)
return
px
.
reduction
.
var
(
self
,
dim
,
ddof
=
ddof
)
def
cumsum
(
self
,
dim
=
None
):
def
cumsum
(
self
,
dim
=
None
):
"""Compute the cumulative sum along the given dimension(s)."""
return
px
.
reduction
.
cumsum
(
self
,
dim
)
return
px
.
reduction
.
cumsum
(
self
,
dim
)
def
cumprod
(
self
,
dim
=
None
):
def
cumprod
(
self
,
dim
=
None
):
"""Compute the cumulative product along the given dimension(s)."""
return
px
.
reduction
.
cumprod
(
self
,
dim
)
return
px
.
reduction
.
cumprod
(
self
,
dim
)
def
diff
(
self
,
dim
,
n
=
1
):
def
diff
(
self
,
dim
,
n
=
1
):
...
@@ -720,7 +831,7 @@ class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]):
...
@@ -720,7 +831,7 @@ class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]):
*
dim
:
str
|
EllipsisType
,
*
dim
:
str
|
EllipsisType
,
missing_dims
:
Literal
[
"raise"
,
"warn"
,
"ignore"
]
=
"raise"
,
missing_dims
:
Literal
[
"raise"
,
"warn"
,
"ignore"
]
=
"raise"
,
):
):
"""Transpose
dimensions of the tensor
.
"""Transpose
the dimensions of the variable
.
Parameters
Parameters
----------
----------
...
@@ -729,6 +840,7 @@ class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]):
...
@@ -729,6 +840,7 @@ class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]):
Can use ellipsis (...) to represent remaining dimensions.
Can use ellipsis (...) to represent remaining dimensions.
missing_dims : {"raise", "warn", "ignore"}, default="raise"
missing_dims : {"raise", "warn", "ignore"}, default="raise"
How to handle dimensions that don't exist in the tensor:
How to handle dimensions that don't exist in the tensor:
- "raise": Raise an error if any dimensions don't exist
- "raise": Raise an error if any dimensions don't exist
- "warn": Warn if any dimensions don't exist
- "warn": Warn if any dimensions don't exist
- "ignore": Silently ignore any dimensions that don't exist
- "ignore": Silently ignore any dimensions that don't exist
...
@@ -747,21 +859,38 @@ class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]):
...
@@ -747,21 +859,38 @@ class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]):
return
px
.
shape
.
transpose
(
self
,
*
dim
,
missing_dims
=
missing_dims
)
return
px
.
shape
.
transpose
(
self
,
*
dim
,
missing_dims
=
missing_dims
)
def
stack
(
self
,
dim
,
**
dims
):
def
stack
(
self
,
dim
,
**
dims
):
"""Stack existing dimensions into a single new dimension."""
return
px
.
shape
.
stack
(
self
,
dim
,
**
dims
)
return
px
.
shape
.
stack
(
self
,
dim
,
**
dims
)
def
unstack
(
self
,
dim
,
**
dims
):
def
unstack
(
self
,
dim
,
**
dims
):
"""Unstack a dimension into multiple dimensions of a given size.
Because XTensorVariables don't have coords, this operation requires the sizes of each unstacked dimension to be specified.
Also, unstacked dims will follow a C-style order, regardless of the order of the original dimensions.
.. testcode::
import pytensor.xtensor as ptx
x = ptx.as_xtensor([[1, 2], [3, 4]], dims=("a", "b"))
stacked_cumsum = x.stack({"c": ["a", "b"]}).cumsum("c")
unstacked_cumsum = stacked_cumsum.unstack({"c": x.sizes})
print(unstacked_cumsum.eval())
.. testoutput::
[[ 1 3]
[ 6 10]]
"""
return
px
.
shape
.
unstack
(
self
,
dim
,
**
dims
)
return
px
.
shape
.
unstack
(
self
,
dim
,
**
dims
)
def
dot
(
self
,
other
,
dim
=
None
):
def
dot
(
self
,
other
,
dim
=
None
):
"""
Matrix multiplication with another XTensorVariable, contracting over matching or specified dims
."""
"""
Generalized dot product with another XTensorVariable
."""
return
px
.
math
.
dot
(
self
,
other
,
dim
=
dim
)
return
px
.
math
.
dot
(
self
,
other
,
dim
=
dim
)
def
broadcast
(
self
,
*
others
,
exclude
=
None
):
"""Broadcast this tensor against other XTensorVariables."""
return
px
.
shape
.
broadcast
(
self
,
*
others
,
exclude
=
exclude
)
def
broadcast_like
(
self
,
other
,
exclude
=
None
):
def
broadcast_like
(
self
,
other
,
exclude
=
None
):
"""Broadcast
this tensor
against another XTensorVariable."""
"""Broadcast against another XTensorVariable."""
_
,
self_bcast
=
px
.
shape
.
broadcast
(
other
,
self
,
exclude
=
exclude
)
_
,
self_bcast
=
px
.
shape
.
broadcast
(
other
,
self
,
exclude
=
exclude
)
return
self_bcast
return
self_bcast
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论