Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
133ec80e
提交
133ec80e
authored
5月 28, 2025
作者:
Allen Downey
提交者:
Ricardo Vieira
6月 21, 2025
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Implement transpose for XTensorVariables
上级
30b50fda
隐藏空白字符变更
内嵌
并排
正在显示
5 个修改的文件
包含
240 行增加
和
9 行删除
+240
-9
shape.py
pytensor/xtensor/rewriting/shape.py
+17
-1
shape.py
pytensor/xtensor/shape.py
+90
-0
type.py
pytensor/xtensor/type.py
+46
-1
test_shape.py
tests/xtensor/test_shape.py
+84
-1
test_type.py
tests/xtensor/test_type.py
+3
-6
没有找到文件。
pytensor/xtensor/rewriting/shape.py
浏览文件 @
133ec80e
...
@@ -2,7 +2,7 @@ from pytensor.graph import node_rewriter
...
@@ -2,7 +2,7 @@ from pytensor.graph import node_rewriter
from
pytensor.tensor
import
broadcast_to
,
join
,
moveaxis
from
pytensor.tensor
import
broadcast_to
,
join
,
moveaxis
from
pytensor.xtensor.basic
import
tensor_from_xtensor
,
xtensor_from_tensor
from
pytensor.xtensor.basic
import
tensor_from_xtensor
,
xtensor_from_tensor
from
pytensor.xtensor.rewriting.basic
import
register_lower_xtensor
from
pytensor.xtensor.rewriting.basic
import
register_lower_xtensor
from
pytensor.xtensor.shape
import
Concat
,
Stack
from
pytensor.xtensor.shape
import
Concat
,
Stack
,
Transpose
@register_lower_xtensor
@register_lower_xtensor
...
@@ -70,3 +70,19 @@ def lower_concat(fgraph, node):
...
@@ -70,3 +70,19 @@ def lower_concat(fgraph, node):
joined_tensor
=
join
(
concat_axis
,
*
bcast_tensor_inputs
)
joined_tensor
=
join
(
concat_axis
,
*
bcast_tensor_inputs
)
new_out
=
xtensor_from_tensor
(
joined_tensor
,
dims
=
out_dims
)
new_out
=
xtensor_from_tensor
(
joined_tensor
,
dims
=
out_dims
)
return
[
new_out
]
return
[
new_out
]
@register_lower_xtensor
@node_rewriter
(
tracks
=
[
Transpose
])
def
lower_transpose
(
fgraph
,
node
):
[
x
]
=
node
.
inputs
# Use the final dimensions that were already computed in make_node
out_dims
=
node
.
outputs
[
0
]
.
type
.
dims
in_dims
=
x
.
type
.
dims
# Compute the permutation based on the final dimensions
perm
=
tuple
(
in_dims
.
index
(
d
)
for
d
in
out_dims
)
x_tensor
=
tensor_from_xtensor
(
x
)
x_tensor_transposed
=
x_tensor
.
transpose
(
perm
)
new_out
=
xtensor_from_tensor
(
x_tensor_transposed
,
dims
=
out_dims
)
return
[
new_out
]
pytensor/xtensor/shape.py
浏览文件 @
133ec80e
import
typing
import
warnings
from
collections.abc
import
Sequence
from
collections.abc
import
Sequence
from
types
import
EllipsisType
from
typing
import
Literal
from
pytensor.graph
import
Apply
from
pytensor.graph
import
Apply
from
pytensor.scalar
import
upcast
from
pytensor.scalar
import
upcast
...
@@ -72,6 +76,92 @@ def stack(x, dim: dict[str, Sequence[str]] | None = None, **dims: Sequence[str])
...
@@ -72,6 +76,92 @@ def stack(x, dim: dict[str, Sequence[str]] | None = None, **dims: Sequence[str])
return
y
return
y
class
Transpose
(
XOp
):
__props__
=
(
"dims"
,)
def
__init__
(
self
,
dims
:
Sequence
[
str
],
):
super
()
.
__init__
()
self
.
dims
=
tuple
(
dims
)
def
make_node
(
self
,
x
):
x
=
as_xtensor
(
x
)
transpose_dims
=
self
.
dims
x_shape
=
x
.
type
.
shape
x_dims
=
x
.
type
.
dims
if
set
(
transpose_dims
)
!=
set
(
x_dims
):
raise
ValueError
(
f
"{transpose_dims} must be a permuted list of {x_dims}"
)
output
=
xtensor
(
dtype
=
x
.
type
.
dtype
,
shape
=
tuple
(
x_shape
[
x_dims
.
index
(
d
)]
for
d
in
transpose_dims
),
dims
=
transpose_dims
,
)
return
Apply
(
self
,
[
x
],
[
output
])
def
transpose
(
x
,
*
dims
:
str
|
EllipsisType
,
missing_dims
:
Literal
[
"raise"
,
"warn"
,
"ignore"
]
=
"raise"
,
):
"""Transpose dimensions of the tensor.
Parameters
----------
x : XTensorVariable
Input tensor to transpose.
*dims : str
Dimensions to transpose to. Can include ellipsis (...) to represent
remaining dimensions in their original order.
missing_dims : {"raise", "warn", "ignore"}, optional
How to handle dimensions that don't exist in the input tensor:
- "raise": Raise an error if any dimensions don't exist (default)
- "warn": Warn if any dimensions don't exist
- "ignore": Silently ignore any dimensions that don't exist
Returns
-------
XTensorVariable
Transposed tensor with reordered dimensions.
Raises
------
ValueError
If any dimension in dims doesn't exist in the input tensor and missing_dims is "raise".
"""
# Validate dimensions
x
=
as_xtensor
(
x
)
x_dims
=
x
.
type
.
dims
invalid_dims
=
set
(
dims
)
-
{
...
,
*
x_dims
}
if
invalid_dims
:
if
missing_dims
!=
"ignore"
:
msg
=
f
"Dimensions {invalid_dims} do not exist. Expected one or more of: {x_dims}"
if
missing_dims
==
"raise"
:
raise
ValueError
(
msg
)
else
:
warnings
.
warn
(
msg
)
# Handle missing dimensions if not raising
dims
=
tuple
(
d
for
d
in
dims
if
d
in
x_dims
or
d
is
...
)
if
dims
==
()
or
dims
==
(
...
,):
dims
=
tuple
(
reversed
(
x_dims
))
elif
...
in
dims
:
if
dims
.
count
(
...
)
>
1
:
raise
ValueError
(
"Ellipsis (...) can only appear once in the dimensions"
)
# Handle ellipsis expansion
ellipsis_idx
=
dims
.
index
(
...
)
pre
=
dims
[:
ellipsis_idx
]
post
=
dims
[
ellipsis_idx
+
1
:]
middle
=
[
d
for
d
in
x_dims
if
d
not
in
pre
+
post
]
dims
=
(
*
pre
,
*
middle
,
*
post
)
return
Transpose
(
typing
.
cast
(
tuple
[
str
],
dims
))(
x
)
class
Concat
(
XOp
):
class
Concat
(
XOp
):
__props__
=
(
"dim"
,)
__props__
=
(
"dim"
,)
...
...
pytensor/xtensor/type.py
浏览文件 @
133ec80e
import
typing
import
typing
from
types
import
EllipsisType
from
pytensor.compile
import
(
from
pytensor.compile
import
(
DeepCopyOp
,
DeepCopyOp
,
...
@@ -23,7 +24,7 @@ except ModuleNotFoundError:
...
@@ -23,7 +24,7 @@ except ModuleNotFoundError:
XARRAY_AVAILABLE
=
False
XARRAY_AVAILABLE
=
False
from
collections.abc
import
Sequence
from
collections.abc
import
Sequence
from
typing
import
TypeVar
from
typing
import
Literal
,
TypeVar
import
numpy
as
np
import
numpy
as
np
...
@@ -438,6 +439,19 @@ class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]):
...
@@ -438,6 +439,19 @@ class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]):
def
real
(
self
):
def
real
(
self
):
return
px
.
math
.
real
(
self
)
return
px
.
math
.
real
(
self
)
@property
def
T
(
self
):
"""Return the full transpose of the tensor.
This is equivalent to calling transpose() with no arguments.
Returns
-------
XTensorVariable
Fully transposed tensor.
"""
return
self
.
transpose
()
# Aggregation
# Aggregation
# https://docs.xarray.dev/en/latest/api.html#id6
# https://docs.xarray.dev/en/latest/api.html#id6
def
all
(
self
,
dim
=
None
):
def
all
(
self
,
dim
=
None
):
...
@@ -475,6 +489,37 @@ class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]):
...
@@ -475,6 +489,37 @@ class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]):
# Reshaping and reorganizing
# Reshaping and reorganizing
# https://docs.xarray.dev/en/latest/api.html#id8
# https://docs.xarray.dev/en/latest/api.html#id8
def
transpose
(
self
,
*
dims
:
str
|
EllipsisType
,
missing_dims
:
Literal
[
"raise"
,
"warn"
,
"ignore"
]
=
"raise"
,
):
"""Transpose dimensions of the tensor.
Parameters
----------
*dims : str | Ellipsis
Dimensions to transpose. If empty, performs a full transpose.
Can use ellipsis (...) to represent remaining dimensions.
missing_dims : {"raise", "warn", "ignore"}, default="raise"
How to handle dimensions that don't exist in the tensor:
- "raise": Raise an error if any dimensions don't exist
- "warn": Warn if any dimensions don't exist
- "ignore": Silently ignore any dimensions that don't exist
Returns
-------
XTensorVariable
Transposed tensor with reordered dimensions.
Raises
------
ValueError
If missing_dims="raise" and any dimensions don't exist.
If multiple ellipsis are provided.
"""
return
px
.
shape
.
transpose
(
self
,
*
dims
,
missing_dims
=
missing_dims
)
def
stack
(
self
,
dim
,
**
dims
):
def
stack
(
self
,
dim
,
**
dims
):
return
px
.
shape
.
stack
(
self
,
dim
,
**
dims
)
return
px
.
shape
.
stack
(
self
,
dim
,
**
dims
)
...
...
tests/xtensor/test_shape.py
浏览文件 @
133ec80e
...
@@ -4,12 +4,13 @@ import pytest
...
@@ -4,12 +4,13 @@ import pytest
pytest
.
importorskip
(
"xarray"
)
pytest
.
importorskip
(
"xarray"
)
import
re
from
itertools
import
chain
,
combinations
from
itertools
import
chain
,
combinations
import
numpy
as
np
import
numpy
as
np
from
xarray
import
concat
as
xr_concat
from
xarray
import
concat
as
xr_concat
from
pytensor.xtensor.shape
import
concat
,
stack
from
pytensor.xtensor.shape
import
concat
,
stack
,
transpose
from
pytensor.xtensor.type
import
xtensor
from
pytensor.xtensor.type
import
xtensor
from
tests.xtensor.util
import
(
from
tests.xtensor.util
import
(
xr_arange_like
,
xr_arange_like
,
...
@@ -28,6 +29,88 @@ def powerset(iterable, min_group_size=0):
...
@@ -28,6 +29,88 @@ def powerset(iterable, min_group_size=0):
)
)
def
test_transpose
():
a
,
b
,
c
,
d
,
e
=
"abcde"
x
=
xtensor
(
"x"
,
dims
=
(
a
,
b
,
c
,
d
,
e
),
shape
=
(
2
,
3
,
5
,
7
,
11
))
permutations
=
[
(
a
,
b
,
c
,
d
,
e
),
# identity
(
e
,
d
,
c
,
b
,
a
),
# full tranpose
(),
# eqivalent to full transpose
(
a
,
b
,
c
,
e
,
d
),
# swap last two dims
(
...
,
d
,
c
),
# equivalent to (a, b, e, d, c)
(
b
,
a
,
...
,
e
,
d
),
# equivalent to (b, a, c, d, e)
(
c
,
a
,
...
),
# equivalent to (c, a, b, d, e)
]
outs
=
[
transpose
(
x
,
*
perm
)
for
perm
in
permutations
]
fn
=
xr_function
([
x
],
outs
)
x_test
=
xr_arange_like
(
x
)
res
=
fn
(
x_test
)
expected_res
=
[
x_test
.
transpose
(
*
perm
)
for
perm
in
permutations
]
for
outs_i
,
res_i
,
expected_res_i
in
zip
(
outs
,
res
,
expected_res
):
xr_assert_allclose
(
res_i
,
expected_res_i
)
def
test_xtensor_variable_transpose
():
"""Test the transpose() method of XTensorVariable."""
x
=
xtensor
(
"x"
,
dims
=
(
"a"
,
"b"
,
"c"
),
shape
=
(
2
,
3
,
4
))
# Test basic transpose
out
=
x
.
transpose
()
fn
=
xr_function
([
x
],
out
)
x_test
=
xr_arange_like
(
x
)
xr_assert_allclose
(
fn
(
x_test
),
x_test
.
transpose
())
# Test transpose with specific dimensions
out
=
x
.
transpose
(
"c"
,
"a"
,
"b"
)
fn
=
xr_function
([
x
],
out
)
xr_assert_allclose
(
fn
(
x_test
),
x_test
.
transpose
(
"c"
,
"a"
,
"b"
))
# Test transpose with ellipsis
out
=
x
.
transpose
(
"c"
,
...
)
fn
=
xr_function
([
x
],
out
)
xr_assert_allclose
(
fn
(
x_test
),
x_test
.
transpose
(
"c"
,
...
))
# Test error cases
with
pytest
.
raises
(
ValueError
,
match
=
re
.
escape
(
"Dimensions {'d'} do not exist. Expected one or more of: ('a', 'b', 'c')"
),
):
x
.
transpose
(
"d"
)
with
pytest
.
raises
(
ValueError
,
match
=
re
.
escape
(
"Ellipsis (...) can only appear once in the dimensions"
),
):
x
.
transpose
(
"a"
,
...
,
"b"
,
...
)
# Test missing_dims parameter
# Test ignore
out
=
x
.
transpose
(
"c"
,
...
,
"d"
,
missing_dims
=
"ignore"
)
fn
=
xr_function
([
x
],
out
)
xr_assert_allclose
(
fn
(
x_test
),
x_test
.
transpose
(
"c"
,
...
))
# Test warn
with
pytest
.
warns
(
UserWarning
,
match
=
"Dimensions {'d'} do not exist"
):
out
=
x
.
transpose
(
"c"
,
...
,
"d"
,
missing_dims
=
"warn"
)
fn
=
xr_function
([
x
],
out
)
xr_assert_allclose
(
fn
(
x_test
),
x_test
.
transpose
(
"c"
,
...
))
def
test_xtensor_variable_T
():
"""Test the T property of XTensorVariable."""
# Test T property with 3D tensor
x
=
xtensor
(
"x"
,
dims
=
(
"a"
,
"b"
,
"c"
),
shape
=
(
2
,
3
,
4
))
out
=
x
.
T
fn
=
xr_function
([
x
],
out
)
x_test
=
xr_arange_like
(
x
)
xr_assert_allclose
(
fn
(
x_test
),
x_test
.
T
)
def
test_stack
():
def
test_stack
():
dims
=
(
"a"
,
"b"
,
"c"
,
"d"
)
dims
=
(
"a"
,
"b"
,
"c"
,
"d"
)
x
=
xtensor
(
"x"
,
dims
=
dims
,
shape
=
(
2
,
3
,
5
,
7
))
x
=
xtensor
(
"x"
,
dims
=
dims
,
shape
=
(
2
,
3
,
5
,
7
))
...
...
tests/xtensor/test_type.py
浏览文件 @
133ec80e
...
@@ -33,15 +33,12 @@ def test_xtensortype_filter_variable():
...
@@ -33,15 +33,12 @@ def test_xtensortype_filter_variable():
assert
x
.
type
.
filter_variable
(
y1
)
is
y1
assert
x
.
type
.
filter_variable
(
y1
)
is
y1
y2
=
xtensor
(
"y2"
,
dims
=
(
"b"
,
"a"
),
shape
=
(
3
,
2
))
y2
=
xtensor
(
"y2"
,
dims
=
(
"b"
,
"a"
),
shape
=
(
3
,
2
))
expected_y2
=
as_xtensor
(
y2
.
values
.
transpose
(),
dims
=
(
"a"
,
"b"
)
)
expected_y2
=
y2
.
transpose
(
)
assert
equal_computations
([
x
.
type
.
filter_variable
(
y2
)],
[
expected_y2
])
assert
equal_computations
([
x
.
type
.
filter_variable
(
y2
)],
[
expected_y2
])
y3
=
xtensor
(
"y3"
,
dims
=
(
"b"
,
"a"
),
shape
=
(
3
,
None
))
y3
=
xtensor
(
"y3"
,
dims
=
(
"b"
,
"a"
),
shape
=
(
3
,
None
))
expected_y3
=
as_xtensor
(
expected_y3
=
as_xtensor
(
specify_shape
(
specify_shape
(
y3
.
transpose
()
.
values
,
(
2
,
3
)),
dims
=
(
"a"
,
"b"
)
as_xtensor
(
y3
.
values
.
transpose
(),
dims
=
(
"a"
,
"b"
))
.
values
,
(
2
,
3
)
),
dims
=
(
"a"
,
"b"
),
)
)
assert
equal_computations
([
x
.
type
.
filter_variable
(
y3
)],
[
expected_y3
])
assert
equal_computations
([
x
.
type
.
filter_variable
(
y3
)],
[
expected_y3
])
...
@@ -116,7 +113,7 @@ def test_minimum_compile():
...
@@ -116,7 +113,7 @@ def test_minimum_compile():
from
pytensor.compile.mode
import
Mode
from
pytensor.compile.mode
import
Mode
x
=
xtensor
(
"x"
,
dims
=
(
"a"
,
"b"
),
shape
=
(
2
,
3
))
x
=
xtensor
(
"x"
,
dims
=
(
"a"
,
"b"
),
shape
=
(
2
,
3
))
y
=
as_xtensor
(
x
.
values
.
transpose
(),
dims
=
(
"b"
,
"a"
)
)
y
=
x
.
transpose
(
)
minimum_mode
=
Mode
(
linker
=
"py"
,
optimizer
=
"minimum_compile"
)
minimum_mode
=
Mode
(
linker
=
"py"
,
optimizer
=
"minimum_compile"
)
result
=
y
.
eval
({
"x"
:
np
.
ones
((
2
,
3
))},
mode
=
minimum_mode
)
result
=
y
.
eval
({
"x"
:
np
.
ones
((
2
,
3
))},
mode
=
minimum_mode
)
np
.
testing
.
assert_array_equal
(
result
,
np
.
ones
((
3
,
2
)))
np
.
testing
.
assert_array_equal
(
result
,
np
.
ones
((
3
,
2
)))
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论