Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
9716b3f2
提交
9716b3f2
authored
6月 06, 2025
作者:
Allen Downey
提交者:
Ricardo Vieira
6月 21, 2025
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Implement squeeze for XTensorVariables
上级
071c4eb8
显示空白字符变更
内嵌
并排
正在显示
4 个修改的文件
包含
240 行增加
和
3 行删除
+240
-3
shape.py
pytensor/xtensor/rewriting/shape.py
+29
-2
shape.py
pytensor/xtensor/shape.py
+84
-0
type.py
pytensor/xtensor/type.py
+26
-0
test_shape.py
tests/xtensor/test_shape.py
+101
-1
没有找到文件。
pytensor/xtensor/rewriting/shape.py
浏览文件 @
9716b3f2
from
pytensor.graph
import
node_rewriter
from
pytensor.tensor
import
broadcast_to
,
join
,
moveaxis
,
specify_shape
from
pytensor.tensor
import
(
broadcast_to
,
join
,
moveaxis
,
specify_shape
,
squeeze
,
)
from
pytensor.xtensor.basic
import
tensor_from_xtensor
,
xtensor_from_tensor
from
pytensor.xtensor.rewriting.basic
import
register_lower_xtensor
from
pytensor.xtensor.shape
import
Concat
,
Stack
,
Transpose
,
UnStack
from
pytensor.xtensor.shape
import
(
Concat
,
Squeeze
,
Stack
,
Transpose
,
UnStack
,
)
@register_lower_xtensor
...
...
@@ -105,3 +117,18 @@ def lower_transpose(fgraph, node):
x_tensor_transposed
=
x_tensor
.
transpose
(
perm
)
new_out
=
xtensor_from_tensor
(
x_tensor_transposed
,
dims
=
out_dims
)
return
[
new_out
]
@register_lower_xtensor
@node_rewriter
([
Squeeze
])
def
local_squeeze_reshape
(
fgraph
,
node
):
"""Rewrite Squeeze to tensor.squeeze."""
[
x
]
=
node
.
inputs
x_tensor
=
tensor_from_xtensor
(
x
)
x_dims
=
x
.
type
.
dims
dims_to_remove
=
node
.
op
.
dims
axes_to_squeeze
=
tuple
(
x_dims
.
index
(
d
)
for
d
in
dims_to_remove
)
x_tensor_squeezed
=
squeeze
(
x_tensor
,
axis
=
axes_to_squeeze
)
new_out
=
xtensor_from_tensor
(
x_tensor_squeezed
,
dims
=
node
.
outputs
[
0
]
.
type
.
dims
)
return
[
new_out
]
pytensor/xtensor/shape.py
浏览文件 @
9716b3f2
...
...
@@ -297,3 +297,87 @@ class Concat(XOp):
def
concat
(
xtensors
,
dim
:
str
):
return
Concat
(
dim
=
dim
)(
*
xtensors
)
class
Squeeze
(
XOp
):
"""Remove specified dimensions from an XTensorVariable.
Only dimensions that are known statically to be size 1 will be removed.
Symbolic dimensions must be explicitly specified, and are assumed safe.
Parameters
----------
dim : tuple of str
The names of the dimensions to remove.
"""
__props__
=
(
"dims"
,)
def
__init__
(
self
,
dims
):
self
.
dims
=
tuple
(
sorted
(
set
(
dims
)))
def
make_node
(
self
,
x
):
x
=
as_xtensor
(
x
)
# Validate that dims exist and are size-1 if statically known
dims_to_remove
=
[]
x_dims
=
x
.
type
.
dims
x_shape
=
x
.
type
.
shape
for
d
in
self
.
dims
:
if
d
not
in
x_dims
:
raise
ValueError
(
f
"Dimension {d} not found in {x.type.dims}"
)
idx
=
x_dims
.
index
(
d
)
dim_size
=
x_shape
[
idx
]
if
dim_size
is
not
None
and
dim_size
!=
1
:
raise
ValueError
(
f
"Dimension {d} has static size {dim_size}, not 1"
)
dims_to_remove
.
append
(
idx
)
new_dims
=
tuple
(
d
for
i
,
d
in
enumerate
(
x
.
type
.
dims
)
if
i
not
in
dims_to_remove
)
new_shape
=
tuple
(
s
for
i
,
s
in
enumerate
(
x
.
type
.
shape
)
if
i
not
in
dims_to_remove
)
out
=
xtensor
(
dtype
=
x
.
type
.
dtype
,
shape
=
new_shape
,
dims
=
new_dims
,
)
return
Apply
(
self
,
[
x
],
[
out
])
def
squeeze
(
x
,
dim
=
None
,
drop
=
False
,
axis
=
None
):
"""Remove dimensions of size 1 from an XTensorVariable."""
x
=
as_xtensor
(
x
)
# drop parameter is ignored in pytensor.xtensor
if
drop
is
not
None
:
warnings
.
warn
(
"drop parameter has no effect in pytensor.xtensor"
,
UserWarning
)
# dim and axis are mutually exclusive
if
dim
is
not
None
and
axis
is
not
None
:
raise
ValueError
(
"Cannot specify both `dim` and `axis`"
)
# if axis is specified, it must be a sequence of ints
if
axis
is
not
None
:
if
not
isinstance
(
axis
,
Sequence
):
axis
=
[
axis
]
if
not
all
(
isinstance
(
a
,
int
)
for
a
in
axis
):
raise
ValueError
(
"axis must be an integer or a sequence of integers"
)
# convert axis to dims
dims
=
tuple
(
x
.
type
.
dims
[
i
]
for
i
in
axis
)
# if dim is specified, it must be a string or a sequence of strings
if
dim
is
None
:
dims
=
tuple
(
d
for
d
,
s
in
zip
(
x
.
type
.
dims
,
x
.
type
.
shape
)
if
s
==
1
)
elif
isinstance
(
dim
,
str
):
dims
=
(
dim
,)
else
:
dims
=
tuple
(
dim
)
if
not
dims
:
return
x
# no-op if nothing to squeeze
return
Squeeze
(
dims
=
dims
)(
x
)
pytensor/xtensor/type.py
浏览文件 @
9716b3f2
...
...
@@ -547,6 +547,32 @@ class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]):
def
thin
(
self
,
indexers
:
dict
[
str
,
Any
]
|
int
|
None
=
None
,
**
indexers_kwargs
):
return
self
.
_head_tail_or_thin
(
indexers
,
indexers_kwargs
,
kind
=
"thin"
)
def
squeeze
(
self
,
dim
:
Sequence
[
str
]
|
str
|
None
=
None
,
drop
=
None
,
axis
:
int
|
Sequence
[
int
]
|
None
=
None
,
):
"""Remove dimensions of size 1 from an XTensorVariable.
Parameters
----------
x : XTensorVariable
The input tensor
dim : str or None or iterable of str, optional
The name(s) of the dimension(s) to remove. If None, all dimensions of size 1
(known statically) will be removed. Dimensions with unknown static shape will be retained, even if they have size 1 at runtime.
drop : bool, optional
If drop=True, drop squeezed coordinates instead of making them scalar.
axis : int or iterable of int, optional
The axis(es) to remove. If None, all dimensions of size 1 will be removed.
Returns
-------
XTensorVariable
A new tensor with the specified dimension(s) removed.
"""
return
px
.
shape
.
squeeze
(
self
,
dim
,
drop
,
axis
)
# ndarray methods
# https://docs.xarray.dev/en/latest/api.html#id7
def
clip
(
self
,
min
,
max
):
...
...
tests/xtensor/test_shape.py
浏览文件 @
9716b3f2
...
...
@@ -8,10 +8,16 @@ import re
from
itertools
import
chain
,
combinations
import
numpy
as
np
import
pytest
from
xarray
import
DataArray
from
xarray
import
concat
as
xr_concat
from
pytensor.xtensor.shape
import
concat
,
stack
,
transpose
,
unstack
from
pytensor.xtensor.shape
import
(
concat
,
stack
,
transpose
,
unstack
,
)
from
pytensor.xtensor.type
import
xtensor
from
tests.xtensor.util
import
(
xr_arange_like
,
...
...
@@ -21,6 +27,9 @@ from tests.xtensor.util import (
)
pytest
.
importorskip
(
"xarray"
)
def
powerset
(
iterable
,
min_group_size
=
0
):
"Subsequences of the iterable from shortest to longest."
# powerset([1,2,3]) → () (1,) (2,) (3,) (1,2) (1,3) (2,3) (1,2,3)
...
...
@@ -256,3 +265,94 @@ def test_concat_scalar():
res
=
fn
(
x1_test
,
x2_test
)
expected_res
=
xr_concat
([
x1_test
,
x2_test
],
dim
=
"new_dim"
)
xr_assert_allclose
(
res
,
expected_res
)
def
test_squeeze
():
"""Test squeeze."""
# Single dimension
x1
=
xtensor
(
"x1"
,
dims
=
(
"city"
,
"country"
),
shape
=
(
3
,
1
))
y1
=
x1
.
squeeze
(
"country"
)
fn1
=
xr_function
([
x1
],
y1
)
x1_test
=
xr_arange_like
(
x1
)
xr_assert_allclose
(
fn1
(
x1_test
),
x1_test
.
squeeze
(
"country"
))
# Multiple dimensions and order independence
x2
=
xtensor
(
"x2"
,
dims
=
(
"a"
,
"b"
,
"c"
,
"d"
),
shape
=
(
2
,
1
,
1
,
3
))
y2a
=
x2
.
squeeze
([
"b"
,
"c"
])
y2b
=
x2
.
squeeze
([
"c"
,
"b"
])
# Test order independence
y2c
=
x2
.
squeeze
([
"b"
,
"b"
])
# Test redundant dimensions
y2d
=
x2
.
squeeze
([])
# Test empty list (no-op)
fn2a
=
xr_function
([
x2
],
y2a
)
fn2b
=
xr_function
([
x2
],
y2b
)
fn2c
=
xr_function
([
x2
],
y2c
)
fn2d
=
xr_function
([
x2
],
y2d
)
x2_test
=
xr_arange_like
(
x2
)
xr_assert_allclose
(
fn2a
(
x2_test
),
x2_test
.
squeeze
([
"b"
,
"c"
]))
xr_assert_allclose
(
fn2b
(
x2_test
),
x2_test
.
squeeze
([
"c"
,
"b"
]))
xr_assert_allclose
(
fn2c
(
x2_test
),
x2_test
.
squeeze
([
"b"
,
"b"
]))
xr_assert_allclose
(
fn2d
(
x2_test
),
x2_test
)
# Unknown shapes
x3
=
xtensor
(
"x3"
,
dims
=
(
"a"
,
"b"
,
"c"
))
# shape unknown
y3
=
x3
.
squeeze
(
"b"
)
x3_test
=
xr_arange_like
(
xtensor
(
dims
=
x3
.
dims
,
shape
=
(
2
,
1
,
3
)))
fn3
=
xr_function
([
x3
],
y3
)
xr_assert_allclose
(
fn3
(
x3_test
),
x3_test
.
squeeze
(
"b"
))
# Mixed known + unknown shapes
x4
=
xtensor
(
"x4"
,
dims
=
(
"a"
,
"b"
,
"c"
),
shape
=
(
None
,
1
,
3
))
y4
=
x4
.
squeeze
(
"b"
)
x4_test
=
xr_arange_like
(
xtensor
(
dims
=
x4
.
dims
,
shape
=
(
4
,
1
,
3
)))
fn4
=
xr_function
([
x4
],
y4
)
xr_assert_allclose
(
fn4
(
x4_test
),
x4_test
.
squeeze
(
"b"
))
# Test axis parameter
x5
=
xtensor
(
"x5"
,
dims
=
(
"a"
,
"b"
,
"c"
),
shape
=
(
2
,
1
,
3
))
y5
=
x5
.
squeeze
(
axis
=
1
)
# squeeze dimension at index 1 (b)
fn5
=
xr_function
([
x5
],
y5
)
x5_test
=
xr_arange_like
(
x5
)
xr_assert_allclose
(
fn5
(
x5_test
),
x5_test
.
squeeze
(
axis
=
1
))
# Test axis parameter with negative index
y5
=
x5
.
squeeze
(
axis
=-
1
)
# squeeze dimension at index -2 (b)
fn5
=
xr_function
([
x5
],
y5
)
x5_test
=
xr_arange_like
(
x5
)
xr_assert_allclose
(
fn5
(
x5_test
),
x5_test
.
squeeze
(
axis
=-
2
))
# Test axis parameter with sequence of ints
y6
=
x2
.
squeeze
(
axis
=
[
1
,
2
])
fn6
=
xr_function
([
x2
],
y6
)
x2_test
=
xr_arange_like
(
x2
)
xr_assert_allclose
(
fn6
(
x2_test
),
x2_test
.
squeeze
(
axis
=
[
1
,
2
]))
# Test drop parameter warning
x7
=
xtensor
(
"x7"
,
dims
=
(
"a"
,
"b"
),
shape
=
(
2
,
1
))
with
pytest
.
warns
(
UserWarning
,
match
=
"drop parameter has no effect in pytensor.xtensor"
):
y7
=
x7
.
squeeze
(
"b"
,
drop
=
True
)
# squeeze and drop coordinate
fn7
=
xr_function
([
x7
],
y7
)
x7_test
=
xr_arange_like
(
x7
)
xr_assert_allclose
(
fn7
(
x7_test
),
x7_test
.
squeeze
(
"b"
,
drop
=
True
))
def
test_squeeze_errors
():
"""Test error cases for squeeze."""
# Non-existent dimension
x1
=
xtensor
(
"x1"
,
dims
=
(
"city"
,
"country"
),
shape
=
(
3
,
1
))
with
pytest
.
raises
(
ValueError
,
match
=
"Dimension .* not found"
):
x1
.
squeeze
(
"time"
)
# Dimension size > 1
with
pytest
.
raises
(
ValueError
,
match
=
"has static size .* not 1"
):
x1
.
squeeze
(
"city"
)
# Symbolic shape: dim is not 1 at runtime → should raise
x2
=
xtensor
(
"x2"
,
dims
=
(
"a"
,
"b"
,
"c"
))
# shape unknown
y2
=
x2
.
squeeze
(
"b"
)
x2_test
=
xr_arange_like
(
xtensor
(
dims
=
x2
.
dims
,
shape
=
(
2
,
2
,
3
)))
fn2
=
xr_function
([
x2
],
y2
)
with
pytest
.
raises
(
Exception
):
fn2
(
x2_test
)
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论