Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
cc4b77a0
提交
cc4b77a0
authored
1月 31, 2026
作者:
ricardoV94
提交者:
Ricardo Vieira
2月 05, 2026
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Remove useless arguments from non xarray API
Make sure we don't issue unexpected warnings
上级
8da7cd76
显示空白字符变更
内嵌
并排
正在显示
10 个修改的文件
包含
36 行增加
和
55 行删除
+36
-55
pyproject.toml
pyproject.toml
+1
-0
shape.py
pytensor/xtensor/shape.py
+4
-39
type.py
pytensor/xtensor/type.py
+13
-5
test_indexing.py
tests/xtensor/test_indexing.py
+2
-0
test_linalg.py
tests/xtensor/test_linalg.py
+1
-0
test_math.py
tests/xtensor/test_math.py
+1
-0
test_random.py
tests/xtensor/test_random.py
+6
-1
test_reduction.py
tests/xtensor/test_reduction.py
+1
-0
test_shape.py
tests/xtensor/test_shape.py
+6
-10
test_type.py
tests/xtensor/test_type.py
+1
-0
没有找到文件。
pyproject.toml
浏览文件 @
cc4b77a0
...
...
@@ -163,6 +163,7 @@ lines-after-imports = 2
"tests/link/numba/**/test_*.py"
=
["E402"]
"tests/link/pytorch/**/test_*.py"
=
["E402"]
"tests/link/mlx/**/test_*.py"
=
["E402"]
"tests/xtensor/**/test_*.py"
=
["E402"]
...
...
pytensor/xtensor/shape.py
浏览文件 @
cc4b77a0
import
typing
import
warnings
from
collections.abc
import
Hashable
,
Sequence
from
collections.abc
import
Sequence
from
types
import
EllipsisType
from
typing
import
Literal
...
...
@@ -384,28 +384,10 @@ class Squeeze(XOp):
return
Apply
(
self
,
[
x
],
[
out
])
def
squeeze
(
x
,
dim
=
None
,
drop
=
False
,
axis
=
None
):
def
squeeze
(
x
,
dim
:
str
|
Sequence
[
str
]
|
None
=
None
):
"""Remove dimensions of size 1 from an XTensorVariable."""
x
=
as_xtensor
(
x
)
# drop parameter is ignored in pytensor.xtensor
if
drop
is
not
None
:
warnings
.
warn
(
"drop parameter has no effect in pytensor.xtensor"
,
UserWarning
)
# dim and axis are mutually exclusive
if
dim
is
not
None
and
axis
is
not
None
:
raise
ValueError
(
"Cannot specify both `dim` and `axis`"
)
# if axis is specified, it must be a sequence of ints
if
axis
is
not
None
:
if
not
isinstance
(
axis
,
Sequence
):
axis
=
[
axis
]
if
not
all
(
isinstance
(
a
,
int
)
for
a
in
axis
):
raise
ValueError
(
"axis must be an integer or a sequence of integers"
)
# convert axis to dims
dims
=
tuple
(
x
.
type
.
dims
[
i
]
for
i
in
axis
)
# if dim is specified, it must be a string or a sequence of strings
if
dim
is
None
:
dims
=
tuple
(
d
for
d
,
s
in
zip
(
x
.
type
.
dims
,
x
.
type
.
shape
)
if
s
==
1
)
...
...
@@ -461,33 +443,18 @@ class ExpandDims(XOp):
return
Apply
(
self
,
[
x
,
size
],
[
out
])
def
expand_dims
(
x
,
dim
=
None
,
create_index_for_new_dim
=
None
,
axis
=
None
,
**
dim_kwargs
):
def
expand_dims
(
x
,
dim
=
None
,
axis
=
None
,
**
dim_kwargs
):
"""Add one or more new dimensions to an XTensorVariable."""
x
=
as_xtensor
(
x
)
# Store original dimensions for axis handling
original_dims
=
x
.
type
.
dims
# Warn if create_index_for_new_dim is used (not supported)
if
create_index_for_new_dim
is
not
None
:
warnings
.
warn
(
"create_index_for_new_dim=False has no effect in pytensor.xtensor"
,
UserWarning
,
stacklevel
=
2
,
)
if
dim
is
None
:
dim
=
dim_kwargs
elif
dim_kwargs
:
raise
ValueError
(
"Cannot specify both `dim` and `**dim_kwargs`"
)
# Check that dim is Hashable or a sequence of Hashable or dict
if
not
isinstance
(
dim
,
Hashable
):
if
not
isinstance
(
dim
,
Sequence
|
dict
):
raise
TypeError
(
f
"unhashable type: {type(dim).__name__}"
)
if
not
all
(
isinstance
(
d
,
Hashable
)
for
d
in
dim
):
raise
TypeError
(
f
"unhashable type in {type(dim).__name__}"
)
# Normalize to a dimension-size mapping
if
isinstance
(
dim
,
str
):
dims_dict
=
{
dim
:
1
}
...
...
@@ -496,9 +463,7 @@ def expand_dims(x, dim=None, create_index_for_new_dim=None, axis=None, **dim_kwa
elif
isinstance
(
dim
,
dict
):
dims_dict
=
{}
for
name
,
val
in
dim
.
items
():
if
isinstance
(
val
,
str
):
raise
TypeError
(
f
"Dimension size cannot be a string: {val}"
)
if
isinstance
(
val
,
Sequence
|
np
.
ndarray
):
if
isinstance
(
val
,
list
|
tuple
|
np
.
ndarray
):
warnings
.
warn
(
"When a sequence is provided as a dimension size, only its length is used. "
"The actual values (which would be coordinates in xarray) are ignored."
,
...
...
pytensor/xtensor/type.py
浏览文件 @
cc4b77a0
...
...
@@ -687,7 +687,7 @@ class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]):
The name(s) of the dimension(s) to remove. If None, all dimensions of size 1
(known statically) will be removed. Dimensions with unknown static shape will be retained, even if they have size 1 at runtime.
drop : bool, optional
I
f drop=True, drop squeezed coordinates instead of making them scala
r.
I
gnored by PyTenso
r.
axis : int or iterable of int, optional
The axis(es) to remove. If None, all dimensions of size 1 will be removed.
Returns
...
...
@@ -695,12 +695,21 @@ class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]):
XTensorVariable
A new tensor with the specified dimension(s) removed.
"""
return
px
.
shape
.
squeeze
(
self
,
dim
,
drop
,
axis
)
if
axis
is
not
None
:
if
dim
is
not
None
:
raise
ValueError
(
"Cannot specify both `dim` and `axis`"
)
if
not
isinstance
(
axis
,
Sequence
):
axis
=
(
axis
,)
dim
=
tuple
(
self
.
type
.
dims
[
i
]
for
i
in
axis
)
return
px
.
shape
.
squeeze
(
self
,
dim
)
def
expand_dims
(
self
,
dim
:
str
|
Sequence
[
str
]
|
dict
[
str
,
int
|
Sequence
]
|
None
=
None
,
create_index_for_new_dim
:
bool
=
Tru
e
,
create_index_for_new_dim
:
bool
|
None
=
Non
e
,
axis
:
int
|
Sequence
[
int
]
|
None
=
None
,
**
dim_kwargs
,
):
...
...
@@ -714,7 +723,7 @@ class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]):
- int: the new size
- sequence: coordinates (length determines size)
create_index_for_new_dim : bool,
default: True
create_index_for_new_dim : bool,
optional
Ignored by PyTensor
axis : int | Sequence[int] | None, default: None
Not implemented yet. In xarray, specifies where to insert the new dimension(s).
...
...
@@ -730,7 +739,6 @@ class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]):
return
px
.
shape
.
expand_dims
(
self
,
dim
,
create_index_for_new_dim
=
create_index_for_new_dim
,
axis
=
axis
,
**
dim_kwargs
,
)
...
...
tests/xtensor/test_indexing.py
浏览文件 @
cc4b77a0
...
...
@@ -2,6 +2,8 @@ import pytest
pytest
.
importorskip
(
"xarray"
)
pytestmark
=
pytest
.
mark
.
filterwarnings
(
"error"
)
import
re
...
...
tests/xtensor/test_linalg.py
浏览文件 @
cc4b77a0
...
...
@@ -3,6 +3,7 @@ import pytest
pytest
.
importorskip
(
"xarray"
)
pytest
.
importorskip
(
"xarray_einstats"
)
pytestmark
=
pytest
.
mark
.
filterwarnings
(
"error"
)
import
numpy
as
np
from
xarray
import
DataArray
...
...
tests/xtensor/test_math.py
浏览文件 @
cc4b77a0
...
...
@@ -2,6 +2,7 @@ import pytest
pytest
.
importorskip
(
"xarray"
)
pytestmark
=
pytest
.
mark
.
filterwarnings
(
"error"
)
import
inspect
...
...
tests/xtensor/test_random.py
浏览文件 @
cc4b77a0
import
pytest
pytest
.
importorskip
(
"xarray"
)
pytestmark
=
pytest
.
mark
.
filterwarnings
(
"error"
)
import
inspect
import
re
from
copy
import
deepcopy
import
numpy
as
np
import
pytest
import
pytensor.tensor.random
as
ptr
import
pytensor.xtensor.random
as
pxr
...
...
tests/xtensor/test_reduction.py
浏览文件 @
cc4b77a0
...
...
@@ -2,6 +2,7 @@ import pytest
pytest
.
importorskip
(
"xarray"
)
pytestmark
=
pytest
.
mark
.
filterwarnings
(
"error"
)
from
pytensor.xtensor.type
import
xtensor
from
tests.xtensor.util
import
xr_arange_like
,
xr_assert_allclose
,
xr_function
...
...
tests/xtensor/test_shape.py
浏览文件 @
cc4b77a0
...
...
@@ -2,6 +2,7 @@ import pytest
pytest
.
importorskip
(
"xarray"
)
pytestmark
=
pytest
.
mark
.
filterwarnings
(
"error"
)
import
re
from
itertools
import
chain
,
combinations
...
...
@@ -33,9 +34,6 @@ from tests.xtensor.util import (
)
pytest
.
importorskip
(
"xarray"
)
def
powerset
(
iterable
,
min_group_size
=
0
):
"Subsequences of the iterable from shortest to longest."
# powerset([1,2,3]) → () (1,) (2,) (3,) (1,2) (1,3) (2,3) (1,2,3)
...
...
@@ -322,7 +320,7 @@ def test_squeeze():
xr_assert_allclose
(
fn5
(
x5_test
),
x5_test
.
squeeze
(
axis
=
1
))
# Test axis parameter with negative index
y5
=
x5
.
squeeze
(
axis
=-
1
)
# squeeze dimension at index -2 (b)
y5
=
x5
.
squeeze
(
axis
=-
2
)
# squeeze dimension at index -2 (b)
fn5
=
xr_function
([
x5
],
y5
)
x5_test
=
xr_arange_like
(
x5
)
xr_assert_allclose
(
fn5
(
x5_test
),
x5_test
.
squeeze
(
axis
=-
2
))
...
...
@@ -333,12 +331,9 @@ def test_squeeze():
x2_test
=
xr_arange_like
(
x2
)
xr_assert_allclose
(
fn6
(
x2_test
),
x2_test
.
squeeze
(
axis
=
[
1
,
2
]))
# Test drop parameter
warning
# Test drop parameter
ignored, but accepted
x7
=
xtensor
(
"x7"
,
dims
=
(
"a"
,
"b"
),
shape
=
(
2
,
1
))
with
pytest
.
warns
(
UserWarning
,
match
=
"drop parameter has no effect in pytensor.xtensor"
):
y7
=
x7
.
squeeze
(
"b"
,
drop
=
True
)
# squeeze and drop coordinate
y7
=
x7
.
squeeze
(
"b"
,
drop
=
True
)
fn7
=
xr_function
([
x7
],
y7
)
x7_test
=
xr_arange_like
(
x7
)
xr_assert_allclose
(
fn7
(
x7_test
),
x7_test
.
squeeze
(
"b"
,
drop
=
True
))
...
...
@@ -391,6 +386,7 @@ def test_expand_dims():
xr_assert_allclose
(
fn
(
x_test
),
x_test
.
expand_dims
(
country
=
2
,
state
=
3
))
# Test with a dict of name-coord array pairs
with
pytest
.
warns
(
UserWarning
,
match
=
"only its length is used"
):
y
=
x
.
expand_dims
({
"country"
:
np
.
array
([
1
,
2
]),
"state"
:
np
.
array
([
3
,
4
,
5
])})
fn
=
xr_function
([
x
],
y
)
xr_assert_allclose
(
...
...
@@ -471,7 +467,7 @@ def test_expand_dims_errors():
# TypeError: unhashable type: 'numpy.ndarray'
# Test with a numpy array as dim (not supported)
with
pytest
.
raises
(
TypeError
,
match
=
"
unhashable type
"
):
with
pytest
.
raises
(
TypeError
,
match
=
"
Invalid type for `dim`
"
):
y
.
expand_dims
(
np
.
array
([
1
,
2
]))
...
...
tests/xtensor/test_type.py
浏览文件 @
cc4b77a0
...
...
@@ -2,6 +2,7 @@ import pytest
pytest
.
importorskip
(
"xarray"
)
pytestmark
=
pytest
.
mark
.
filterwarnings
(
"error"
)
import
numpy
as
np
from
xarray
import
DataArray
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论