Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
cc4b77a0
提交
cc4b77a0
authored
1月 31, 2026
作者:
ricardoV94
提交者:
Ricardo Vieira
2月 05, 2026
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Remove useless arguments from non xarray API
Make sure we don't issue unexpected warnings
上级
8da7cd76
隐藏空白字符变更
内嵌
并排
正在显示
10 个修改的文件
包含
37 行增加
和
56 行删除
+37
-56
pyproject.toml
pyproject.toml
+1
-0
shape.py
pytensor/xtensor/shape.py
+4
-39
type.py
pytensor/xtensor/type.py
+13
-5
test_indexing.py
tests/xtensor/test_indexing.py
+2
-0
test_linalg.py
tests/xtensor/test_linalg.py
+1
-0
test_math.py
tests/xtensor/test_math.py
+1
-0
test_random.py
tests/xtensor/test_random.py
+6
-1
test_reduction.py
tests/xtensor/test_reduction.py
+1
-0
test_shape.py
tests/xtensor/test_shape.py
+7
-11
test_type.py
tests/xtensor/test_type.py
+1
-0
没有找到文件。
pyproject.toml
浏览文件 @
cc4b77a0
...
@@ -163,6 +163,7 @@ lines-after-imports = 2
...
@@ -163,6 +163,7 @@ lines-after-imports = 2
"tests/link/numba/**/test_*.py"
=
["E402"]
"tests/link/numba/**/test_*.py"
=
["E402"]
"tests/link/pytorch/**/test_*.py"
=
["E402"]
"tests/link/pytorch/**/test_*.py"
=
["E402"]
"tests/link/mlx/**/test_*.py"
=
["E402"]
"tests/link/mlx/**/test_*.py"
=
["E402"]
"tests/xtensor/**/test_*.py"
=
["E402"]
...
...
pytensor/xtensor/shape.py
浏览文件 @
cc4b77a0
import
typing
import
typing
import
warnings
import
warnings
from
collections.abc
import
Hashable
,
Sequence
from
collections.abc
import
Sequence
from
types
import
EllipsisType
from
types
import
EllipsisType
from
typing
import
Literal
from
typing
import
Literal
...
@@ -384,28 +384,10 @@ class Squeeze(XOp):
...
@@ -384,28 +384,10 @@ class Squeeze(XOp):
return
Apply
(
self
,
[
x
],
[
out
])
return
Apply
(
self
,
[
x
],
[
out
])
def
squeeze
(
x
,
dim
=
None
,
drop
=
False
,
axis
=
None
):
def
squeeze
(
x
,
dim
:
str
|
Sequence
[
str
]
|
None
=
None
):
"""Remove dimensions of size 1 from an XTensorVariable."""
"""Remove dimensions of size 1 from an XTensorVariable."""
x
=
as_xtensor
(
x
)
x
=
as_xtensor
(
x
)
# drop parameter is ignored in pytensor.xtensor
if
drop
is
not
None
:
warnings
.
warn
(
"drop parameter has no effect in pytensor.xtensor"
,
UserWarning
)
# dim and axis are mutually exclusive
if
dim
is
not
None
and
axis
is
not
None
:
raise
ValueError
(
"Cannot specify both `dim` and `axis`"
)
# if axis is specified, it must be a sequence of ints
if
axis
is
not
None
:
if
not
isinstance
(
axis
,
Sequence
):
axis
=
[
axis
]
if
not
all
(
isinstance
(
a
,
int
)
for
a
in
axis
):
raise
ValueError
(
"axis must be an integer or a sequence of integers"
)
# convert axis to dims
dims
=
tuple
(
x
.
type
.
dims
[
i
]
for
i
in
axis
)
# if dim is specified, it must be a string or a sequence of strings
# if dim is specified, it must be a string or a sequence of strings
if
dim
is
None
:
if
dim
is
None
:
dims
=
tuple
(
d
for
d
,
s
in
zip
(
x
.
type
.
dims
,
x
.
type
.
shape
)
if
s
==
1
)
dims
=
tuple
(
d
for
d
,
s
in
zip
(
x
.
type
.
dims
,
x
.
type
.
shape
)
if
s
==
1
)
...
@@ -461,33 +443,18 @@ class ExpandDims(XOp):
...
@@ -461,33 +443,18 @@ class ExpandDims(XOp):
return
Apply
(
self
,
[
x
,
size
],
[
out
])
return
Apply
(
self
,
[
x
,
size
],
[
out
])
def
expand_dims
(
x
,
dim
=
None
,
create_index_for_new_dim
=
None
,
axis
=
None
,
**
dim_kwargs
):
def
expand_dims
(
x
,
dim
=
None
,
axis
=
None
,
**
dim_kwargs
):
"""Add one or more new dimensions to an XTensorVariable."""
"""Add one or more new dimensions to an XTensorVariable."""
x
=
as_xtensor
(
x
)
x
=
as_xtensor
(
x
)
# Store original dimensions for axis handling
# Store original dimensions for axis handling
original_dims
=
x
.
type
.
dims
original_dims
=
x
.
type
.
dims
# Warn if create_index_for_new_dim is used (not supported)
if
create_index_for_new_dim
is
not
None
:
warnings
.
warn
(
"create_index_for_new_dim=False has no effect in pytensor.xtensor"
,
UserWarning
,
stacklevel
=
2
,
)
if
dim
is
None
:
if
dim
is
None
:
dim
=
dim_kwargs
dim
=
dim_kwargs
elif
dim_kwargs
:
elif
dim_kwargs
:
raise
ValueError
(
"Cannot specify both `dim` and `**dim_kwargs`"
)
raise
ValueError
(
"Cannot specify both `dim` and `**dim_kwargs`"
)
# Check that dim is Hashable or a sequence of Hashable or dict
if
not
isinstance
(
dim
,
Hashable
):
if
not
isinstance
(
dim
,
Sequence
|
dict
):
raise
TypeError
(
f
"unhashable type: {type(dim).__name__}"
)
if
not
all
(
isinstance
(
d
,
Hashable
)
for
d
in
dim
):
raise
TypeError
(
f
"unhashable type in {type(dim).__name__}"
)
# Normalize to a dimension-size mapping
# Normalize to a dimension-size mapping
if
isinstance
(
dim
,
str
):
if
isinstance
(
dim
,
str
):
dims_dict
=
{
dim
:
1
}
dims_dict
=
{
dim
:
1
}
...
@@ -496,9 +463,7 @@ def expand_dims(x, dim=None, create_index_for_new_dim=None, axis=None, **dim_kwa
...
@@ -496,9 +463,7 @@ def expand_dims(x, dim=None, create_index_for_new_dim=None, axis=None, **dim_kwa
elif
isinstance
(
dim
,
dict
):
elif
isinstance
(
dim
,
dict
):
dims_dict
=
{}
dims_dict
=
{}
for
name
,
val
in
dim
.
items
():
for
name
,
val
in
dim
.
items
():
if
isinstance
(
val
,
str
):
if
isinstance
(
val
,
list
|
tuple
|
np
.
ndarray
):
raise
TypeError
(
f
"Dimension size cannot be a string: {val}"
)
if
isinstance
(
val
,
Sequence
|
np
.
ndarray
):
warnings
.
warn
(
warnings
.
warn
(
"When a sequence is provided as a dimension size, only its length is used. "
"When a sequence is provided as a dimension size, only its length is used. "
"The actual values (which would be coordinates in xarray) are ignored."
,
"The actual values (which would be coordinates in xarray) are ignored."
,
...
...
pytensor/xtensor/type.py
浏览文件 @
cc4b77a0
...
@@ -687,7 +687,7 @@ class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]):
...
@@ -687,7 +687,7 @@ class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]):
The name(s) of the dimension(s) to remove. If None, all dimensions of size 1
The name(s) of the dimension(s) to remove. If None, all dimensions of size 1
(known statically) will be removed. Dimensions with unknown static shape will be retained, even if they have size 1 at runtime.
(known statically) will be removed. Dimensions with unknown static shape will be retained, even if they have size 1 at runtime.
drop : bool, optional
drop : bool, optional
I
f drop=True, drop squeezed coordinates instead of making them scala
r.
I
gnored by PyTenso
r.
axis : int or iterable of int, optional
axis : int or iterable of int, optional
The axis(es) to remove. If None, all dimensions of size 1 will be removed.
The axis(es) to remove. If None, all dimensions of size 1 will be removed.
Returns
Returns
...
@@ -695,12 +695,21 @@ class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]):
...
@@ -695,12 +695,21 @@ class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]):
XTensorVariable
XTensorVariable
A new tensor with the specified dimension(s) removed.
A new tensor with the specified dimension(s) removed.
"""
"""
return
px
.
shape
.
squeeze
(
self
,
dim
,
drop
,
axis
)
if
axis
is
not
None
:
if
dim
is
not
None
:
raise
ValueError
(
"Cannot specify both `dim` and `axis`"
)
if
not
isinstance
(
axis
,
Sequence
):
axis
=
(
axis
,)
dim
=
tuple
(
self
.
type
.
dims
[
i
]
for
i
in
axis
)
return
px
.
shape
.
squeeze
(
self
,
dim
)
def
expand_dims
(
def
expand_dims
(
self
,
self
,
dim
:
str
|
Sequence
[
str
]
|
dict
[
str
,
int
|
Sequence
]
|
None
=
None
,
dim
:
str
|
Sequence
[
str
]
|
dict
[
str
,
int
|
Sequence
]
|
None
=
None
,
create_index_for_new_dim
:
bool
=
Tru
e
,
create_index_for_new_dim
:
bool
|
None
=
Non
e
,
axis
:
int
|
Sequence
[
int
]
|
None
=
None
,
axis
:
int
|
Sequence
[
int
]
|
None
=
None
,
**
dim_kwargs
,
**
dim_kwargs
,
):
):
...
@@ -714,7 +723,7 @@ class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]):
...
@@ -714,7 +723,7 @@ class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]):
- int: the new size
- int: the new size
- sequence: coordinates (length determines size)
- sequence: coordinates (length determines size)
create_index_for_new_dim : bool,
default: True
create_index_for_new_dim : bool,
optional
Ignored by PyTensor
Ignored by PyTensor
axis : int | Sequence[int] | None, default: None
axis : int | Sequence[int] | None, default: None
Not implemented yet. In xarray, specifies where to insert the new dimension(s).
Not implemented yet. In xarray, specifies where to insert the new dimension(s).
...
@@ -730,7 +739,6 @@ class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]):
...
@@ -730,7 +739,6 @@ class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]):
return
px
.
shape
.
expand_dims
(
return
px
.
shape
.
expand_dims
(
self
,
self
,
dim
,
dim
,
create_index_for_new_dim
=
create_index_for_new_dim
,
axis
=
axis
,
axis
=
axis
,
**
dim_kwargs
,
**
dim_kwargs
,
)
)
...
...
tests/xtensor/test_indexing.py
浏览文件 @
cc4b77a0
...
@@ -2,6 +2,8 @@ import pytest
...
@@ -2,6 +2,8 @@ import pytest
pytest
.
importorskip
(
"xarray"
)
pytest
.
importorskip
(
"xarray"
)
pytestmark
=
pytest
.
mark
.
filterwarnings
(
"error"
)
import
re
import
re
...
...
tests/xtensor/test_linalg.py
浏览文件 @
cc4b77a0
...
@@ -3,6 +3,7 @@ import pytest
...
@@ -3,6 +3,7 @@ import pytest
pytest
.
importorskip
(
"xarray"
)
pytest
.
importorskip
(
"xarray"
)
pytest
.
importorskip
(
"xarray_einstats"
)
pytest
.
importorskip
(
"xarray_einstats"
)
pytestmark
=
pytest
.
mark
.
filterwarnings
(
"error"
)
import
numpy
as
np
import
numpy
as
np
from
xarray
import
DataArray
from
xarray
import
DataArray
...
...
tests/xtensor/test_math.py
浏览文件 @
cc4b77a0
...
@@ -2,6 +2,7 @@ import pytest
...
@@ -2,6 +2,7 @@ import pytest
pytest
.
importorskip
(
"xarray"
)
pytest
.
importorskip
(
"xarray"
)
pytestmark
=
pytest
.
mark
.
filterwarnings
(
"error"
)
import
inspect
import
inspect
...
...
tests/xtensor/test_random.py
浏览文件 @
cc4b77a0
import
pytest
pytest
.
importorskip
(
"xarray"
)
pytestmark
=
pytest
.
mark
.
filterwarnings
(
"error"
)
import
inspect
import
inspect
import
re
import
re
from
copy
import
deepcopy
from
copy
import
deepcopy
import
numpy
as
np
import
numpy
as
np
import
pytest
import
pytensor.tensor.random
as
ptr
import
pytensor.tensor.random
as
ptr
import
pytensor.xtensor.random
as
pxr
import
pytensor.xtensor.random
as
pxr
...
...
tests/xtensor/test_reduction.py
浏览文件 @
cc4b77a0
...
@@ -2,6 +2,7 @@ import pytest
...
@@ -2,6 +2,7 @@ import pytest
pytest
.
importorskip
(
"xarray"
)
pytest
.
importorskip
(
"xarray"
)
pytestmark
=
pytest
.
mark
.
filterwarnings
(
"error"
)
from
pytensor.xtensor.type
import
xtensor
from
pytensor.xtensor.type
import
xtensor
from
tests.xtensor.util
import
xr_arange_like
,
xr_assert_allclose
,
xr_function
from
tests.xtensor.util
import
xr_arange_like
,
xr_assert_allclose
,
xr_function
...
...
tests/xtensor/test_shape.py
浏览文件 @
cc4b77a0
...
@@ -2,6 +2,7 @@ import pytest
...
@@ -2,6 +2,7 @@ import pytest
pytest
.
importorskip
(
"xarray"
)
pytest
.
importorskip
(
"xarray"
)
pytestmark
=
pytest
.
mark
.
filterwarnings
(
"error"
)
import
re
import
re
from
itertools
import
chain
,
combinations
from
itertools
import
chain
,
combinations
...
@@ -33,9 +34,6 @@ from tests.xtensor.util import (
...
@@ -33,9 +34,6 @@ from tests.xtensor.util import (
)
)
pytest
.
importorskip
(
"xarray"
)
def
powerset
(
iterable
,
min_group_size
=
0
):
def
powerset
(
iterable
,
min_group_size
=
0
):
"Subsequences of the iterable from shortest to longest."
"Subsequences of the iterable from shortest to longest."
# powerset([1,2,3]) → () (1,) (2,) (3,) (1,2) (1,3) (2,3) (1,2,3)
# powerset([1,2,3]) → () (1,) (2,) (3,) (1,2) (1,3) (2,3) (1,2,3)
...
@@ -322,7 +320,7 @@ def test_squeeze():
...
@@ -322,7 +320,7 @@ def test_squeeze():
xr_assert_allclose
(
fn5
(
x5_test
),
x5_test
.
squeeze
(
axis
=
1
))
xr_assert_allclose
(
fn5
(
x5_test
),
x5_test
.
squeeze
(
axis
=
1
))
# Test axis parameter with negative index
# Test axis parameter with negative index
y5
=
x5
.
squeeze
(
axis
=-
1
)
# squeeze dimension at index -2 (b)
y5
=
x5
.
squeeze
(
axis
=-
2
)
# squeeze dimension at index -2 (b)
fn5
=
xr_function
([
x5
],
y5
)
fn5
=
xr_function
([
x5
],
y5
)
x5_test
=
xr_arange_like
(
x5
)
x5_test
=
xr_arange_like
(
x5
)
xr_assert_allclose
(
fn5
(
x5_test
),
x5_test
.
squeeze
(
axis
=-
2
))
xr_assert_allclose
(
fn5
(
x5_test
),
x5_test
.
squeeze
(
axis
=-
2
))
...
@@ -333,12 +331,9 @@ def test_squeeze():
...
@@ -333,12 +331,9 @@ def test_squeeze():
x2_test
=
xr_arange_like
(
x2
)
x2_test
=
xr_arange_like
(
x2
)
xr_assert_allclose
(
fn6
(
x2_test
),
x2_test
.
squeeze
(
axis
=
[
1
,
2
]))
xr_assert_allclose
(
fn6
(
x2_test
),
x2_test
.
squeeze
(
axis
=
[
1
,
2
]))
# Test drop parameter
warning
# Test drop parameter
ignored, but accepted
x7
=
xtensor
(
"x7"
,
dims
=
(
"a"
,
"b"
),
shape
=
(
2
,
1
))
x7
=
xtensor
(
"x7"
,
dims
=
(
"a"
,
"b"
),
shape
=
(
2
,
1
))
with
pytest
.
warns
(
y7
=
x7
.
squeeze
(
"b"
,
drop
=
True
)
UserWarning
,
match
=
"drop parameter has no effect in pytensor.xtensor"
):
y7
=
x7
.
squeeze
(
"b"
,
drop
=
True
)
# squeeze and drop coordinate
fn7
=
xr_function
([
x7
],
y7
)
fn7
=
xr_function
([
x7
],
y7
)
x7_test
=
xr_arange_like
(
x7
)
x7_test
=
xr_arange_like
(
x7
)
xr_assert_allclose
(
fn7
(
x7_test
),
x7_test
.
squeeze
(
"b"
,
drop
=
True
))
xr_assert_allclose
(
fn7
(
x7_test
),
x7_test
.
squeeze
(
"b"
,
drop
=
True
))
...
@@ -391,7 +386,8 @@ def test_expand_dims():
...
@@ -391,7 +386,8 @@ def test_expand_dims():
xr_assert_allclose
(
fn
(
x_test
),
x_test
.
expand_dims
(
country
=
2
,
state
=
3
))
xr_assert_allclose
(
fn
(
x_test
),
x_test
.
expand_dims
(
country
=
2
,
state
=
3
))
# Test with a dict of name-coord array pairs
# Test with a dict of name-coord array pairs
y
=
x
.
expand_dims
({
"country"
:
np
.
array
([
1
,
2
]),
"state"
:
np
.
array
([
3
,
4
,
5
])})
with
pytest
.
warns
(
UserWarning
,
match
=
"only its length is used"
):
y
=
x
.
expand_dims
({
"country"
:
np
.
array
([
1
,
2
]),
"state"
:
np
.
array
([
3
,
4
,
5
])})
fn
=
xr_function
([
x
],
y
)
fn
=
xr_function
([
x
],
y
)
xr_assert_allclose
(
xr_assert_allclose
(
fn
(
x_test
),
fn
(
x_test
),
...
@@ -471,7 +467,7 @@ def test_expand_dims_errors():
...
@@ -471,7 +467,7 @@ def test_expand_dims_errors():
# TypeError: unhashable type: 'numpy.ndarray'
# TypeError: unhashable type: 'numpy.ndarray'
# Test with a numpy array as dim (not supported)
# Test with a numpy array as dim (not supported)
with
pytest
.
raises
(
TypeError
,
match
=
"
unhashable type
"
):
with
pytest
.
raises
(
TypeError
,
match
=
"
Invalid type for `dim`
"
):
y
.
expand_dims
(
np
.
array
([
1
,
2
]))
y
.
expand_dims
(
np
.
array
([
1
,
2
]))
...
...
tests/xtensor/test_type.py
浏览文件 @
cc4b77a0
...
@@ -2,6 +2,7 @@ import pytest
...
@@ -2,6 +2,7 @@ import pytest
pytest
.
importorskip
(
"xarray"
)
pytest
.
importorskip
(
"xarray"
)
pytestmark
=
pytest
.
mark
.
filterwarnings
(
"error"
)
import
numpy
as
np
import
numpy
as
np
from
xarray
import
DataArray
from
xarray
import
DataArray
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论