Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
9ede7f6a
提交
9ede7f6a
authored
6月 13, 2025
作者:
Allen Downey
提交者:
Ricardo Vieira
6月 21, 2025
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Implement expand_dims for XTensorVariables (#1449)
上级
9716b3f2
隐藏空白字符变更
内嵌
并排
正在显示
4 个修改的文件
包含
307 行增加
和
3 行删除
+307
-3
shape.py
pytensor/xtensor/rewriting/shape.py
+33
-1
shape.py
pytensor/xtensor/shape.py
+122
-1
type.py
pytensor/xtensor/type.py
+41
-0
test_shape.py
tests/xtensor/test_shape.py
+111
-1
没有找到文件。
pytensor/xtensor/rewriting/shape.py
浏览文件 @
9ede7f6a
from
pytensor.graph
import
node_rewriter
from
pytensor.tensor
import
(
broadcast_to
,
expand_dims
,
join
,
moveaxis
,
specify_shape
,
...
...
@@ -10,6 +11,7 @@ from pytensor.xtensor.basic import tensor_from_xtensor, xtensor_from_tensor
from
pytensor.xtensor.rewriting.basic
import
register_lower_xtensor
from
pytensor.xtensor.shape
import
(
Concat
,
ExpandDims
,
Squeeze
,
Stack
,
Transpose
,
...
...
@@ -121,7 +123,7 @@ def lower_transpose(fgraph, node):
@register_lower_xtensor
@node_rewriter
([
Squeeze
])
def
lo
cal_squeeze_reshap
e
(
fgraph
,
node
):
def
lo
wer_squeez
e
(
fgraph
,
node
):
"""Rewrite Squeeze to tensor.squeeze."""
[
x
]
=
node
.
inputs
x_tensor
=
tensor_from_xtensor
(
x
)
...
...
@@ -132,3 +134,33 @@ def local_squeeze_reshape(fgraph, node):
new_out
=
xtensor_from_tensor
(
x_tensor_squeezed
,
dims
=
node
.
outputs
[
0
]
.
type
.
dims
)
return
[
new_out
]
@register_lower_xtensor
@node_rewriter
([
ExpandDims
])
def
lower_expand_dims
(
fgraph
,
node
):
"""Rewrite ExpandDims using tensor operations."""
x
,
size
=
node
.
inputs
out
=
node
.
outputs
[
0
]
# Convert inputs to tensors
x_tensor
=
tensor_from_xtensor
(
x
)
size_tensor
=
tensor_from_xtensor
(
size
)
# Get the new dimension name and position
new_axis
=
0
# Always insert at front
# Use tensor operations
if
out
.
type
.
shape
[
0
]
==
1
:
# Simple case: just expand with size 1
result_tensor
=
expand_dims
(
x_tensor
,
new_axis
)
else
:
# Otherwise broadcast to the requested size
result_tensor
=
broadcast_to
(
x_tensor
,
(
size_tensor
,
*
x_tensor
.
shape
))
# Preserve static shape information
result_tensor
=
specify_shape
(
result_tensor
,
out
.
type
.
shape
)
# Convert result back to xtensor
result
=
xtensor_from_tensor
(
result_tensor
,
dims
=
out
.
type
.
dims
)
return
[
result
]
pytensor/xtensor/shape.py
浏览文件 @
9ede7f6a
import
typing
import
warnings
from
collections.abc
import
Sequence
from
collections.abc
import
Hashable
,
Sequence
from
types
import
EllipsisType
from
typing
import
Literal
import
numpy
as
np
from
pytensor.graph
import
Apply
from
pytensor.scalar
import
discrete_dtypes
,
upcast
from
pytensor.tensor
import
as_tensor
,
get_scalar_constant_value
from
pytensor.tensor.exceptions
import
NotScalarConstantError
from
pytensor.tensor.type
import
integer_dtypes
from
pytensor.xtensor.basic
import
XOp
from
pytensor.xtensor.type
import
as_xtensor
,
xtensor
...
...
@@ -381,3 +384,121 @@ def squeeze(x, dim=None, drop=False, axis=None):
return
x
# no-op if nothing to squeeze
return
Squeeze
(
dims
=
dims
)(
x
)
class
ExpandDims
(
XOp
):
"""Add a new dimension to an XTensorVariable."""
__props__
=
(
"dim"
,)
def
__init__
(
self
,
dim
):
if
not
isinstance
(
dim
,
str
):
raise
TypeError
(
f
"`dim` must be a string, got: {type(self.dim)}"
)
self
.
dim
=
dim
def
make_node
(
self
,
x
,
size
):
x
=
as_xtensor
(
x
)
if
self
.
dim
in
x
.
type
.
dims
:
raise
ValueError
(
f
"Dimension {self.dim} already exists in {x.type.dims}"
)
size
=
as_xtensor
(
size
,
dims
=
())
if
not
(
size
.
dtype
in
integer_dtypes
and
size
.
ndim
==
0
):
raise
ValueError
(
f
"size should be an integer scalar, got {size.type}"
)
try
:
static_size
=
int
(
get_scalar_constant_value
(
size
))
except
NotScalarConstantError
:
static_size
=
None
# If size is a constant, validate it
if
static_size
is
not
None
and
static_size
<
0
:
raise
ValueError
(
f
"size must be 0 or positive, got: {static_size}"
)
new_shape
=
(
static_size
,
*
x
.
type
.
shape
)
# Insert new dim at front
new_dims
=
(
self
.
dim
,
*
x
.
type
.
dims
)
out
=
xtensor
(
dtype
=
x
.
type
.
dtype
,
shape
=
new_shape
,
dims
=
new_dims
,
)
return
Apply
(
self
,
[
x
,
size
],
[
out
])
def
expand_dims
(
x
,
dim
=
None
,
create_index_for_new_dim
=
None
,
axis
=
None
,
**
dim_kwargs
):
"""Add one or more new dimensions to an XTensorVariable."""
x
=
as_xtensor
(
x
)
# Store original dimensions for axis handling
original_dims
=
x
.
type
.
dims
# Warn if create_index_for_new_dim is used (not supported)
if
create_index_for_new_dim
is
not
None
:
warnings
.
warn
(
"create_index_for_new_dim=False has no effect in pytensor.xtensor"
,
UserWarning
,
stacklevel
=
2
,
)
if
dim
is
None
:
dim
=
dim_kwargs
elif
dim_kwargs
:
raise
ValueError
(
"Cannot specify both `dim` and `**dim_kwargs`"
)
# Check that dim is Hashable or a sequence of Hashable or dict
if
not
isinstance
(
dim
,
Hashable
):
if
not
isinstance
(
dim
,
Sequence
|
dict
):
raise
TypeError
(
f
"unhashable type: {type(dim).__name__}"
)
if
not
all
(
isinstance
(
d
,
Hashable
)
for
d
in
dim
):
raise
TypeError
(
f
"unhashable type in {type(dim).__name__}"
)
# Normalize to a dimension-size mapping
if
isinstance
(
dim
,
str
):
dims_dict
=
{
dim
:
1
}
elif
isinstance
(
dim
,
Sequence
)
and
not
isinstance
(
dim
,
dict
):
dims_dict
=
{
d
:
1
for
d
in
dim
}
elif
isinstance
(
dim
,
dict
):
dims_dict
=
{}
for
name
,
val
in
dim
.
items
():
if
isinstance
(
val
,
str
):
raise
TypeError
(
f
"Dimension size cannot be a string: {val}"
)
if
isinstance
(
val
,
Sequence
|
np
.
ndarray
):
warnings
.
warn
(
"When a sequence is provided as a dimension size, only its length is used. "
"The actual values (which would be coordinates in xarray) are ignored."
,
UserWarning
,
stacklevel
=
2
,
)
dims_dict
[
name
]
=
len
(
val
)
else
:
# should be int or symbolic scalar
dims_dict
[
name
]
=
val
else
:
raise
TypeError
(
f
"Invalid type for `dim`: {type(dim)}"
)
# Insert each new dim at the front (reverse order preserves user intent)
for
name
,
size
in
reversed
(
dims_dict
.
items
()):
x
=
ExpandDims
(
dim
=
name
)(
x
,
size
)
# If axis is specified, transpose to put new dimensions in the right place
if
axis
is
not
None
:
# Wrap non-sequence axis in a list
if
not
isinstance
(
axis
,
Sequence
):
axis
=
[
axis
]
# require len(axis) == len(dims_dict)
if
len
(
axis
)
!=
len
(
dims_dict
):
raise
ValueError
(
"lengths of dim and axis should be identical."
)
# Insert new dimensions at their specified positions
target_dims
=
list
(
original_dims
)
for
name
,
pos
in
zip
(
dims_dict
,
axis
):
# Convert negative axis to positive position relative to current dims
if
pos
<
0
:
pos
=
len
(
target_dims
)
+
pos
+
1
target_dims
.
insert
(
pos
,
name
)
x
=
Transpose
(
dims
=
tuple
(
target_dims
))(
x
)
return
x
pytensor/xtensor/type.py
浏览文件 @
9ede7f6a
...
...
@@ -573,6 +573,47 @@ class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]):
"""
return
px
.
shape
.
squeeze
(
self
,
dim
,
drop
,
axis
)
def
expand_dims
(
self
,
dim
:
str
|
Sequence
[
str
]
|
dict
[
str
,
int
|
Sequence
]
|
None
=
None
,
create_index_for_new_dim
:
bool
=
True
,
axis
:
int
|
Sequence
[
int
]
|
None
=
None
,
**
dim_kwargs
,
):
"""Add one or more new dimensions to the tensor.
Parameters
----------
dim : str | Sequence[str] | dict[str, int | Sequence] | None
If str or sequence of str, new dimensions with size 1.
If dict, keys are dimension names and values are either:
- int: the new size
- sequence: coordinates (length determines size)
create_index_for_new_dim : bool, default: True
Currently ignored. Reserved for future coordinate support.
In xarray, when True (default), creates a coordinate index for the new dimension
with values from 0 to size-1. When False, no coordinate index is created.
axis : int | Sequence[int] | None, default: None
Not implemented yet. In xarray, specifies where to insert the new dimension(s).
By default (None), new dimensions are inserted at the beginning (axis=0).
Symbolic axis is not supported yet.
Negative values count from the end.
**dim_kwargs : int | Sequence
Alternative to `dim` dict. Only used if `dim` is None.
Returns
-------
XTensorVariable
A tensor with additional dimensions inserted at the front.
"""
return
px
.
shape
.
expand_dims
(
self
,
dim
,
create_index_for_new_dim
=
create_index_for_new_dim
,
axis
=
axis
,
**
dim_kwargs
,
)
# ndarray methods
# https://docs.xarray.dev/en/latest/api.html#id7
def
clip
(
self
,
min
,
max
):
...
...
tests/xtensor/test_shape.py
浏览文件 @
9ede7f6a
...
...
@@ -8,10 +8,10 @@ import re
from
itertools
import
chain
,
combinations
import
numpy
as
np
import
pytest
from
xarray
import
DataArray
from
xarray
import
concat
as
xr_concat
from
pytensor.tensor
import
scalar
from
pytensor.xtensor.shape
import
(
concat
,
stack
,
...
...
@@ -356,3 +356,113 @@ def test_squeeze_errors():
fn2
=
xr_function
([
x2
],
y2
)
with
pytest
.
raises
(
Exception
):
fn2
(
x2_test
)
def
test_expand_dims
():
"""Test expand_dims."""
x
=
xtensor
(
"x"
,
dims
=
(
"city"
,
"year"
),
shape
=
(
2
,
2
))
x_test
=
xr_arange_like
(
x
)
# Implicit size 1
y
=
x
.
expand_dims
(
"country"
)
fn
=
xr_function
([
x
],
y
)
xr_assert_allclose
(
fn
(
x_test
),
x_test
.
expand_dims
(
"country"
))
# Test with multiple dimensions
y
=
x
.
expand_dims
([
"country"
,
"state"
])
fn
=
xr_function
([
x
],
y
)
xr_assert_allclose
(
fn
(
x_test
),
x_test
.
expand_dims
([
"country"
,
"state"
]))
# Test with a dict of name-size pairs
y
=
x
.
expand_dims
({
"country"
:
2
,
"state"
:
3
})
fn
=
xr_function
([
x
],
y
)
xr_assert_allclose
(
fn
(
x_test
),
x_test
.
expand_dims
({
"country"
:
2
,
"state"
:
3
}))
# Test with kwargs (equivalent to dict)
y
=
x
.
expand_dims
(
country
=
2
,
state
=
3
)
fn
=
xr_function
([
x
],
y
)
xr_assert_allclose
(
fn
(
x_test
),
x_test
.
expand_dims
(
country
=
2
,
state
=
3
))
# Test with a dict of name-coord array pairs
y
=
x
.
expand_dims
({
"country"
:
np
.
array
([
1
,
2
]),
"state"
:
np
.
array
([
3
,
4
,
5
])})
fn
=
xr_function
([
x
],
y
)
xr_assert_allclose
(
fn
(
x_test
),
x_test
.
expand_dims
({
"country"
:
np
.
array
([
1
,
2
]),
"state"
:
np
.
array
([
3
,
4
,
5
])}),
)
# Symbolic size 1
size_sym_1
=
scalar
(
"size_sym_1"
,
dtype
=
"int64"
)
y
=
x
.
expand_dims
({
"country"
:
size_sym_1
})
fn
=
xr_function
([
x
,
size_sym_1
],
y
)
xr_assert_allclose
(
fn
(
x_test
,
1
),
x_test
.
expand_dims
({
"country"
:
1
}))
# Test with symbolic sizes in dict
size_sym_2
=
scalar
(
"size_sym_2"
,
dtype
=
"int64"
)
y
=
x
.
expand_dims
({
"country"
:
size_sym_1
,
"state"
:
size_sym_2
})
fn
=
xr_function
([
x
,
size_sym_1
,
size_sym_2
],
y
)
xr_assert_allclose
(
fn
(
x_test
,
2
,
3
),
x_test
.
expand_dims
({
"country"
:
2
,
"state"
:
3
}))
# Test with symbolic sizes in kwargs
y
=
x
.
expand_dims
(
country
=
size_sym_1
,
state
=
size_sym_2
)
fn
=
xr_function
([
x
,
size_sym_1
,
size_sym_2
],
y
)
xr_assert_allclose
(
fn
(
x_test
,
2
,
3
),
x_test
.
expand_dims
({
"country"
:
2
,
"state"
:
3
}))
# Test with axis parameter
y
=
x
.
expand_dims
(
"country"
,
axis
=
1
)
fn
=
xr_function
([
x
],
y
)
xr_assert_allclose
(
fn
(
x_test
),
x_test
.
expand_dims
(
"country"
,
axis
=
1
))
# Test with negative axis parameter
y
=
x
.
expand_dims
(
"country"
,
axis
=-
1
)
fn
=
xr_function
([
x
],
y
)
xr_assert_allclose
(
fn
(
x_test
),
x_test
.
expand_dims
(
"country"
,
axis
=-
1
))
# Add two new dims with axis parameters
y
=
x
.
expand_dims
([
"country"
,
"state"
],
axis
=
[
1
,
2
])
fn
=
xr_function
([
x
],
y
)
xr_assert_allclose
(
fn
(
x_test
),
x_test
.
expand_dims
([
"country"
,
"state"
],
axis
=
[
1
,
2
])
)
# Add two dims with negative axis parameters
y
=
x
.
expand_dims
([
"country"
,
"state"
],
axis
=
[
-
1
,
-
2
])
fn
=
xr_function
([
x
],
y
)
xr_assert_allclose
(
fn
(
x_test
),
x_test
.
expand_dims
([
"country"
,
"state"
],
axis
=
[
-
1
,
-
2
])
)
# Add two dims with positive and negative axis parameters
y
=
x
.
expand_dims
([
"country"
,
"state"
],
axis
=
[
-
2
,
1
])
fn
=
xr_function
([
x
],
y
)
xr_assert_allclose
(
fn
(
x_test
),
x_test
.
expand_dims
([
"country"
,
"state"
],
axis
=
[
-
2
,
1
])
)
def
test_expand_dims_errors
():
"""Test error handling in expand_dims."""
# Expanding existing dim
x
=
xtensor
(
"x"
,
dims
=
(
"city"
,),
shape
=
(
3
,))
y
=
x
.
expand_dims
(
"country"
)
with
pytest
.
raises
(
ValueError
,
match
=
"already exists"
):
y
.
expand_dims
(
"city"
)
# Invalid dim type
with
pytest
.
raises
(
TypeError
,
match
=
"Invalid type for `dim`"
):
x
.
expand_dims
(
123
)
# Duplicate dimension creation
y
=
x
.
expand_dims
(
"new"
)
with
pytest
.
raises
(
ValueError
,
match
=
"already exists"
):
y
.
expand_dims
(
"new"
)
# Find out what xarray does with a numpy array as dim
# x_test = xr_arange_like(x)
# x_test.expand_dims(np.array([1, 2]))
# TypeError: unhashable type: 'numpy.ndarray'
# Test with a numpy array as dim (not supported)
with
pytest
.
raises
(
TypeError
,
match
=
"unhashable type"
):
y
.
expand_dims
(
np
.
array
([
1
,
2
]))
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论