Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
87470065
提交
87470065
authored
2月 07, 2026
作者:
ricardoV94
提交者:
Ricardo Vieira
2月 19, 2026
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Implement dim-aware vectorize_graph
上级
9d99267c
全部展开
隐藏空白字符变更
内嵌
并排
正在显示
11 个修改的文件
包含
265 行增加
和
36 行删除
+265
-36
replace.py
pytensor/graph/replace.py
+7
-0
extra_ops.py
pytensor/tensor/extra_ops.py
+2
-2
basic.py
pytensor/xtensor/basic.py
+19
-11
indexing.py
pytensor/xtensor/indexing.py
+2
-2
reduction.py
pytensor/xtensor/reduction.py
+2
-2
shape.py
pytensor/xtensor/shape.py
+7
-7
vectorization.py
pytensor/xtensor/vectorization.py
+0
-0
test_basic.py
tests/xtensor/test_basic.py
+20
-8
test_shape.py
tests/xtensor/test_shape.py
+3
-3
test_vectorization.py
tests/xtensor/test_vectorization.py
+202
-0
util.py
tests/xtensor/util.py
+1
-1
没有找到文件。
pytensor/graph/replace.py
浏览文件 @
87470065
...
...
@@ -283,6 +283,13 @@ def vectorize_graph(
# [array([-10., -11.]), array([10., 11.])]
"""
# TODO: Move this to tensor.vectorize, and make this helper type agnostic.
#
# This helper may dispatch to tensor.vectorize_graph or xtensor.vectorize_graph depending on the replacement types
# The behavior is distinct, because tensor vectorization depends on axis-position while xtensor depends on dimension labels
#
# xtensor.vectorize_graph will be able to handle batched inner tensor operations, while tensor.vectorize_graph won't,
# as it is by design unaware of xtensors and their semantics.
if
isinstance
(
outputs
,
Sequence
):
seq_outputs
=
outputs
else
:
...
...
pytensor/tensor/extra_ops.py
浏览文件 @
87470065
import
warnings
from
collections.abc
import
Collection
,
Iterable
from
collections.abc
import
Collection
,
Iterable
,
Sequence
from
textwrap
import
dedent
import
numpy
as
np
...
...
@@ -1926,7 +1926,7 @@ def logspace(
def
broadcast_to
(
x
:
Tensor
Variable
,
shape
:
TensorVariable
|
tuple
[
Variable
,
...
]
x
:
Tensor
Like
,
shape
:
TensorLike
|
Sequence
[
TensorLike
]
)
->
TensorVariable
:
"""Broadcast an array to a new shape.
...
...
pytensor/xtensor/basic.py
浏览文件 @
87470065
...
...
@@ -18,7 +18,9 @@ class XOp(Op):
def
do_constant_folding
(
self
,
fgraph
,
node
):
return
False
def
vectorize_node
(
self
,
node
,
*
new_inputs
)
->
Sequence
[
Variable
]:
def
vectorize_node
(
self
,
node
,
*
new_inputs
,
new_dim
:
str
|
None
)
->
Sequence
[
Variable
]:
raise
NotImplementedError
(
f
"Vectorized node not implemented for {self}"
)
...
...
@@ -31,7 +33,9 @@ class XTypeCastOp(TypeCastingOp):
def
infer_shape
(
self
,
fgraph
,
node
,
input_shapes
):
return
input_shapes
def
vectorize_node
(
self
,
node
,
*
new_inputs
)
->
Sequence
[
Variable
]:
def
vectorize_node
(
self
,
node
,
*
new_inputs
,
new_dim
:
str
|
None
)
->
Sequence
[
Variable
]:
raise
NotImplementedError
(
f
"Vectorized node not implemented for {self}"
)
...
...
@@ -49,12 +53,13 @@ class TensorFromXTensor(XTypeCastOp):
[
g_out
]
=
g_outs
return
[
xtensor_from_tensor
(
g_out
,
dims
=
x
.
type
.
dims
)]
def
vectorize_node
(
self
,
node
,
new_x
):
def
vectorize_node
(
self
,
node
,
new_x
,
new_dim
):
[
old_x
]
=
node
.
inputs
if
(
new_x
.
ndim
-
old_x
.
ndim
)
>
1
:
raise
NotImplementedError
(
f
"Vectorization of {self} cannot guarantee correct placement of multiple batch dimensions. "
"You can call vectorize_graph one batch dimension at a time."
"You can call vectorize_graph one batch dimension at a time, "
"or pytensor.xtensor.vectorization.vectorize_graph instead."
)
new_x
=
new_x
.
transpose
(
...
,
*
old_x
.
dims
)
return
[
self
(
new_x
)]
...
...
@@ -80,14 +85,17 @@ class XTensorFromTensor(XTypeCastOp):
[
g_out
]
=
g_outs
return
[
tensor_from_xtensor
(
g_out
)]
def
vectorize_node
(
self
,
node
,
new_x
):
def
vectorize_node
(
self
,
node
,
new_x
,
new_dim
):
[
old_x
]
=
node
.
inputs
if
new_x
.
ndim
!=
old_x
.
ndim
:
raise
NotImplementedError
(
f
"Vectorization of {self} with batched inputs not implemented, "
"as it can't infer new dimension labels"
)
return
[
self
(
new_x
)]
if
new_dim
is
None
:
raise
NotImplementedError
(
f
"Vectorization of {self} cannot infer the new dimension labels. "
"Use pytensor.xtensor.vectorization.vectorize_graph instead."
)
return
[
type
(
self
)(
dims
=
(
new_dim
,
*
self
.
dims
))(
new_x
)]
else
:
return
[
self
(
new_x
)]
def
xtensor_from_tensor
(
x
,
dims
,
name
=
None
):
...
...
@@ -111,7 +119,7 @@ class Rename(XTypeCastOp):
[
g_out
]
=
g_outs
return
[
rename
(
g_out
,
dims
=
x
.
type
.
dims
)]
def
vectorize_node
(
self
,
node
,
new_x
):
def
vectorize_node
(
self
,
node
,
new_x
,
new_dim
):
[
old_x
]
=
node
.
inputs
old_dim_mapping
=
dict
(
zip
(
old_x
.
dims
,
self
.
new_dims
,
strict
=
True
))
...
...
pytensor/xtensor/indexing.py
浏览文件 @
87470065
...
...
@@ -197,7 +197,7 @@ class Index(XOp):
output
=
xtensor
(
dtype
=
x
.
type
.
dtype
,
shape
=
out_shape
,
dims
=
out_dims
)
return
Apply
(
self
,
[
x
,
*
idxs
],
[
output
])
def
vectorize_node
(
self
,
node
,
new_x
,
*
new_idxs
):
def
vectorize_node
(
self
,
node
,
new_x
,
*
new_idxs
,
new_dim
):
# new_x may have dims in different order
# we pair each pre-existing dim to the respective index
# with new dims having simply a slice(None)
...
...
@@ -237,7 +237,7 @@ class IndexUpdate(XOp):
out
=
x
.
type
()
return
Apply
(
self
,
[
x
,
y
,
*
idxs
],
[
out
])
def
vectorize_node
(
self
,
node
,
*
new_inputs
):
def
vectorize_node
(
self
,
node
,
*
new_inputs
,
new_dim
):
# If y or the indices have new dimensions we need to broadcast_x
exclude
:
set
[
str
]
=
set
(
chain
.
from_iterable
(
...
...
pytensor/xtensor/reduction.py
浏览文件 @
87470065
...
...
@@ -46,7 +46,7 @@ class XReduce(XOp):
output
=
xtensor
(
dtype
=
x
.
type
.
dtype
,
shape
=
out_shape
,
dims
=
out_dims
)
return
Apply
(
self
,
[
x
],
[
output
])
def
vectorize_node
(
self
,
node
,
new_x
):
def
vectorize_node
(
self
,
node
,
new_x
,
new_dim
):
return
[
self
(
new_x
)]
...
...
@@ -120,7 +120,7 @@ class XCumReduce(XOp):
out
=
x
.
type
()
return
Apply
(
self
,
[
x
],
[
out
])
def
vectorize_node
(
self
,
node
,
new_x
):
def
vectorize_node
(
self
,
node
,
new_x
,
new_dim
):
return
[
self
(
new_x
)]
...
...
pytensor/xtensor/shape.py
浏览文件 @
87470065
...
...
@@ -68,7 +68,7 @@ class Stack(XOp):
)
return
Apply
(
self
,
[
x
],
[
output
])
def
vectorize_node
(
self
,
node
,
new_x
):
def
vectorize_node
(
self
,
node
,
new_x
,
new_dim
):
return
[
self
(
new_x
)]
...
...
@@ -149,7 +149,7 @@ class UnStack(XOp):
)
return
Apply
(
self
,
[
x
,
*
unstacked_lengths
],
[
output
])
def
vectorize_node
(
self
,
node
,
new_x
,
*
new_unstacked_length
):
def
vectorize_node
(
self
,
node
,
new_x
,
*
new_unstacked_length
,
new_dim
):
new_unstacked_length
=
[
ul
.
squeeze
()
for
ul
in
new_unstacked_length
]
if
not
all
(
ul
.
type
.
ndim
==
0
for
ul
in
new_unstacked_length
):
raise
NotImplementedError
(
...
...
@@ -200,7 +200,7 @@ class Transpose(XOp):
)
return
Apply
(
self
,
[
x
],
[
output
])
def
vectorize_node
(
self
,
node
,
new_x
):
def
vectorize_node
(
self
,
node
,
new_x
,
new_dim
):
old_dims
=
self
.
dims
new_dims
=
tuple
(
dim
for
dim
in
new_x
.
dims
if
dim
not
in
old_dims
)
return
[
type
(
self
)(
dims
=
(
*
new_dims
,
*
old_dims
))(
new_x
)]
...
...
@@ -318,7 +318,7 @@ class Concat(XOp):
output
=
xtensor
(
dtype
=
dtype
,
dims
=
dims
,
shape
=
shape
)
return
Apply
(
self
,
inputs
,
[
output
])
def
vectorize_node
(
self
,
node
,
*
new_inputs
):
def
vectorize_node
(
self
,
node
,
*
new_inputs
,
new_dim
):
return
[
self
(
*
new_inputs
)]
...
...
@@ -402,7 +402,7 @@ class Squeeze(XOp):
)
return
Apply
(
self
,
[
x
],
[
out
])
def
vectorize_node
(
self
,
node
,
new_x
):
def
vectorize_node
(
self
,
node
,
new_x
,
new_dim
):
return
[
self
(
new_x
)]
...
...
@@ -464,7 +464,7 @@ class ExpandDims(XOp):
)
return
Apply
(
self
,
[
x
,
size
],
[
out
])
def
vectorize_node
(
self
,
node
,
new_x
,
new_size
):
def
vectorize_node
(
self
,
node
,
new_x
,
new_size
,
new_dim
):
new_size
=
new_size
.
squeeze
()
if
new_size
.
type
.
ndim
!=
0
:
raise
NotImplementedError
(
...
...
@@ -567,7 +567,7 @@ class Broadcast(XOp):
return
Apply
(
self
,
inputs
,
outputs
)
def
vectorize_node
(
self
,
node
,
*
new_inputs
):
def
vectorize_node
(
self
,
node
,
*
new_inputs
,
new_dim
):
if
exclude_set
:
=
set
(
self
.
exclude
):
for
new_x
,
old_x
in
zip
(
node
.
inputs
,
new_inputs
,
strict
=
True
):
if
invalid_excluded
:
=
(
...
...
pytensor/xtensor/vectorization.py
浏览文件 @
87470065
差异被折叠。
点击展开。
tests/xtensor/test_basic.py
浏览文件 @
87470065
...
...
@@ -3,10 +3,12 @@ import pytest
pytest
.
importorskip
(
"xarray"
)
import
re
import
numpy
as
np
from
pytensor
import
function
from
pytensor.graph
import
vectorize_graph
from
pytensor.graph
import
vectorize_graph
as
tensor_vectorize_graph
from
pytensor.tensor
import
matrix
,
vector
from
pytensor.xtensor.basic
import
(
Rename
,
...
...
@@ -14,10 +16,9 @@ from pytensor.xtensor.basic import (
tensor_from_xtensor
,
xtensor_from_tensor
,
)
from
pytensor.xtensor.type
import
xtensor
from
pytensor.xtensor.type
import
as_xtensor
,
xtensor
from
pytensor.xtensor.vectorization
import
vectorize_graph
from
tests.unittest_tools
import
assert_equal_computations
# from pytensor.xtensor.vectorization import vectorize_graph
from
tests.xtensor.util
import
check_vectorization
...
...
@@ -53,9 +54,15 @@ def test_xtensor_from_tensor_vectorize():
t_batched
=
matrix
(
"t_batched"
)
with
pytest
.
raises
(
NotImplementedError
,
match
=
r"Vectorization of .* not implemented"
NotImplementedError
,
match
=
re
.
escape
(
"cannot infer the new dimension labels. Use pytensor.xtensor.vectorization.vectorize_graph instead."
),
):
vectorize_graph
([
x
],
{
t
:
t_batched
})
tensor_vectorize_graph
(
x
,
{
t
:
t_batched
})
vec_x
=
vectorize_graph
(
x
,
{
t
:
t_batched
},
new_tensor_dims
=
(
"b"
,))
assert_equal_computations
([
vec_x
],
[
as_xtensor
(
t_batched
,
dims
=
(
"b"
,
"a"
))])
def
test_tensor_from_xtensor_vectorize
():
...
...
@@ -64,7 +71,7 @@ def test_tensor_from_xtensor_vectorize():
x_batched
=
xtensor
(
"x"
,
dims
=
(
"a"
,
"b"
),
shape
=
(
3
,
5
))
y_batched
=
vectorize_graph
(
y
,
{
x
:
x_batched
})
y_batched
=
tensor_
vectorize_graph
(
y
,
{
x
:
x_batched
})
# vectorize_graph should place output batch dimension on the left
assert
y_batched
.
type
.
shape
==
(
5
,
3
)
assert_equal_computations
([
y_batched
],
[
x_batched
.
transpose
(
"b"
,
...
)
.
values
])
...
...
@@ -72,4 +79,9 @@ def test_tensor_from_xtensor_vectorize():
x_batched
=
xtensor
(
"x"
,
dims
=
(
"c"
,
"a"
,
"b"
),
shape
=
(
7
,
3
,
5
))
# vectorize_graph can't handle multiple batch dimensions safely
with
pytest
.
raises
(
NotImplementedError
):
vectorize_graph
(
y
,
{
x
:
x_batched
})
tensor_vectorize_graph
(
y
,
{
x
:
x_batched
})
# xtensor vectorize_graph can handle this graph safely
y_batched
=
vectorize_graph
(
y
,
{
x
:
x_batched
})
assert
y_batched
.
type
.
shape
==
(
7
,
5
,
3
)
assert_equal_computations
([
y_batched
],
[
x_batched
.
transpose
(
"c"
,
"b"
,
"a"
)
.
values
])
tests/xtensor/test_shape.py
浏览文件 @
87470065
...
...
@@ -15,7 +15,6 @@ from xarray import full_like as xr_full_like
from
xarray
import
ones_like
as
xr_ones_like
from
xarray
import
zeros_like
as
xr_zeros_like
from
pytensor.graph
import
vectorize_graph
from
pytensor.tensor
import
scalar
,
vector
from
pytensor.xtensor.shape
import
(
broadcast
,
...
...
@@ -27,6 +26,7 @@ from pytensor.xtensor.shape import (
zeros_like
,
)
from
pytensor.xtensor.type
import
as_xtensor
,
xtensor
from
pytensor.xtensor.vectorization
import
vectorize_graph
from
tests.xtensor.util
import
(
check_vectorization
,
xr_arange_like
,
...
...
@@ -874,7 +874,7 @@ def test_expand_dims_batch_length_vectorize():
with
pytest
.
raises
(
NotImplementedError
,
match
=
r"Vectorization of .* not implemented"
):
vectorize_graph
([
y
],
{
x
:
x_batch
,
l
:
l_batch
})
vectorize_graph
([
y
],
{
x
:
x_batch
,
l
:
l_batch
}
,
new_tensor_dims
=
[
"batch"
]
)
def
test_unstack_batch_length_vectorize
():
...
...
@@ -888,4 +888,4 @@ def test_unstack_batch_length_vectorize():
with
pytest
.
raises
(
NotImplementedError
,
match
=
r"Vectorization of .* not implemented"
):
vectorize_graph
([
y
],
{
x
:
x_batch
,
l
:
l_batch
})
vectorize_graph
([
y
],
{
x
:
x_batch
,
l
:
l_batch
}
,
new_tensor_dims
=
[
"batch"
]
)
tests/xtensor/test_vectorization.py
0 → 100644
浏览文件 @
87470065
import
numpy
as
np
import
pytest
from
pytensor.tensor
import
TensorVariable
,
broadcast_to
,
tensor
from
pytensor.xtensor.basic
import
xtensor_from_tensor
from
pytensor.xtensor.type
import
XTensorType
,
as_xtensor
,
xtensor
from
pytensor.xtensor.vectorization
import
vectorize_graph
from
tests.unittest_tools
import
assert_equal_computations
class
TestVectorizeGraph
:
def
test_pure_xtensor_graph
(
self
):
x
=
xtensor
(
"x"
,
dims
=
(
"a"
,))
out
=
x
+
1
x_new
=
xtensor
(
"x_new"
,
dims
=
(
"c"
,
"a"
,
"b"
))
[
out_vec
]
=
vectorize_graph
([
out
],
{
x
:
x_new
})
assert
isinstance
(
out_vec
.
type
,
XTensorType
)
assert
out_vec
.
type
.
dims
==
(
"c"
,
"b"
,
"a"
)
expected
=
x_new
.
transpose
(
"c"
,
"b"
,
"a"
)
+
1
assert_equal_computations
([
out_vec
],
[
expected
])
def
test_pure_tensor_graph
(
self
):
x
=
tensor
(
"x"
,
shape
=
())
out
=
x
+
1
x_new
=
tensor
(
"x_new"
,
shape
=
(
5
,))
[
out_vec
]
=
vectorize_graph
([
out
],
{
x
:
x_new
},
new_tensor_dims
=
[
"b"
])
assert
isinstance
(
out_vec
,
TensorVariable
)
assert
out_vec
.
ndim
==
1
expected
=
x_new
+
1
assert_equal_computations
([
out_vec
],
[
expected
])
def
test_intermediate_tensor_graph
(
self
):
x
=
xtensor
(
"x"
,
dims
=
(
"a"
,))
t
=
x
.
values
# Convert to TensorVariable
t2
=
t
+
np
.
ones
(
1
)
out
=
xtensor_from_tensor
(
t2
,
dims
=
(
"a"
,))
x_new
=
xtensor
(
"x_new"
,
dims
=
(
"a"
,
"b"
))
[
out_vec
]
=
vectorize_graph
([
out
],
{
x
:
x_new
})
assert
isinstance
(
out_vec
.
type
,
XTensorType
)
assert
out_vec
.
type
.
dims
==
(
"b"
,
"a"
)
expected
=
as_xtensor
(
x_new
.
transpose
(
"b"
,
"a"
)
.
values
+
np
.
ones
(
1
),
dims
=
(
"b"
,
"a"
)
)
assert_equal_computations
([
out_vec
],
[
expected
])
def
test_intermediate_tensor_multiple_inputs_graph
(
self
):
x
=
xtensor
(
"x"
,
dims
=
(
"a"
,))
y
=
xtensor
(
"y"
,
dims
=
(
"a"
,))
t
=
x
.
values
+
y
.
values
out
=
xtensor_from_tensor
(
t
,
dims
=
(
"a"
,))
x_new
=
xtensor
(
"x_new"
,
dims
=
(
"a"
,
"c"
))
# Both inputs have the same batch dims
y_new
=
xtensor
(
"y_new"
,
dims
=
(
"c"
,
"a"
))
[
out_vec
]
=
vectorize_graph
([
out
],
{
x
:
x_new
,
y
:
y_new
})
assert
isinstance
(
out_vec
.
type
,
XTensorType
)
assert
out_vec
.
type
.
dims
==
(
"c"
,
"a"
)
expected
=
as_xtensor
(
(
x_new
.
transpose
(
"c"
,
"a"
)
.
values
+
y_new
.
transpose
(
"c"
,
"a"
)
.
values
),
dims
=
(
"c"
,
"a"
),
)
assert_equal_computations
([
out_vec
],
[
expected
])
# Inputs have different batch dims
y_new
=
xtensor
(
"y_new"
,
dims
=
(
"b"
,
"a"
))
[
out_vec
]
=
vectorize_graph
([
out
],
{
x
:
x_new
,
y
:
y_new
})
assert
isinstance
(
out_vec
.
type
,
XTensorType
)
assert
out_vec
.
type
.
dims
==
(
"c"
,
"b"
,
"a"
)
expected
=
as_xtensor
(
(
x_new
.
transpose
(
"c"
,
"a"
)
.
values
[:,
None
]
+
y_new
.
transpose
(
"b"
,
"a"
)
.
values
[
None
,
:]
),
dims
=
(
"c"
,
"b"
,
"a"
),
)
assert_equal_computations
([
out_vec
],
[
expected
])
def
test_intermediate_xtensor_graph
(
self
):
x
=
tensor
(
"x"
,
shape
=
(
3
,))
t
=
as_xtensor
(
x
,
dims
=
(
"a"
,))
t2
=
t
+
1
out
=
t2
.
values
x_new
=
tensor
(
"x_new"
,
shape
=
(
5
,
3
))
[
out_vec
]
=
vectorize_graph
([
out
],
{
x
:
x_new
},
new_tensor_dims
=
[
"b"
])
assert
isinstance
(
out_vec
,
TensorVariable
)
assert
out_vec
.
ndim
==
2
expected
=
(
as_xtensor
(
x_new
,
dims
=
(
"b"
,
"a"
))
+
1
)
.
values
assert_equal_computations
([
out_vec
],
[
expected
])
def
test_mixed_type_inputs
(
self
):
x
=
xtensor
(
"x"
,
dims
=
(
"a"
,),
shape
=
(
3
,))
y
=
tensor
(
"y"
,
shape
=
(
5
,))
out
=
as_xtensor
(
y
[
2
:],
dims
=
(
"b"
,))
+
x
x_new
=
xtensor
(
"x_new"
,
dims
=
(
"a"
,
"d"
),
shape
=
(
3
,
7
))
y_new
=
tensor
(
"y_new"
,
shape
=
(
7
,
5
))
# New dimension of y is aligned with the new dimension of x
[
out_vec
]
=
vectorize_graph
([
out
],
{
x
:
x_new
,
y
:
y_new
},
new_tensor_dims
=
[
"d"
])
assert
isinstance
(
out_vec
.
type
,
XTensorType
)
assert
out_vec
.
type
.
dims
==
(
"d"
,
"b"
,
"a"
)
expected
=
as_xtensor
(
y_new
[:,
2
:],
dims
=
(
"d"
,
"b"
))
+
x_new
.
transpose
(
"d"
,
"a"
)
assert_equal_computations
([
out_vec
],
[
expected
])
# New dimension of y is distinct from that of x
[
out_vec
]
=
vectorize_graph
([
out
],
{
x
:
x_new
,
y
:
y_new
},
new_tensor_dims
=
[
"c"
])
assert
isinstance
(
out_vec
.
type
,
XTensorType
)
assert
out_vec
.
type
.
dims
==
(
"d"
,
"c"
,
"b"
,
"a"
)
# x introduced a new dimension "d" which causes y to be broadcasted
y_broadcasted
=
broadcast_to
(
y_new
,
(
x_new
.
sizes
[
"d"
],
y_new
.
shape
[
0
],
y_new
.
shape
[
1
])
)
expected
=
as_xtensor
(
y_broadcasted
[:,
:,
2
:],
dims
=
(
"d"
,
"c"
,
"b"
)
)
+
x_new
.
transpose
(
"d"
,
"a"
)
assert_equal_computations
([
out_vec
],
[
expected
])
def
test_mixed_type_inputs_complex_broadcasting
(
self
):
a
=
xtensor
(
"a"
,
dims
=
(
"a"
,),
shape
=
(
3
,))
b
=
xtensor
(
"b"
,
dims
=
(
"b"
),
shape
=
(
5
,))
y
=
tensor
(
"y"
,
shape
=
(
7
,))
z
=
tensor
(
"z"
,
shape
=
(
11
,))
out
=
a
+
b
+
y
.
sum
()
+
z
.
sum
()
assert
out
.
dims
==
(
"a"
,
"b"
)
a_new
=
xtensor
(
"a_new"
,
dims
=
(
"a*"
,
"a"
),
shape
=
(
33
,
3
))
b_new
=
xtensor
(
"b_new"
,
dims
=
(
"b*"
,
"b"
),
shape
=
(
55
,
5
))
y_new
=
tensor
(
"y_new"
,
shape
=
(
1
,
55
,
2
,
1
,
7
))
z_new
=
tensor
(
"z_new"
,
shape
=
(
33
,
1
,
1
,
2
,
11
))
[
out_vec
]
=
vectorize_graph
(
[
out
],
{
a
:
a_new
,
b
:
b_new
,
y
:
y_new
,
z
:
z_new
},
new_tensor_dims
=
[
"a*"
,
"b*"
,
"y*"
,
"z*"
],
)
assert
isinstance
(
out_vec
.
type
,
XTensorType
)
assert
out_vec
.
type
.
dims
==
(
"a*"
,
"b*"
,
"y*"
,
"z*"
,
"a"
,
"b"
)
batch_shape_truth
=
(
a_new
.
sizes
[
"a*"
],
b_new
.
sizes
[
"b*"
],
y_new
.
shape
[
2
],
z_new
.
shape
[
3
],
)
y_new_bcast
=
broadcast_to
(
y_new
,
(
*
batch_shape_truth
,
y_new
.
shape
[
4
]))
z_new_bcast
=
broadcast_to
(
z_new
,
(
*
batch_shape_truth
,
z_new
.
shape
[
4
]))
expected_out
=
(
(
a_new
+
b_new
)
+
as_xtensor
(
y_new_bcast
.
sum
(
axis
=-
1
),
dims
=
(
"a*"
,
"b*"
,
"y*"
,
"z*"
))
+
as_xtensor
(
z_new_bcast
.
sum
(
axis
=-
1
),
dims
=
(
"a*"
,
"b*"
,
"y*"
,
"z*"
))
)
.
transpose
(
"a*"
,
"b*"
,
"y*"
,
"z*"
,
...
)
assert_equal_computations
([
out_vec
],
[
expected_out
])
def
test_invalid_cases
(
self
):
x
=
xtensor
(
"x"
,
dims
=
(
"a"
,))
out
=
x
+
1
# Missing xtensor dims
x_bad
=
xtensor
(
"x_bad"
,
dims
=
(
"b"
,))
# Missing "a"
with
pytest
.
raises
(
ValueError
,
match
=
"missing pre-existing dims"
):
vectorize_graph
([
out
],
{
x
:
x_bad
})
# New xtensor dims that were present in original graph
y
=
xtensor
(
"y"
,
dims
=
(
"b"
,))
out2
=
x
+
y
x_new_conflict
=
xtensor
(
"x_new"
,
dims
=
(
"a"
,
"b"
))
# "b" is new to x, but present in graph (in y)
with
pytest
.
raises
(
ValueError
,
match
=
"new dimensions that were present"
):
vectorize_graph
([
out2
],
{
x
:
x_new_conflict
})
# Missing tensor dims
t
=
tensor
(
"t"
,
shape
=
(
3
,))
out_t
=
t
+
1
# Replacement has fewer dims (rank 0)
t_bad_rank
=
tensor
(
"t_bad"
,
shape
=
())
with
pytest
.
raises
(
ValueError
,
match
=
"missing pre-existing dims"
):
vectorize_graph
([
out_t
],
{
t
:
t_bad_rank
})
# Missing new_tensor_dims
t_new
=
tensor
(
"t_new"
,
shape
=
(
5
,
5
,
3
))
with
pytest
.
raises
(
ValueError
,
match
=
"You must specify `new_tensor_dims`"
):
vectorize_graph
([
out_t
],
{
t
:
t_new
})
with
pytest
.
raises
(
ValueError
,
match
=
r"but only .* were specified"
):
vectorize_graph
([
out_t
],
{
t
:
t_new
},
new_tensor_dims
=
[
"a"
])
# Excess new_tensor_dims
# Replacement adds 1 dim, but 2 are specified
t_new_1dim
=
tensor
(
"t_new_1dim"
,
shape
=
(
5
,
3
))
with
pytest
.
raises
(
ValueError
,
match
=
"tensor dims were specified, but only"
):
vectorize_graph
([
out_t
],
{
t
:
t_new_1dim
},
new_tensor_dims
=
[
"a"
,
"b"
])
tests/xtensor/util.py
浏览文件 @
87470065
...
...
@@ -10,8 +10,8 @@ from xarray import DataArray
from
xarray.testing
import
assert_allclose
from
pytensor
import
function
from
pytensor.graph
import
vectorize_graph
from
pytensor.xtensor.type
import
XTensorType
,
as_xtensor
from
pytensor.xtensor.vectorization
import
vectorize_graph
def
xr_function
(
*
args
,
**
kwargs
):
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论