Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
5024d54e
提交
5024d54e
authored
7月 11, 2025
作者:
Ricardo Vieira
提交者:
Ricardo Vieira
7月 14, 2025
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add docstrings to more XTensorVariable methods
Also remove broadcast which is not a method in Xarray
上级
fdb40877
隐藏空白字符变更
内嵌
并排
正在显示
1 个修改的文件
包含
152 行增加
和
23 行删除
+152
-23
type.py
pytensor/xtensor/type.py
+152
-23
没有找到文件。
pytensor/xtensor/type.py
浏览文件 @
5024d54e
...
...
@@ -366,6 +366,7 @@ class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]):
# https://docs.xarray.dev/en/latest/api.html#id1
@property
def
values
(
self
)
->
TensorVariable
:
"""Convert to a TensorVariable with the same data."""
return
typing
.
cast
(
TensorVariable
,
px
.
basic
.
tensor_from_xtensor
(
self
))
# Can't provide property data because that's already taken by Constants!
...
...
@@ -373,14 +374,17 @@ class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]):
@property
def
coords
(
self
):
"""Not implemented."""
raise
NotImplementedError
(
"coords not implemented for XTensorVariable"
)
@property
def
dims
(
self
)
->
tuple
[
str
,
...
]:
"""The names of the dimensions of the variable."""
return
self
.
type
.
dims
@property
def
sizes
(
self
)
->
dict
[
str
,
TensorVariable
]:
"""The sizes of the dimensions of the variable."""
return
dict
(
zip
(
self
.
dims
,
self
.
shape
))
@property
...
...
@@ -392,18 +396,22 @@ class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]):
# https://docs.xarray.dev/en/latest/api.html#ndarray-attributes
@property
def
ndim
(
self
)
->
int
:
"""The number of dimensions of the variable."""
return
self
.
type
.
ndim
@property
def
shape
(
self
)
->
tuple
[
TensorVariable
,
...
]:
"""The shape of the variable."""
return
tuple
(
px
.
basic
.
tensor_from_xtensor
(
self
)
.
shape
)
# type: ignore
@property
def
size
(
self
)
->
TensorVariable
:
"""The total number of elements in the variable."""
return
typing
.
cast
(
TensorVariable
,
variadic_mul
(
*
self
.
shape
))
@property
def
dtype
(
self
):
def
dtype
(
self
)
->
str
:
"""The data type of the variable."""
return
self
.
type
.
dtype
@property
...
...
@@ -414,6 +422,7 @@ class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]):
# DataArray contents
# https://docs.xarray.dev/en/latest/api.html#dataarray-contents
def
rename
(
self
,
new_name_or_name_dict
=
None
,
**
names
):
"""Rename the variable or its dimension(s)."""
if
isinstance
(
new_name_or_name_dict
,
str
):
new_name
=
new_name_or_name_dict
name_dict
=
None
...
...
@@ -425,31 +434,41 @@ class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]):
return
new_out
def
copy
(
self
,
name
:
str
|
None
=
None
):
"""Create a copy of the variable.
This is just an identity operation, as XTensorVariables are immutable.
"""
out
=
px
.
math
.
identity
(
self
)
out
.
name
=
name
return
out
def
astype
(
self
,
dtype
):
"""Convert the variable to a different data type."""
return
px
.
math
.
cast
(
self
,
dtype
)
def
item
(
self
):
"""Not implemented."""
raise
NotImplementedError
(
"item not implemented for XTensorVariable"
)
# Indexing
# https://docs.xarray.dev/en/latest/api.html#id2
def
__setitem__
(
self
,
idx
,
value
):
"""Not implemented. Use `x[idx].set(value)` or `x[idx].inc(value)` instead."""
raise
TypeError
(
"XTensorVariable does not support item assignment. Use the output of `x[idx].set` or `x[idx].inc` instead."
)
@property
def
loc
(
self
):
"""Not implemented."""
raise
NotImplementedError
(
"loc not implemented for XTensorVariable"
)
def
sel
(
self
,
*
args
,
**
kwargs
):
"""Not implemented."""
raise
NotImplementedError
(
"sel not implemented for XTensorVariable"
)
def
__getitem__
(
self
,
idx
):
"""Index the variable positionally."""
if
isinstance
(
idx
,
dict
):
return
self
.
isel
(
idx
)
...
...
@@ -465,6 +484,7 @@ class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]):
missing_dims
:
Literal
[
"raise"
,
"warn"
,
"ignore"
]
=
"raise"
,
**
indexers_kwargs
,
):
"""Index the variable along the specified dimension(s)."""
if
indexers_kwargs
:
if
indexers
is
not
None
:
raise
ValueError
(
...
...
@@ -505,6 +525,48 @@ class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]):
return
px
.
indexing
.
index
(
self
,
*
indices
)
def
set
(
self
,
value
):
"""Return a copy of the variable indexed by self with the indexed values set to y.
The original variable is not modified.
Raises
------
ValueError
If self is not the result of an index operation
Examples
--------
.. testcode::
import pytensor.xtensor as ptx
x = ptx.as_xtensor([[0, 0], [0, 0]], dims=("a", "b"))
idx = ptx.as_xtensor([0, 1], dims=("a",))
out = x[:, idx].set(1)
print(out.eval())
.. testoutput::
[[1 0]
[0 1]]
.. testcode::
import pytensor.xtensor as ptx
x = ptx.as_xtensor([[0, 0], [0, 0]], dims=("a", "b"))
idx = ptx.as_xtensor([0, 1], dims=("a",))
out = x.isel({"b": idx}).set(-1)
print(out.eval())
.. testoutput::
[[-1 0]
[ 0 -1]]
"""
if
not
(
self
.
owner
is
not
None
and
isinstance
(
self
.
owner
.
op
,
px
.
indexing
.
Index
)
):
...
...
@@ -516,6 +578,48 @@ class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]):
return
px
.
indexing
.
index_assignment
(
x
,
value
,
*
idxs
)
def
inc
(
self
,
value
):
"""Return a copy of the variable indexed by self with the indexed values incremented by value.
The original variable is not modified.
Raises
------
ValueError
If self is not the result of an index operation
Examples
--------
.. testcode::
import pytensor.xtensor as ptx
x = ptx.as_xtensor([[1, 1], [1, 1]], dims=("a", "b"))
idx = ptx.as_xtensor([0, 1], dims=("a",))
out = x[:, idx].inc(1)
print(out.eval())
.. testoutput::
[[2 1]
[1 2]]
.. testcode::
import pytensor.xtensor as ptx
x = ptx.as_xtensor([[1, 1], [1, 1]], dims=("a", "b"))
idx = ptx.as_xtensor([0, 1], dims=("a",))
out = x.isel({"b": idx}).inc(-1)
print(out.eval())
.. testoutput::
[[0 1]
[1 0]]
"""
if
not
(
self
.
owner
is
not
None
and
isinstance
(
self
.
owner
.
op
,
px
.
indexing
.
Index
)
):
...
...
@@ -579,7 +683,7 @@ class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]):
drop
=
None
,
axis
:
int
|
Sequence
[
int
]
|
None
=
None
,
):
"""Remove dimensions of size 1
from an XTensorVariable
.
"""Remove dimensions of size 1.
Parameters
----------
...
...
@@ -606,24 +710,21 @@ class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]):
axis
:
int
|
Sequence
[
int
]
|
None
=
None
,
**
dim_kwargs
,
):
"""Add one or more new dimensions to the
tensor
.
"""Add one or more new dimensions to the
variable
.
Parameters
----------
dim : str | Sequence[str] | dict[str, int | Sequence] | None
If str or sequence of str, new dimensions with size 1.
If dict, keys are dimension names and values are either:
- int: the new size
- sequence: coordinates (length determines size)
- int: the new size
- sequence: coordinates (length determines size)
create_index_for_new_dim : bool, default: True
Currently ignored. Reserved for future coordinate support.
In xarray, when True (default), creates a coordinate index for the new dimension
with values from 0 to size-1. When False, no coordinate index is created.
Ignored by PyTensor
axis : int | Sequence[int] | None, default: None
Not implemented yet. In xarray, specifies where to insert the new dimension(s).
By default (None), new dimensions are inserted at the beginning (axis=0).
Symbolic axis is not supported yet.
Negative values count from the end.
**dim_kwargs : int | Sequence
Alternative to `dim` dict. Only used if `dim` is None.
...
...
@@ -643,65 +744,75 @@ class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]):
# ndarray methods
# https://docs.xarray.dev/en/latest/api.html#id7
def
clip
(
self
,
min
,
max
):
"""Clip the values of the variable to a specified range."""
return
px
.
math
.
clip
(
self
,
min
,
max
)
def
conj
(
self
):
"""Return the complex conjugate of the variable."""
return
px
.
math
.
conj
(
self
)
@property
def
imag
(
self
):
"""Return the imaginary part of the variable."""
return
px
.
math
.
imag
(
self
)
@property
def
real
(
self
):
"""Return the real part of the variable."""
return
px
.
math
.
real
(
self
)
@property
def
T
(
self
):
"""Return the full transpose of the
tensor
.
"""Return the full transpose of the
variable
.
This is equivalent to calling transpose() with no arguments.
Returns
-------
XTensorVariable
Fully transposed tensor.
"""
return
self
.
transpose
()
# Aggregation
# https://docs.xarray.dev/en/latest/api.html#id6
def
all
(
self
,
dim
=
None
):
"""Reduce the variable by applying `all` along some dimension(s)."""
return
px
.
reduction
.
all
(
self
,
dim
)
def
any
(
self
,
dim
=
None
):
"""Reduce the variable by applying `any` along some dimension(s)."""
return
px
.
reduction
.
any
(
self
,
dim
)
def
max
(
self
,
dim
=
None
):
"""Compute the maximum along the given dimension(s)."""
return
px
.
reduction
.
max
(
self
,
dim
)
def
min
(
self
,
dim
=
None
):
"""Compute the minimum along the given dimension(s)."""
return
px
.
reduction
.
min
(
self
,
dim
)
def
mean
(
self
,
dim
=
None
):
"""Compute the mean along the given dimension(s)."""
return
px
.
reduction
.
mean
(
self
,
dim
)
def
prod
(
self
,
dim
=
None
):
"""Compute the product along the given dimension(s)."""
return
px
.
reduction
.
prod
(
self
,
dim
)
def
sum
(
self
,
dim
=
None
):
"""Compute the sum along the given dimension(s)."""
return
px
.
reduction
.
sum
(
self
,
dim
)
def
std
(
self
,
dim
=
None
,
ddof
=
0
):
"""Compute the standard deviation along the given dimension(s)."""
return
px
.
reduction
.
std
(
self
,
dim
,
ddof
=
ddof
)
def
var
(
self
,
dim
=
None
,
ddof
=
0
):
"""Compute the variance along the given dimension(s)."""
return
px
.
reduction
.
var
(
self
,
dim
,
ddof
=
ddof
)
def
cumsum
(
self
,
dim
=
None
):
"""Compute the cumulative sum along the given dimension(s)."""
return
px
.
reduction
.
cumsum
(
self
,
dim
)
def
cumprod
(
self
,
dim
=
None
):
"""Compute the cumulative product along the given dimension(s)."""
return
px
.
reduction
.
cumprod
(
self
,
dim
)
def
diff
(
self
,
dim
,
n
=
1
):
...
...
@@ -720,7 +831,7 @@ class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]):
*
dim
:
str
|
EllipsisType
,
missing_dims
:
Literal
[
"raise"
,
"warn"
,
"ignore"
]
=
"raise"
,
):
"""Transpose
dimensions of the tensor
.
"""Transpose
the dimensions of the variable
.
Parameters
----------
...
...
@@ -729,6 +840,7 @@ class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]):
Can use ellipsis (...) to represent remaining dimensions.
missing_dims : {"raise", "warn", "ignore"}, default="raise"
How to handle dimensions that don't exist in the tensor:
- "raise": Raise an error if any dimensions don't exist
- "warn": Warn if any dimensions don't exist
- "ignore": Silently ignore any dimensions that don't exist
...
...
@@ -747,21 +859,38 @@ class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]):
return
px
.
shape
.
transpose
(
self
,
*
dim
,
missing_dims
=
missing_dims
)
def
stack
(
self
,
dim
,
**
dims
):
"""Stack existing dimensions into a single new dimension."""
return
px
.
shape
.
stack
(
self
,
dim
,
**
dims
)
def
unstack
(
self
,
dim
,
**
dims
):
"""Unstack a dimension into multiple dimensions of a given size.
Because XTensorVariables don't have coords, this operation requires the sizes of each unstacked dimension to be specified.
Also, unstacked dims will follow a C-style order, regardless of the order of the original dimensions.
.. testcode::
import pytensor.xtensor as ptx
x = ptx.as_xtensor([[1, 2], [3, 4]], dims=("a", "b"))
stacked_cumsum = x.stack({"c": ["a", "b"]}).cumsum("c")
unstacked_cumsum = stacked_cumsum.unstack({"c": x.sizes})
print(unstacked_cumsum.eval())
.. testoutput::
[[ 1 3]
[ 6 10]]
"""
return
px
.
shape
.
unstack
(
self
,
dim
,
**
dims
)
def
dot
(
self
,
other
,
dim
=
None
):
"""
Matrix multiplication with another XTensorVariable, contracting over matching or specified dims
."""
"""
Generalized dot product with another XTensorVariable
."""
return
px
.
math
.
dot
(
self
,
other
,
dim
=
dim
)
def
broadcast
(
self
,
*
others
,
exclude
=
None
):
"""Broadcast this tensor against other XTensorVariables."""
return
px
.
shape
.
broadcast
(
self
,
*
others
,
exclude
=
exclude
)
def
broadcast_like
(
self
,
other
,
exclude
=
None
):
"""Broadcast
this tensor
against another XTensorVariable."""
"""Broadcast against another XTensorVariable."""
_
,
self_bcast
=
px
.
shape
.
broadcast
(
other
,
self
,
exclude
=
exclude
)
return
self_bcast
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论