Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
33d04c36
提交
33d04c36
authored
6月 19, 2025
作者:
Allen Downey
提交者:
Ricardo Vieira
6月 21, 2025
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Implement dot for XTensorVariables (#1475)
上级
9ede7f6a
显示空白字符变更
内嵌
并排
正在显示
6 个修改的文件
包含
330 行增加
和
2 行删除
+330
-2
__init__.py
pytensor/xtensor/__init__.py
+1
-0
math.py
pytensor/xtensor/math.py
+113
-2
__init__.py
pytensor/xtensor/rewriting/__init__.py
+1
-0
math.py
pytensor/xtensor/rewriting/math.py
+47
-0
type.py
pytensor/xtensor/type.py
+4
-0
test_math.py
tests/xtensor/test_math.py
+164
-0
没有找到文件。
pytensor/xtensor/__init__.py
浏览文件 @
33d04c36
...
@@ -2,6 +2,7 @@ import warnings
...
@@ -2,6 +2,7 @@ import warnings
import
pytensor.xtensor.rewriting
import
pytensor.xtensor.rewriting
from
pytensor.xtensor
import
linalg
from
pytensor.xtensor
import
linalg
from
pytensor.xtensor.math
import
dot
from
pytensor.xtensor.shape
import
concat
from
pytensor.xtensor.shape
import
concat
from
pytensor.xtensor.type
import
(
from
pytensor.xtensor.type
import
(
as_xtensor
,
as_xtensor
,
...
...
pytensor/xtensor/math.py
浏览文件 @
33d04c36
import
sys
import
sys
from
collections.abc
import
Iterable
from
types
import
EllipsisType
import
numpy
as
np
import
numpy
as
np
import
pytensor.scalar
as
ps
import
pytensor.scalar
as
ps
from
pytensor
import
config
from
pytensor
import
config
from
pytensor.graph.basic
import
Apply
from
pytensor.scalar
import
ScalarOp
from
pytensor.scalar
import
ScalarOp
from
pytensor.scalar.basic
import
_cast_mapping
from
pytensor.scalar.basic
import
_cast_mapping
,
upcast
from
pytensor.xtensor.basic
import
as_xtensor
from
pytensor.xtensor.basic
import
XOp
,
as_xtensor
from
pytensor.xtensor.type
import
xtensor
from
pytensor.xtensor.vectorization
import
XElemwise
from
pytensor.xtensor.vectorization
import
XElemwise
...
@@ -139,3 +143,110 @@ def cast(x, dtype):
...
@@ -139,3 +143,110 @@ def cast(x, dtype):
def
softmax
(
x
,
dim
=
None
):
def
softmax
(
x
,
dim
=
None
):
exp_x
=
exp
(
x
)
exp_x
=
exp
(
x
)
return
exp_x
/
exp_x
.
sum
(
dim
=
dim
)
return
exp_x
/
exp_x
.
sum
(
dim
=
dim
)
class
XDot
(
XOp
):
"""Matrix multiplication between two XTensorVariables.
This operation performs matrix multiplication between two tensors, automatically
aligning and contracting dimensions. The behavior matches xarray's dot operation.
Parameters
----------
dims : tuple of str
The dimensions to contract over. If None, will contract over all matching dimensions.
"""
__props__
=
(
"dims"
,)
def
__init__
(
self
,
dims
:
Iterable
[
str
]):
self
.
dims
=
dims
super
()
.
__init__
()
def
make_node
(
self
,
x
,
y
):
x
=
as_xtensor
(
x
)
y
=
as_xtensor
(
y
)
x_shape_dict
=
dict
(
zip
(
x
.
type
.
dims
,
x
.
type
.
shape
))
y_shape_dict
=
dict
(
zip
(
y
.
type
.
dims
,
y
.
type
.
shape
))
# Check for dimension size mismatches (concrete only)
for
dim
in
self
.
dims
:
x_shape
=
x_shape_dict
.
get
(
dim
,
None
)
y_shape
=
y_shape_dict
.
get
(
dim
,
None
)
if
(
isinstance
(
x_shape
,
int
)
and
isinstance
(
y_shape
,
int
)
and
x_shape
!=
y_shape
):
raise
ValueError
(
f
"Size of dim '{dim}' does not match"
)
# Determine output dimensions
shape_dict
=
{
**
x_shape_dict
,
**
y_shape_dict
}
out_dims
=
tuple
(
d
for
d
in
shape_dict
if
d
not
in
self
.
dims
)
# Determine output shape
out_shape
=
tuple
(
shape_dict
[
d
]
for
d
in
out_dims
)
# Determine output dtype
out_dtype
=
upcast
(
x
.
type
.
dtype
,
y
.
type
.
dtype
)
out
=
xtensor
(
dtype
=
out_dtype
,
shape
=
out_shape
,
dims
=
out_dims
)
return
Apply
(
self
,
[
x
,
y
],
[
out
])
def
dot
(
x
,
y
,
dim
:
str
|
Iterable
[
str
]
|
EllipsisType
|
None
=
None
):
"""Matrix multiplication between two XTensorVariables.
This operation performs matrix multiplication between two tensors, automatically
aligning and contracting dimensions. The behavior matches xarray's dot operation.
Parameters
----------
x : XTensorVariable
First input tensor
y : XTensorVariable
Second input tensor
dim : str, Iterable[Hashable], EllipsisType, or None, optional
The dimensions to contract over. If None, will contract over all matching dimensions.
If Ellipsis (...), will contract over all dimensions.
Returns
-------
XTensorVariable
The result of the matrix multiplication.
Examples
--------
>>> x = xtensor(dtype="float64", dims=("a", "b"), shape=(2, 3))
>>> y = xtensor(dtype="float64", dims=("b", "c"), shape=(3, 4))
>>> z = dot(x, y) # Result has dimensions ("a", "c")
>>> z = dot(x, y, dim=...) # Contract over all dimensions
"""
x
=
as_xtensor
(
x
)
y
=
as_xtensor
(
y
)
x_dims
=
set
(
x
.
type
.
dims
)
y_dims
=
set
(
y
.
type
.
dims
)
intersection
=
x_dims
&
y_dims
union
=
x_dims
|
y_dims
# Canonicalize dims
if
dim
is
None
:
dim_set
=
intersection
elif
dim
is
...
:
dim_set
=
union
elif
isinstance
(
dim
,
str
):
dim_set
=
{
dim
}
elif
isinstance
(
dim
,
Iterable
):
dim_set
=
set
(
dim
)
# Validate provided dims
# Check if any dimension is not found in either input
for
d
in
dim_set
:
if
d
not
in
union
:
raise
ValueError
(
f
"Dimension {d} not found in either input"
)
result
=
XDot
(
dims
=
tuple
(
dim_set
))(
x
,
y
)
return
result
pytensor/xtensor/rewriting/__init__.py
浏览文件 @
33d04c36
import
pytensor.xtensor.rewriting.basic
import
pytensor.xtensor.rewriting.basic
import
pytensor.xtensor.rewriting.indexing
import
pytensor.xtensor.rewriting.indexing
import
pytensor.xtensor.rewriting.math
import
pytensor.xtensor.rewriting.reduction
import
pytensor.xtensor.rewriting.reduction
import
pytensor.xtensor.rewriting.shape
import
pytensor.xtensor.rewriting.shape
import
pytensor.xtensor.rewriting.vectorization
import
pytensor.xtensor.rewriting.vectorization
pytensor/xtensor/rewriting/math.py
0 → 100644
浏览文件 @
33d04c36
from
string
import
ascii_lowercase
from
pytensor.graph
import
node_rewriter
from
pytensor.tensor
import
einsum
from
pytensor.tensor.shape
import
specify_shape
from
pytensor.xtensor.basic
import
tensor_from_xtensor
,
xtensor_from_tensor
from
pytensor.xtensor.math
import
XDot
from
pytensor.xtensor.rewriting.utils
import
register_lower_xtensor
@register_lower_xtensor
@node_rewriter
(
tracks
=
[
XDot
])
def
lower_dot
(
fgraph
,
node
):
"""Rewrite XDot to tensor.dot.
This rewrite converts an XDot operation to a tensor-based dot operation,
handling dimension alignment and contraction.
"""
[
x
,
y
]
=
node
.
inputs
[
out
]
=
node
.
outputs
# Convert inputs to tensors
x_tensor
=
tensor_from_xtensor
(
x
)
y_tensor
=
tensor_from_xtensor
(
y
)
# Collect all dimension names across inputs and output
all_dims
=
list
(
dict
.
fromkeys
(
x
.
type
.
dims
+
y
.
type
.
dims
+
out
.
type
.
dims
)
)
# preserve order
if
len
(
all_dims
)
>
len
(
ascii_lowercase
):
raise
ValueError
(
"Too many dimensions to map to einsum subscripts"
)
dim_to_char
=
dict
(
zip
(
all_dims
,
ascii_lowercase
))
# Build einsum string
x_subs
=
""
.
join
(
dim_to_char
[
d
]
for
d
in
x
.
type
.
dims
)
y_subs
=
""
.
join
(
dim_to_char
[
d
]
for
d
in
y
.
type
.
dims
)
out_subs
=
""
.
join
(
dim_to_char
[
d
]
for
d
in
out
.
type
.
dims
)
einsum_str
=
f
"{x_subs},{y_subs}->{out_subs}"
# Perform the einsum operation
out_tensor
=
einsum
(
einsum_str
,
x_tensor
,
y_tensor
)
# Reshape to match the output shape
out_tensor
=
specify_shape
(
out_tensor
,
out
.
type
.
shape
)
return
[
xtensor_from_tensor
(
out_tensor
,
out
.
type
.
dims
)]
pytensor/xtensor/type.py
浏览文件 @
33d04c36
...
@@ -726,6 +726,10 @@ class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]):
...
@@ -726,6 +726,10 @@ class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]):
def
unstack
(
self
,
dim
,
**
dims
):
def
unstack
(
self
,
dim
,
**
dims
):
return
px
.
shape
.
unstack
(
self
,
dim
,
**
dims
)
return
px
.
shape
.
unstack
(
self
,
dim
,
**
dims
)
def
dot
(
self
,
other
,
dim
=
None
):
"""Matrix multiplication with another XTensorVariable, contracting over matching or specified dims."""
return
px
.
math
.
dot
(
self
,
other
,
dim
=
dim
)
class
XTensorConstantSignature
(
TensorConstantSignature
):
class
XTensorConstantSignature
(
TensorConstantSignature
):
pass
pass
...
...
tests/xtensor/test_math.py
浏览文件 @
33d04c36
...
@@ -150,3 +150,167 @@ def test_cast():
...
@@ -150,3 +150,167 @@ def test_cast():
yc64
=
x
.
astype
(
"complex64"
)
yc64
=
x
.
astype
(
"complex64"
)
with
pytest
.
raises
(
TypeError
,
match
=
"Casting from complex to real is ambiguous"
):
with
pytest
.
raises
(
TypeError
,
match
=
"Casting from complex to real is ambiguous"
):
yc64
.
astype
(
"float64"
)
yc64
.
astype
(
"float64"
)
def
test_dot
():
"""Test basic dot product operations."""
# Test matrix-vector dot product (with multiple-letter dim names)
x
=
xtensor
(
"x"
,
dims
=
(
"aa"
,
"bb"
),
shape
=
(
2
,
3
))
y
=
xtensor
(
"y"
,
dims
=
(
"bb"
,),
shape
=
(
3
,))
z
=
x
.
dot
(
y
)
fn
=
xr_function
([
x
,
y
],
z
)
x_test
=
DataArray
(
np
.
ones
((
2
,
3
)),
dims
=
(
"aa"
,
"bb"
))
y_test
=
DataArray
(
np
.
ones
(
3
),
dims
=
(
"bb"
,))
z_test
=
fn
(
x_test
,
y_test
)
expected
=
x_test
.
dot
(
y_test
)
xr_assert_allclose
(
z_test
,
expected
)
# Test matrix-vector dot product with ellipsis
z
=
x
.
dot
(
y
,
dim
=...
)
fn
=
xr_function
([
x
,
y
],
z
)
z_test
=
fn
(
x_test
,
y_test
)
expected
=
x_test
.
dot
(
y_test
,
dim
=...
)
xr_assert_allclose
(
z_test
,
expected
)
# Test matrix-matrix dot product
x
=
xtensor
(
"x"
,
dims
=
(
"a"
,
"b"
),
shape
=
(
2
,
3
))
y
=
xtensor
(
"y"
,
dims
=
(
"b"
,
"c"
),
shape
=
(
3
,
4
))
z
=
x
.
dot
(
y
)
fn
=
xr_function
([
x
,
y
],
z
)
x_test
=
DataArray
(
np
.
add
.
outer
(
np
.
arange
(
2.0
),
np
.
arange
(
3.0
)),
dims
=
(
"a"
,
"b"
))
y_test
=
DataArray
(
np
.
add
.
outer
(
np
.
arange
(
3.0
),
np
.
arange
(
4.0
)),
dims
=
(
"b"
,
"c"
))
z_test
=
fn
(
x_test
,
y_test
)
expected
=
x_test
.
dot
(
y_test
)
xr_assert_allclose
(
z_test
,
expected
)
# Test matrix-matrix dot product with string dim
z
=
x
.
dot
(
y
,
dim
=
"b"
)
fn
=
xr_function
([
x
,
y
],
z
)
z_test
=
fn
(
x_test
,
y_test
)
expected
=
x_test
.
dot
(
y_test
,
dim
=
"b"
)
xr_assert_allclose
(
z_test
,
expected
)
# Test matrix-matrix dot product with list of dims
z
=
x
.
dot
(
y
,
dim
=
[
"b"
])
fn
=
xr_function
([
x
,
y
],
z
)
z_test
=
fn
(
x_test
,
y_test
)
expected
=
x_test
.
dot
(
y_test
,
dim
=
[
"b"
])
xr_assert_allclose
(
z_test
,
expected
)
# Test matrix-matrix dot product with ellipsis
z
=
x
.
dot
(
y
,
dim
=...
)
fn
=
xr_function
([
x
,
y
],
z
)
z_test
=
fn
(
x_test
,
y_test
)
expected
=
x_test
.
dot
(
y_test
,
dim
=...
)
xr_assert_allclose
(
z_test
,
expected
)
# Test a case where there are two dimensions to sum over
x
=
xtensor
(
"x"
,
dims
=
(
"a"
,
"b"
,
"c"
),
shape
=
(
2
,
3
,
4
))
y
=
xtensor
(
"y"
,
dims
=
(
"b"
,
"c"
,
"d"
),
shape
=
(
3
,
4
,
5
))
z
=
x
.
dot
(
y
)
fn
=
xr_function
([
x
,
y
],
z
)
x_test
=
DataArray
(
np
.
arange
(
24.0
)
.
reshape
(
2
,
3
,
4
),
dims
=
(
"a"
,
"b"
,
"c"
))
y_test
=
DataArray
(
np
.
arange
(
60.0
)
.
reshape
(
3
,
4
,
5
),
dims
=
(
"b"
,
"c"
,
"d"
))
z_test
=
fn
(
x_test
,
y_test
)
expected
=
x_test
.
dot
(
y_test
)
xr_assert_allclose
(
z_test
,
expected
)
# Same but with explicit dimensions
z
=
x
.
dot
(
y
,
dim
=
[
"b"
,
"c"
])
fn
=
xr_function
([
x
,
y
],
z
)
z_test
=
fn
(
x_test
,
y_test
)
expected
=
x_test
.
dot
(
y_test
,
dim
=
[
"b"
,
"c"
])
xr_assert_allclose
(
z_test
,
expected
)
# Same but with ellipses
z
=
x
.
dot
(
y
,
dim
=...
)
fn
=
xr_function
([
x
,
y
],
z
)
z_test
=
fn
(
x_test
,
y_test
)
expected
=
x_test
.
dot
(
y_test
,
dim
=...
)
xr_assert_allclose
(
z_test
,
expected
)
# Dot product with sum
x_test
=
DataArray
(
np
.
arange
(
24.0
)
.
reshape
(
2
,
3
,
4
),
dims
=
(
"a"
,
"b"
,
"c"
))
y_test
=
DataArray
(
np
.
arange
(
60.0
)
.
reshape
(
3
,
4
,
5
),
dims
=
(
"b"
,
"c"
,
"d"
))
expected
=
x_test
.
dot
(
y_test
,
dim
=
(
"a"
,
"b"
,
"c"
))
x
=
xtensor
(
"x"
,
dims
=
(
"a"
,
"b"
,
"c"
),
shape
=
(
2
,
3
,
4
))
y
=
xtensor
(
"y"
,
dims
=
(
"b"
,
"c"
,
"d"
),
shape
=
(
3
,
4
,
5
))
z
=
x
.
dot
(
y
,
dim
=
(
"a"
,
"b"
,
"c"
))
fn
=
xr_function
([
x
,
y
],
z
)
z_test
=
fn
(
x_test
,
y_test
)
xr_assert_allclose
(
z_test
,
expected
)
# Dot product with sum in the middle
x_test
=
DataArray
(
np
.
arange
(
120.0
)
.
reshape
(
2
,
3
,
4
,
5
),
dims
=
(
"a"
,
"b"
,
"c"
,
"d"
))
y_test
=
DataArray
(
np
.
arange
(
360.0
)
.
reshape
(
3
,
4
,
5
,
6
),
dims
=
(
"b"
,
"c"
,
"d"
,
"e"
))
expected
=
x_test
.
dot
(
y_test
,
dim
=
(
"b"
,
"d"
))
x
=
xtensor
(
"x"
,
dims
=
(
"a"
,
"b"
,
"c"
,
"d"
),
shape
=
(
2
,
3
,
4
,
5
))
y
=
xtensor
(
"y"
,
dims
=
(
"b"
,
"c"
,
"d"
,
"e"
),
shape
=
(
3
,
4
,
5
,
6
))
z
=
x
.
dot
(
y
,
dim
=
(
"b"
,
"d"
))
fn
=
xr_function
([
x
,
y
],
z
)
z_test
=
fn
(
x_test
,
y_test
)
xr_assert_allclose
(
z_test
,
expected
)
# Same but with first two dims
expected
=
x_test
.
dot
(
y_test
,
dim
=
[
"a"
,
"b"
])
z
=
x
.
dot
(
y
,
dim
=
[
"a"
,
"b"
])
fn
=
xr_function
([
x
,
y
],
z
)
z_test
=
fn
(
x_test
,
y_test
)
xr_assert_allclose
(
z_test
,
expected
)
# Same but with last two
expected
=
x_test
.
dot
(
y_test
,
dim
=
[
"d"
,
"e"
])
z
=
x
.
dot
(
y
,
dim
=
[
"d"
,
"e"
])
fn
=
xr_function
([
x
,
y
],
z
)
z_test
=
fn
(
x_test
,
y_test
)
xr_assert_allclose
(
z_test
,
expected
)
# Same but with every other dim
expected
=
x_test
.
dot
(
y_test
,
dim
=
[
"a"
,
"c"
,
"e"
])
z
=
x
.
dot
(
y
,
dim
=
[
"a"
,
"c"
,
"e"
])
fn
=
xr_function
([
x
,
y
],
z
)
z_test
=
fn
(
x_test
,
y_test
)
xr_assert_allclose
(
z_test
,
expected
)
# Test symbolic shapes
x
=
xtensor
(
"x"
,
dims
=
(
"a"
,
"b"
),
shape
=
(
None
,
3
))
# First dimension is symbolic
y
=
xtensor
(
"y"
,
dims
=
(
"b"
,
"c"
),
shape
=
(
3
,
None
))
# Second dimension is symbolic
z
=
x
.
dot
(
y
)
fn
=
xr_function
([
x
,
y
],
z
)
x_test
=
DataArray
(
np
.
ones
((
2
,
3
)),
dims
=
(
"a"
,
"b"
))
y_test
=
DataArray
(
np
.
ones
((
3
,
4
)),
dims
=
(
"b"
,
"c"
))
z_test
=
fn
(
x_test
,
y_test
)
expected
=
x_test
.
dot
(
y_test
)
xr_assert_allclose
(
z_test
,
expected
)
def
test_dot_errors
():
# No matching dimensions
x
=
xtensor
(
"x"
,
dims
=
(
"a"
,
"b"
),
shape
=
(
2
,
3
))
y
=
xtensor
(
"y"
,
dims
=
(
"b"
,
"c"
),
shape
=
(
3
,
4
))
with
pytest
.
raises
(
ValueError
,
match
=
"Dimension e not found in either input"
):
x
.
dot
(
y
,
dim
=
"e"
)
# Concrete dimension size mismatches
x
=
xtensor
(
"x"
,
dims
=
(
"a"
,
"b"
),
shape
=
(
2
,
3
))
y
=
xtensor
(
"y"
,
dims
=
(
"b"
,
"c"
),
shape
=
(
4
,
5
))
with
pytest
.
raises
(
ValueError
,
match
=
"Size of dim 'b' does not match"
,
):
x
.
dot
(
y
)
# Symbolic dimension size mismatches
x
=
xtensor
(
"x"
,
dims
=
(
"a"
,
"b"
),
shape
=
(
2
,
None
))
y
=
xtensor
(
"y"
,
dims
=
(
"b"
,
"c"
),
shape
=
(
None
,
5
))
z
=
x
.
dot
(
y
)
fn
=
xr_function
([
x
,
y
],
z
)
x_test
=
DataArray
(
np
.
ones
((
2
,
3
)),
dims
=
(
"a"
,
"b"
))
y_test
=
DataArray
(
np
.
ones
((
4
,
5
)),
dims
=
(
"b"
,
"c"
))
# Doesn't fail until the rewrite
with
pytest
.
raises
(
ValueError
,
match
=
"not aligned"
):
fn
(
x_test
,
y_test
)
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论