Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
e9219159
提交
e9219159
authored
6月 29, 2025
作者:
Allen Downey
提交者:
Ricardo Vieira
7月 02, 2025
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Implement broadcast for XTensorVariables
Co-authored-by:
Ricardo
<
ricardo.vieira1994@gmail.com
>
上级
e1ce1c35
隐藏空白字符变更
内嵌
并排
正在显示
6 个修改的文件
包含
311 行增加
和
4 行删除
+311
-4
__init__.py
pytensor/xtensor/__init__.py
+1
-1
shape.py
pytensor/xtensor/rewriting/shape.py
+60
-0
shape.py
pytensor/xtensor/shape.py
+62
-1
type.py
pytensor/xtensor/type.py
+9
-0
vectorization.py
pytensor/xtensor/vectorization.py
+12
-2
test_shape.py
tests/xtensor/test_shape.py
+167
-0
没有找到文件。
pytensor/xtensor/__init__.py
浏览文件 @
e9219159
...
...
@@ -3,7 +3,7 @@ import warnings
import
pytensor.xtensor.rewriting
from
pytensor.xtensor
import
linalg
,
random
from
pytensor.xtensor.math
import
dot
from
pytensor.xtensor.shape
import
concat
from
pytensor.xtensor.shape
import
broadcast
,
concat
from
pytensor.xtensor.type
import
(
as_xtensor
,
xtensor
,
...
...
pytensor/xtensor/rewriting/shape.py
浏览文件 @
e9219159
import
pytensor.tensor
as
pt
from
pytensor.graph
import
node_rewriter
from
pytensor.tensor
import
(
broadcast_to
,
...
...
@@ -11,6 +12,7 @@ from pytensor.xtensor.basic import tensor_from_xtensor, xtensor_from_tensor
from
pytensor.xtensor.rewriting.basic
import
register_lower_xtensor
from
pytensor.xtensor.rewriting.utils
import
lower_aligned
from
pytensor.xtensor.shape
import
(
Broadcast
,
Concat
,
ExpandDims
,
Squeeze
,
...
...
@@ -157,3 +159,61 @@ def lower_expand_dims(fgraph, node):
# Convert result back to xtensor
result
=
xtensor_from_tensor
(
result_tensor
,
dims
=
out
.
type
.
dims
)
return
[
result
]
@register_lower_xtensor
@node_rewriter
(
tracks
=
[
Broadcast
])
def
lower_broadcast
(
fgraph
,
node
):
"""Rewrite XBroadcast using tensor operations."""
excluded_dims
=
node
.
op
.
exclude
tensor_inputs
=
[
lower_aligned
(
inp
,
out
.
type
.
dims
)
for
inp
,
out
in
zip
(
node
.
inputs
,
node
.
outputs
,
strict
=
True
)
]
if
not
excluded_dims
:
# Simple case: All dimensions are broadcasted
tensor_outputs
=
pt
.
broadcast_arrays
(
*
tensor_inputs
)
else
:
# Complex case: Some dimensions are excluded from broadcasting
# Pick the first dimension_length for each dim
broadcast_dims
=
{
d
:
None
for
d
in
node
.
outputs
[
0
]
.
type
.
dims
if
d
not
in
excluded_dims
}
for
xtensor_inp
in
node
.
inputs
:
for
dim
,
dim_length
in
xtensor_inp
.
sizes
.
items
():
if
dim
in
broadcast_dims
and
broadcast_dims
[
dim
]
is
None
:
# If the dimension is not excluded, set its shape
broadcast_dims
[
dim
]
=
dim_length
assert
not
any
(
value
is
None
for
value
in
broadcast_dims
.
values
()
),
"All dimensions must have a length"
# Create zeros with the broadcast dimensions, to then broadcast each input against
# PyTensor will rewrite into using only the shapes of the zeros tensor
broadcast_dims
=
pt
.
zeros
(
tuple
(
broadcast_dims
.
values
()),
dtype
=
node
.
outputs
[
0
]
.
type
.
dtype
,
)
n_broadcast_dims
=
broadcast_dims
.
ndim
tensor_outputs
=
[]
for
tensor_inp
,
xtensor_out
in
zip
(
tensor_inputs
,
node
.
outputs
,
strict
=
True
):
n_excluded_dims
=
tensor_inp
.
type
.
ndim
-
n_broadcast_dims
# Excluded dimensions are on the right side of the output tensor so we padright the broadcast_dims
# second is equivalent to `np.broadcast_arrays(x, y)[1]` in PyTensor
tensor_outputs
.
append
(
pt
.
second
(
pt
.
shape_padright
(
broadcast_dims
,
n_excluded_dims
),
tensor_inp
,
)
)
new_outs
=
[
xtensor_from_tensor
(
out_tensor
,
dims
=
out
.
type
.
dims
)
for
out_tensor
,
out
in
zip
(
tensor_outputs
,
node
.
outputs
)
]
return
new_outs
pytensor/xtensor/shape.py
浏览文件 @
e9219159
...
...
@@ -13,7 +13,8 @@ from pytensor.tensor.exceptions import NotScalarConstantError
from
pytensor.tensor.type
import
integer_dtypes
from
pytensor.tensor.utils
import
get_static_shape_from_size_variables
from
pytensor.xtensor.basic
import
XOp
from
pytensor.xtensor.type
import
as_xtensor
,
xtensor
from
pytensor.xtensor.type
import
XTensorVariable
,
as_xtensor
,
xtensor
from
pytensor.xtensor.vectorization
import
combine_dims_and_shape
class
Stack
(
XOp
):
...
...
@@ -504,3 +505,63 @@ def expand_dims(x, dim=None, create_index_for_new_dim=None, axis=None, **dim_kwa
x
=
Transpose
(
dims
=
tuple
(
target_dims
))(
x
)
return
x
class
Broadcast
(
XOp
):
"""Broadcast multiple XTensorVariables against each other."""
__props__
=
(
"exclude"
,)
def
__init__
(
self
,
exclude
:
Sequence
[
str
]
=
()):
self
.
exclude
=
tuple
(
exclude
)
def
make_node
(
self
,
*
inputs
):
inputs
=
[
as_xtensor
(
x
)
for
x
in
inputs
]
exclude
=
self
.
exclude
dims_and_shape
=
combine_dims_and_shape
(
inputs
,
exclude
=
exclude
)
broadcast_dims
=
tuple
(
dims_and_shape
.
keys
())
broadcast_shape
=
tuple
(
dims_and_shape
.
values
())
dtype
=
upcast
(
*
[
x
.
type
.
dtype
for
x
in
inputs
])
outputs
=
[]
for
x
in
inputs
:
x_dims
=
x
.
type
.
dims
x_shape
=
x
.
type
.
shape
# The output has excluded dimensions in the order they appear in the op argument
excluded_dims
=
tuple
(
d
for
d
in
exclude
if
d
in
x_dims
)
excluded_shape
=
tuple
(
x_shape
[
x_dims
.
index
(
d
)]
for
d
in
excluded_dims
)
output
=
xtensor
(
dtype
=
dtype
,
shape
=
broadcast_shape
+
excluded_shape
,
dims
=
broadcast_dims
+
excluded_dims
,
)
outputs
.
append
(
output
)
return
Apply
(
self
,
inputs
,
outputs
)
def
broadcast
(
*
args
,
exclude
:
str
|
Sequence
[
str
]
|
None
=
None
)
->
tuple
[
XTensorVariable
,
...
]:
"""Broadcast any number of XTensorVariables against each other.
Parameters
----------
*args : XTensorVariable
The tensors to broadcast against each other.
exclude : str or Sequence[str] or None, optional
"""
if
not
args
:
return
()
if
exclude
is
None
:
exclude
=
()
elif
isinstance
(
exclude
,
str
):
exclude
=
(
exclude
,)
elif
not
isinstance
(
exclude
,
Sequence
):
raise
TypeError
(
f
"exclude must be None, str, or Sequence, got {type(exclude)}"
)
# xarray broadcast always returns a tuple, even if there's only one tensor
return
tuple
(
Broadcast
(
exclude
=
exclude
)(
*
args
,
return_list
=
True
))
# type: ignore
pytensor/xtensor/type.py
浏览文件 @
e9219159
...
...
@@ -736,6 +736,15 @@ class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]):
"""Matrix multiplication with another XTensorVariable, contracting over matching or specified dims."""
return
px
.
math
.
dot
(
self
,
other
,
dim
=
dim
)
def
broadcast
(
self
,
*
others
,
exclude
=
None
):
"""Broadcast this tensor against other XTensorVariables."""
return
px
.
shape
.
broadcast
(
self
,
*
others
,
exclude
=
exclude
)
def
broadcast_like
(
self
,
other
,
exclude
=
None
):
"""Broadcast this tensor against another XTensorVariable."""
_
,
self_bcast
=
px
.
shape
.
broadcast
(
other
,
self
,
exclude
=
exclude
)
return
self_bcast
class
XTensorConstantSignature
(
TensorConstantSignature
):
pass
...
...
pytensor/xtensor/vectorization.py
浏览文件 @
e9219159
from
collections.abc
import
Sequence
from
itertools
import
chain
import
numpy
as
np
...
...
@@ -13,13 +14,22 @@ from pytensor.tensor.utils import (
get_static_shape_from_size_variables
,
)
from
pytensor.xtensor.basic
import
XOp
from
pytensor.xtensor.type
import
as_xtensor
,
xtensor
from
pytensor.xtensor.type
import
XTensorVariable
,
as_xtensor
,
xtensor
def
combine_dims_and_shape
(
inputs
):
def
combine_dims_and_shape
(
inputs
:
Sequence
[
XTensorVariable
],
exclude
:
Sequence
[
str
]
|
None
=
None
)
->
dict
[
str
,
int
|
None
]:
"""Combine information of static dimensions and shapes from multiple xtensor inputs.
Exclude
"""
exclude_set
:
set
[
str
]
=
set
()
if
exclude
is
None
else
set
(
exclude
)
dims_and_shape
:
dict
[
str
,
int
|
None
]
=
{}
for
inp
in
inputs
:
for
dim
,
dim_length
in
zip
(
inp
.
type
.
dims
,
inp
.
type
.
shape
):
if
dim
in
exclude_set
:
continue
if
dim
not
in
dims_and_shape
:
dims_and_shape
[
dim
]
=
dim_length
elif
dim_length
is
not
None
:
...
...
tests/xtensor/test_shape.py
浏览文件 @
e9219159
...
...
@@ -9,10 +9,12 @@ from itertools import chain, combinations
import
numpy
as
np
from
xarray
import
DataArray
from
xarray
import
broadcast
as
xr_broadcast
from
xarray
import
concat
as
xr_concat
from
pytensor.tensor
import
scalar
from
pytensor.xtensor.shape
import
(
broadcast
,
concat
,
stack
,
unstack
,
...
...
@@ -466,3 +468,168 @@ def test_expand_dims_errors():
# Test with a numpy array as dim (not supported)
with
pytest
.
raises
(
TypeError
,
match
=
"unhashable type"
):
y
.
expand_dims
(
np
.
array
([
1
,
2
]))
class
TestBroadcast
:
@pytest.mark.parametrize
(
"exclude"
,
[
None
,
[],
[
"b"
],
[
"b"
,
"d"
],
[
"a"
,
"d"
],
[
"b"
,
"c"
,
"d"
],
[
"a"
,
"b"
,
"c"
,
"d"
],
],
)
def
test_compatible_excluded_shapes
(
self
,
exclude
):
# Create test data
x
=
xtensor
(
"x"
,
dims
=
(
"a"
,
"b"
),
shape
=
(
3
,
4
))
y
=
xtensor
(
"y"
,
dims
=
(
"c"
,
"d"
),
shape
=
(
5
,
6
))
z
=
xtensor
(
"z"
,
dims
=
(
"b"
,
"d"
),
shape
=
(
4
,
6
))
x_test
=
xr_arange_like
(
x
)
y_test
=
xr_arange_like
(
y
)
z_test
=
xr_arange_like
(
z
)
# Test with excluded dims
x2_expected
,
y2_expected
,
z2_expected
=
xr_broadcast
(
x_test
,
y_test
,
z_test
,
exclude
=
exclude
)
x2
,
y2
,
z2
=
broadcast
(
x
,
y
,
z
,
exclude
=
exclude
)
fn
=
xr_function
([
x
,
y
,
z
],
[
x2
,
y2
,
z2
])
x2_result
,
y2_result
,
z2_result
=
fn
(
x_test
,
y_test
,
z_test
)
xr_assert_allclose
(
x2_result
,
x2_expected
)
xr_assert_allclose
(
y2_result
,
y2_expected
)
xr_assert_allclose
(
z2_result
,
z2_expected
)
def
test_incompatible_excluded_shapes
(
self
):
# Test that excluded dims are allowed to be different sizes
x
=
xtensor
(
"x"
,
dims
=
(
"a"
,
"b"
),
shape
=
(
3
,
4
))
y
=
xtensor
(
"y"
,
dims
=
(
"c"
,
"d"
),
shape
=
(
5
,
6
))
z
=
xtensor
(
"z"
,
dims
=
(
"b"
,
"d"
),
shape
=
(
4
,
7
))
out
=
broadcast
(
x
,
y
,
z
,
exclude
=
[
"d"
])
x_test
=
xr_arange_like
(
x
)
y_test
=
xr_arange_like
(
y
)
z_test
=
xr_arange_like
(
z
)
fn
=
xr_function
([
x
,
y
,
z
],
out
)
results
=
fn
(
x_test
,
y_test
,
z_test
)
expected_results
=
xr_broadcast
(
x_test
,
y_test
,
z_test
,
exclude
=
[
"d"
])
for
res
,
expected_res
in
zip
(
results
,
expected_results
,
strict
=
True
):
xr_assert_allclose
(
res
,
expected_res
)
@pytest.mark.parametrize
(
"exclude"
,
[[],
[
"b"
],
[
"b"
,
"c"
],
[
"a"
,
"b"
,
"d"
]])
def
test_runtime_shapes
(
self
,
exclude
):
x
=
xtensor
(
"x"
,
dims
=
(
"a"
,
"b"
),
shape
=
(
None
,
4
))
y
=
xtensor
(
"y"
,
dims
=
(
"c"
,
"d"
),
shape
=
(
5
,
None
))
z
=
xtensor
(
"z"
,
dims
=
(
"b"
,
"d"
),
shape
=
(
None
,
None
))
out
=
broadcast
(
x
,
y
,
z
,
exclude
=
exclude
)
x_test
=
xr_arange_like
(
xtensor
(
dims
=
x
.
dims
,
shape
=
(
3
,
4
)))
y_test
=
xr_arange_like
(
xtensor
(
dims
=
y
.
dims
,
shape
=
(
5
,
6
)))
z_test
=
xr_arange_like
(
xtensor
(
dims
=
z
.
dims
,
shape
=
(
4
,
6
)))
fn
=
xr_function
([
x
,
y
,
z
],
out
)
results
=
fn
(
x_test
,
y_test
,
z_test
)
expected_results
=
xr_broadcast
(
x_test
,
y_test
,
z_test
,
exclude
=
exclude
)
for
res
,
expected_res
in
zip
(
results
,
expected_results
,
strict
=
True
):
xr_assert_allclose
(
res
,
expected_res
)
# Test invalid shape raises an error
# Note: We might decide not to raise an error in the lowered graphs for performance reasons
if
"d"
not
in
exclude
:
z_test_bad
=
xr_arange_like
(
xtensor
(
dims
=
z
.
dims
,
shape
=
(
4
,
7
)))
with
pytest
.
raises
(
Exception
):
fn
(
x_test
,
y_test
,
z_test_bad
)
def
test_broadcast_excluded_dims_in_different_order
(
self
):
"""Test broadcasting excluded dims are aligned with user input."""
x
=
xtensor
(
"x"
,
dims
=
(
"a"
,
"c"
,
"b"
),
shape
=
(
3
,
4
,
5
))
y
=
xtensor
(
"y"
,
dims
=
(
"a"
,
"b"
,
"c"
),
shape
=
(
3
,
5
,
4
))
out
=
(
out_x
,
out_y
)
=
broadcast
(
x
,
y
,
exclude
=
[
"c"
,
"b"
])
assert
out_x
.
type
.
dims
==
(
"a"
,
"c"
,
"b"
)
assert
out_y
.
type
.
dims
==
(
"a"
,
"c"
,
"b"
)
x_test
=
xr_arange_like
(
x
)
y_test
=
xr_arange_like
(
y
)
fn
=
xr_function
([
x
,
y
],
out
)
results
=
fn
(
x_test
,
y_test
)
expected_results
=
xr_broadcast
(
x_test
,
y_test
,
exclude
=
[
"c"
,
"b"
])
for
res
,
expected_res
in
zip
(
results
,
expected_results
,
strict
=
True
):
xr_assert_allclose
(
res
,
expected_res
)
def
test_broadcast_errors
(
self
):
"""Test error handling in broadcast."""
x
=
xtensor
(
"x"
,
dims
=
(
"a"
,
"b"
),
shape
=
(
3
,
4
))
y
=
xtensor
(
"y"
,
dims
=
(
"c"
,
"d"
),
shape
=
(
5
,
6
))
z
=
xtensor
(
"z"
,
dims
=
(
"b"
,
"d"
),
shape
=
(
4
,
6
))
with
pytest
.
raises
(
TypeError
,
match
=
"exclude must be None, str, or Sequence"
):
broadcast
(
x
,
y
,
z
,
exclude
=
1
)
# Test with conflicting shapes
x
=
xtensor
(
"x"
,
dims
=
(
"a"
,
"b"
),
shape
=
(
3
,
4
))
y
=
xtensor
(
"y"
,
dims
=
(
"c"
,
"d"
),
shape
=
(
5
,
6
))
z
=
xtensor
(
"z"
,
dims
=
(
"b"
,
"d"
),
shape
=
(
4
,
7
))
with
pytest
.
raises
(
ValueError
,
match
=
"Dimension .* has conflicting shapes"
):
broadcast
(
x
,
y
,
z
)
def
test_broadcast_no_input
(
self
):
assert
broadcast
()
==
xr_broadcast
()
assert
broadcast
(
exclude
=
(
"a"
,))
==
xr_broadcast
(
exclude
=
(
"a"
,))
def
test_broadcast_single_input
(
self
):
"""Test broadcasting a single input."""
x
=
xtensor
(
"x"
,
dims
=
(
"a"
,
"b"
),
shape
=
(
3
,
4
))
# Broadcast with a single input can still imply a transpose via the exclude parameter
outs
=
[
*
broadcast
(
x
),
*
broadcast
(
x
,
exclude
=
(
"a"
,
"b"
)),
*
broadcast
(
x
,
exclude
=
(
"b"
,
"a"
)),
*
broadcast
(
x
,
exclude
=
(
"b"
,)),
]
fn
=
xr_function
([
x
],
outs
)
x_test
=
xr_arange_like
(
x
)
results
=
fn
(
x_test
)
expected_results
=
[
*
xr_broadcast
(
x_test
),
*
xr_broadcast
(
x_test
,
exclude
=
(
"a"
,
"b"
)),
*
xr_broadcast
(
x_test
,
exclude
=
(
"b"
,
"a"
)),
*
xr_broadcast
(
x_test
,
exclude
=
(
"b"
,)),
]
for
res
,
expected_res
in
zip
(
results
,
expected_results
,
strict
=
True
):
xr_assert_allclose
(
res
,
expected_res
)
@pytest.mark.parametrize
(
"exclude"
,
[
None
,
[
"b"
],
[
"b"
,
"c"
]])
def
test_broadcast_like
(
self
,
exclude
):
"""Test broadcast_like method"""
# Create test data
x
=
xtensor
(
"x"
,
dims
=
(
"a"
,
"b"
),
shape
=
(
3
,
4
))
y
=
xtensor
(
"y"
,
dims
=
(
"c"
,
"d"
),
shape
=
(
5
,
6
))
z
=
xtensor
(
"z"
,
dims
=
(
"b"
,
"d"
),
shape
=
(
4
,
6
))
# Order matters so we test both orders
outs
=
[
x
.
broadcast_like
(
y
,
exclude
=
exclude
),
y
.
broadcast_like
(
x
,
exclude
=
exclude
),
y
.
broadcast_like
(
z
,
exclude
=
exclude
),
z
.
broadcast_like
(
y
,
exclude
=
exclude
),
]
x_test
=
xr_arange_like
(
x
)
y_test
=
xr_arange_like
(
y
)
z_test
=
xr_arange_like
(
z
)
fn
=
xr_function
([
x
,
y
,
z
],
outs
)
results
=
fn
(
x_test
,
y_test
,
z_test
)
expected_results
=
[
x_test
.
broadcast_like
(
y_test
,
exclude
=
exclude
),
y_test
.
broadcast_like
(
x_test
,
exclude
=
exclude
),
y_test
.
broadcast_like
(
z_test
,
exclude
=
exclude
),
z_test
.
broadcast_like
(
y_test
,
exclude
=
exclude
),
]
for
res
,
expected_res
in
zip
(
results
,
expected_results
,
strict
=
True
):
xr_assert_allclose
(
res
,
expected_res
)
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论