Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
87470065
提交
87470065
authored
2月 07, 2026
作者:
ricardoV94
提交者:
Ricardo Vieira
2月 19, 2026
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Implement dim-aware vectorize_graph
上级
9d99267c
显示空白字符变更
内嵌
并排
正在显示
11 个修改的文件
包含
596 行增加
和
55 行删除
+596
-55
replace.py
pytensor/graph/replace.py
+7
-0
extra_ops.py
pytensor/tensor/extra_ops.py
+2
-2
basic.py
pytensor/xtensor/basic.py
+16
-8
indexing.py
pytensor/xtensor/indexing.py
+2
-2
reduction.py
pytensor/xtensor/reduction.py
+2
-2
shape.py
pytensor/xtensor/shape.py
+7
-7
vectorization.py
pytensor/xtensor/vectorization.py
+334
-22
test_basic.py
tests/xtensor/test_basic.py
+20
-8
test_shape.py
tests/xtensor/test_shape.py
+3
-3
test_vectorization.py
tests/xtensor/test_vectorization.py
+202
-0
util.py
tests/xtensor/util.py
+1
-1
没有找到文件。
pytensor/graph/replace.py
浏览文件 @
87470065
...
...
@@ -283,6 +283,13 @@ def vectorize_graph(
# [array([-10., -11.]), array([10., 11.])]
"""
# TODO: Move this to tensor.vectorize, and make this helper type agnostic.
#
# This helper may dispatch to tensor.vectorize_graph or xtensor.vectorize_graph depending on the replacement types
# The behavior is distinct, because tensor vectorization depends on axis-position while xtensor depends on dimension labels
#
# xtensor.vectorize_graph will be able to handle batched inner tensor operations, while tensor.vectorize_graph won't,
# as it is by design unaware of xtensors and their semantics.
if
isinstance
(
outputs
,
Sequence
):
seq_outputs
=
outputs
else
:
...
...
pytensor/tensor/extra_ops.py
浏览文件 @
87470065
import
warnings
from
collections.abc
import
Collection
,
Iterable
from
collections.abc
import
Collection
,
Iterable
,
Sequence
from
textwrap
import
dedent
import
numpy
as
np
...
...
@@ -1926,7 +1926,7 @@ def logspace(
def
broadcast_to
(
x
:
Tensor
Variable
,
shape
:
TensorVariable
|
tuple
[
Variable
,
...
]
x
:
Tensor
Like
,
shape
:
TensorLike
|
Sequence
[
TensorLike
]
)
->
TensorVariable
:
"""Broadcast an array to a new shape.
...
...
pytensor/xtensor/basic.py
浏览文件 @
87470065
...
...
@@ -18,7 +18,9 @@ class XOp(Op):
def
do_constant_folding
(
self
,
fgraph
,
node
):
return
False
def
vectorize_node
(
self
,
node
,
*
new_inputs
)
->
Sequence
[
Variable
]:
def
vectorize_node
(
self
,
node
,
*
new_inputs
,
new_dim
:
str
|
None
)
->
Sequence
[
Variable
]:
raise
NotImplementedError
(
f
"Vectorized node not implemented for {self}"
)
...
...
@@ -31,7 +33,9 @@ class XTypeCastOp(TypeCastingOp):
def
infer_shape
(
self
,
fgraph
,
node
,
input_shapes
):
return
input_shapes
def
vectorize_node
(
self
,
node
,
*
new_inputs
)
->
Sequence
[
Variable
]:
def
vectorize_node
(
self
,
node
,
*
new_inputs
,
new_dim
:
str
|
None
)
->
Sequence
[
Variable
]:
raise
NotImplementedError
(
f
"Vectorized node not implemented for {self}"
)
...
...
@@ -49,12 +53,13 @@ class TensorFromXTensor(XTypeCastOp):
[
g_out
]
=
g_outs
return
[
xtensor_from_tensor
(
g_out
,
dims
=
x
.
type
.
dims
)]
def
vectorize_node
(
self
,
node
,
new_x
):
def
vectorize_node
(
self
,
node
,
new_x
,
new_dim
):
[
old_x
]
=
node
.
inputs
if
(
new_x
.
ndim
-
old_x
.
ndim
)
>
1
:
raise
NotImplementedError
(
f
"Vectorization of {self} cannot guarantee correct placement of multiple batch dimensions. "
"You can call vectorize_graph one batch dimension at a time."
"You can call vectorize_graph one batch dimension at a time, "
"or pytensor.xtensor.vectorization.vectorize_graph instead."
)
new_x
=
new_x
.
transpose
(
...
,
*
old_x
.
dims
)
return
[
self
(
new_x
)]
...
...
@@ -80,13 +85,16 @@ class XTensorFromTensor(XTypeCastOp):
[
g_out
]
=
g_outs
return
[
tensor_from_xtensor
(
g_out
)]
def
vectorize_node
(
self
,
node
,
new_x
):
def
vectorize_node
(
self
,
node
,
new_x
,
new_dim
):
[
old_x
]
=
node
.
inputs
if
new_x
.
ndim
!=
old_x
.
ndim
:
if
new_dim
is
None
:
raise
NotImplementedError
(
f
"Vectorization of {self} with batched inputs not implemented,
"
"as it can't infer new dimension labels
"
f
"Vectorization of {self} cannot infer the new dimension labels.
"
"Use pytensor.xtensor.vectorization.vectorize_graph instead.
"
)
return
[
type
(
self
)(
dims
=
(
new_dim
,
*
self
.
dims
))(
new_x
)]
else
:
return
[
self
(
new_x
)]
...
...
@@ -111,7 +119,7 @@ class Rename(XTypeCastOp):
[
g_out
]
=
g_outs
return
[
rename
(
g_out
,
dims
=
x
.
type
.
dims
)]
def
vectorize_node
(
self
,
node
,
new_x
):
def
vectorize_node
(
self
,
node
,
new_x
,
new_dim
):
[
old_x
]
=
node
.
inputs
old_dim_mapping
=
dict
(
zip
(
old_x
.
dims
,
self
.
new_dims
,
strict
=
True
))
...
...
pytensor/xtensor/indexing.py
浏览文件 @
87470065
...
...
@@ -197,7 +197,7 @@ class Index(XOp):
output
=
xtensor
(
dtype
=
x
.
type
.
dtype
,
shape
=
out_shape
,
dims
=
out_dims
)
return
Apply
(
self
,
[
x
,
*
idxs
],
[
output
])
def
vectorize_node
(
self
,
node
,
new_x
,
*
new_idxs
):
def
vectorize_node
(
self
,
node
,
new_x
,
*
new_idxs
,
new_dim
):
# new_x may have dims in different order
# we pair each pre-existing dim to the respective index
# with new dims having simply a slice(None)
...
...
@@ -237,7 +237,7 @@ class IndexUpdate(XOp):
out
=
x
.
type
()
return
Apply
(
self
,
[
x
,
y
,
*
idxs
],
[
out
])
def
vectorize_node
(
self
,
node
,
*
new_inputs
):
def
vectorize_node
(
self
,
node
,
*
new_inputs
,
new_dim
):
# If y or the indices have new dimensions we need to broadcast_x
exclude
:
set
[
str
]
=
set
(
chain
.
from_iterable
(
...
...
pytensor/xtensor/reduction.py
浏览文件 @
87470065
...
...
@@ -46,7 +46,7 @@ class XReduce(XOp):
output
=
xtensor
(
dtype
=
x
.
type
.
dtype
,
shape
=
out_shape
,
dims
=
out_dims
)
return
Apply
(
self
,
[
x
],
[
output
])
def
vectorize_node
(
self
,
node
,
new_x
):
def
vectorize_node
(
self
,
node
,
new_x
,
new_dim
):
return
[
self
(
new_x
)]
...
...
@@ -120,7 +120,7 @@ class XCumReduce(XOp):
out
=
x
.
type
()
return
Apply
(
self
,
[
x
],
[
out
])
def
vectorize_node
(
self
,
node
,
new_x
):
def
vectorize_node
(
self
,
node
,
new_x
,
new_dim
):
return
[
self
(
new_x
)]
...
...
pytensor/xtensor/shape.py
浏览文件 @
87470065
...
...
@@ -68,7 +68,7 @@ class Stack(XOp):
)
return
Apply
(
self
,
[
x
],
[
output
])
def
vectorize_node
(
self
,
node
,
new_x
):
def
vectorize_node
(
self
,
node
,
new_x
,
new_dim
):
return
[
self
(
new_x
)]
...
...
@@ -149,7 +149,7 @@ class UnStack(XOp):
)
return
Apply
(
self
,
[
x
,
*
unstacked_lengths
],
[
output
])
def
vectorize_node
(
self
,
node
,
new_x
,
*
new_unstacked_length
):
def
vectorize_node
(
self
,
node
,
new_x
,
*
new_unstacked_length
,
new_dim
):
new_unstacked_length
=
[
ul
.
squeeze
()
for
ul
in
new_unstacked_length
]
if
not
all
(
ul
.
type
.
ndim
==
0
for
ul
in
new_unstacked_length
):
raise
NotImplementedError
(
...
...
@@ -200,7 +200,7 @@ class Transpose(XOp):
)
return
Apply
(
self
,
[
x
],
[
output
])
def
vectorize_node
(
self
,
node
,
new_x
):
def
vectorize_node
(
self
,
node
,
new_x
,
new_dim
):
old_dims
=
self
.
dims
new_dims
=
tuple
(
dim
for
dim
in
new_x
.
dims
if
dim
not
in
old_dims
)
return
[
type
(
self
)(
dims
=
(
*
new_dims
,
*
old_dims
))(
new_x
)]
...
...
@@ -318,7 +318,7 @@ class Concat(XOp):
output
=
xtensor
(
dtype
=
dtype
,
dims
=
dims
,
shape
=
shape
)
return
Apply
(
self
,
inputs
,
[
output
])
def
vectorize_node
(
self
,
node
,
*
new_inputs
):
def
vectorize_node
(
self
,
node
,
*
new_inputs
,
new_dim
):
return
[
self
(
*
new_inputs
)]
...
...
@@ -402,7 +402,7 @@ class Squeeze(XOp):
)
return
Apply
(
self
,
[
x
],
[
out
])
def
vectorize_node
(
self
,
node
,
new_x
):
def
vectorize_node
(
self
,
node
,
new_x
,
new_dim
):
return
[
self
(
new_x
)]
...
...
@@ -464,7 +464,7 @@ class ExpandDims(XOp):
)
return
Apply
(
self
,
[
x
,
size
],
[
out
])
def
vectorize_node
(
self
,
node
,
new_x
,
new_size
):
def
vectorize_node
(
self
,
node
,
new_x
,
new_size
,
new_dim
):
new_size
=
new_size
.
squeeze
()
if
new_size
.
type
.
ndim
!=
0
:
raise
NotImplementedError
(
...
...
@@ -567,7 +567,7 @@ class Broadcast(XOp):
return
Apply
(
self
,
inputs
,
outputs
)
def
vectorize_node
(
self
,
node
,
*
new_inputs
):
def
vectorize_node
(
self
,
node
,
*
new_inputs
,
new_dim
):
if
exclude_set
:
=
set
(
self
.
exclude
):
for
new_x
,
old_x
in
zip
(
node
.
inputs
,
new_inputs
,
strict
=
True
):
if
invalid_excluded
:
=
(
...
...
pytensor/xtensor/vectorization.py
浏览文件 @
87470065
from
collections.abc
import
Sequence
from
collections.abc
import
Mapping
,
Sequence
from
functools
import
singledispatch
from
itertools
import
chain
from
typing
import
Literal
from
typing
import
cast
as
typing_cast
import
numpy
as
np
...
...
@@ -8,18 +11,22 @@ from pytensor import shared
from
pytensor.graph
import
Apply
,
Op
from
pytensor.graph.basic
import
Variable
from
pytensor.graph.replace
import
_vectorize_node
from
pytensor.graph.traversal
import
toposort
,
truncated_graph_inputs
from
pytensor.graph.type
import
HasShape
from
pytensor.scalar
import
discrete_dtypes
from
pytensor.tensor
import
tensor
from
pytensor.tensor
import
(
TensorVariable
,
broadcast_shape
,
broadcast_to
,
tensor
,
)
from
pytensor.tensor.random.op
import
RNGConsumerOp
from
pytensor.tensor.random.type
import
RandomType
from
pytensor.tensor.utils
import
(
get_static_shape_from_size_variables
,
)
from
pytensor.utils
import
unzip
from
pytensor.xtensor.basic
import
(
XOp
,
XTypeCastOp
,
)
from
pytensor.xtensor.basic
import
XOp
,
XTypeCastOp
from
pytensor.xtensor.type
import
XTensorType
,
XTensorVariable
,
as_xtensor
,
xtensor
...
...
@@ -79,7 +86,7 @@ class XElemwise(XOp):
]
return
Apply
(
self
,
inputs
,
outputs
)
def
vectorize_node
(
self
,
node
,
*
new_inputs
):
def
vectorize_node
(
self
,
node
,
*
new_inputs
,
new_dim
):
return
self
(
*
new_inputs
,
return_list
=
True
)
...
...
@@ -149,7 +156,7 @@ class XBlockwise(XOp):
]
return
Apply
(
self
,
inputs
,
outputs
)
def
vectorize_node
(
self
,
node
,
*
new_inputs
):
def
vectorize_node
(
self
,
node
,
*
new_inputs
,
new_dim
):
return
self
(
*
new_inputs
,
return_list
=
True
)
...
...
@@ -300,7 +307,7 @@ class XRV(XOp, RNGConsumerOp):
return
Apply
(
self
,
[
rng
,
*
extra_dim_lengths
,
*
params
],
[
rng
.
type
(),
out
])
def
vectorize_node
(
self
,
node
,
*
new_inputs
):
def
vectorize_node
(
self
,
node
,
*
new_inputs
,
new_dim
):
new_rng
,
*
new_extra_dim_lengths_and_params
=
new_inputs
k
=
len
(
self
.
extra_dims
)
new_extra_dim_lengths
,
new_params
=
(
...
...
@@ -319,24 +326,36 @@ class XRV(XOp, RNGConsumerOp):
@_vectorize_node.register
(
XOp
)
@_vectorize_node.register
(
XTypeCastOp
)
def
vectorize_xop
(
op
:
XOp
,
node
,
*
new_inputs
)
->
Sequence
[
Variable
]:
old_inp_dims
=
[
inp
.
dims
for
inp
in
node
.
inputs
if
isinstance
(
inp
.
type
,
XTensorType
)
]
old_out_dims
=
[
out
.
dims
for
out
in
node
.
outputs
if
isinstance
(
out
.
type
,
XTensorType
)
]
all_old_dims_set
=
set
(
chain
.
from_iterable
((
*
old_inp_dims
,
old_out_dims
)))
for
new_inp
,
old_inp
in
zip
(
new_inputs
,
node
.
inputs
,
strict
=
True
):
def
vectorize_xop
(
op
,
node
,
*
new_inputs
)
->
Sequence
[
Variable
]:
# This gets called by regular graph_replace, which isn't aware of xtensor and doesn't have a concept of `new_dim`
return
_vectorize_xnode
(
node
.
op
,
node
,
*
new_inputs
,
new_dim
=
None
)
@singledispatch
def
_vectorize_xnode
(
op
:
XOp
|
XTypeCastOp
,
node
:
Apply
,
*
batched_inputs
:
Variable
,
new_dim
:
str
|
None
=
None
,
)
->
Sequence
[
Variable
]:
"""Returns vectorized version of node with new batched inputs."""
all_old_dims_set
=
set
(
chain
.
from_iterable
(
x
.
type
.
dims
for
x
in
(
*
node
.
inputs
,
*
node
.
outputs
)
if
isinstance
(
x
.
type
,
XTensorType
)
)
)
for
new_inp
,
old_inp
in
zip
(
batched_inputs
,
node
.
inputs
,
strict
=
True
):
if
not
(
isinstance
(
new_inp
.
type
,
XTensorType
)
and
isinstance
(
old_inp
.
type
,
XTensorType
)
):
continue
old_dims_set
=
set
(
old_inp
.
dims
)
new_dims_set
=
set
(
new_inp
.
dims
)
old_dims_set
=
set
(
old_inp
.
type
.
dims
)
new_dims_set
=
set
(
new_inp
.
type
.
dims
)
# Validate that new inputs didn't drop pre-existing dims
if
missing_dims
:
=
old_dims_set
-
new_dims_set
:
...
...
@@ -349,4 +368,297 @@ def vectorize_xop(op: XOp, node, *new_inputs) -> Sequence[Variable]:
f
"Vectorized input {new_inp} has new dimensions that were present in the original graph: {new_core_dims}"
)
return
op
.
vectorize_node
(
node
,
*
new_inputs
)
return
op
.
vectorize_node
(
node
,
*
batched_inputs
,
new_dim
=
new_dim
)
def
_vectorize_single_dim
(
outputs
,
replace
,
new_dim
:
str
):
inputs
=
truncated_graph_inputs
(
outputs
,
ancestors_to_include
=
replace
.
keys
())
new_inputs
=
[
replace
.
get
(
inp
,
inp
)
for
inp
in
inputs
]
vect_vars
=
dict
(
zip
(
inputs
,
new_inputs
,
strict
=
True
))
for
node
in
toposort
(
outputs
,
blockers
=
inputs
):
vect_inputs
=
[
vect_vars
.
get
(
inp
,
inp
)
for
inp
in
node
.
inputs
]
if
isinstance
(
node
.
op
,
XOp
|
XTypeCastOp
):
node_vect_outs
=
_vectorize_xnode
(
node
.
op
,
node
,
*
vect_inputs
,
new_dim
=
new_dim
)
else
:
node_vect_outs_or_apply
=
_vectorize_node
(
node
.
op
,
node
,
*
vect_inputs
)
# Old API compatibility
node_vect_outs
=
(
node_vect_outs_or_apply
.
outputs
if
isinstance
(
node_vect_outs_or_apply
,
Apply
)
else
node_vect_outs_or_apply
)
for
output
,
vect_output
in
zip
(
node
.
outputs
,
node_vect_outs
,
strict
=
True
):
if
output
in
vect_vars
:
# This can happen when some outputs of a multi-output node are given a replacement,
# while some of the remaining outputs are still needed in the graph.
# We make sure we don't overwrite the provided replacement with the newly vectorized output
continue
vect_vars
[
output
]
=
vect_output
return
[
vect_vars
[
out
]
for
out
in
outputs
]
def
vectorize_graph
(
outputs
:
Variable
|
Sequence
[
Variable
],
replace
:
Mapping
[
Variable
,
Variable
],
*
,
new_tensor_dims
:
Sequence
[
str
]
=
(),
):
"""Dimension-aware vectorize_graph.
This is an extension to :func:`pytensor.graph.replace.vectorize_graph` that correctly handles
mixed XTensor/TensorVariable graphs.
Vectorization rule for batch TensorVariables works like regular ``vectorize_graph``,
with batched axes assumed to be aligned positionally and present on the left of the new inputs.
They must be given labels with ``new_tensor_dims`` argument (left to right),
for correct interaction with XTensorVariables (and even if there are no XTensorVariables in the graph).
Batched XTensorVariables may contain new dimensions anywhere.
These can include dimensions in ``new_tensor_dims``, as well as other new dimensions
implied by the variable's ``dims``. New dimensions for a given input should not have
existed in the original graph.
The vectorized outputs will have the new dimensions on the left.
The order of new dimensions is:
1. New dimensions introduced by XTensorVariables (that are not in ``new_tensor_dims``).
2. Dimensions specified in ``new_tensor_dims``.
Parameters
----------
outputs: Variable or Sequence of Variable
The output variable(s) of the graph to be vectorized.
replace: Mapping of Variable to Variable
A dictionary mapping original variables to their vectorized counterparts.
new_tensor_dims: Sequence of str, optional
A sequence of string labels for the new batch dimensions introduced by ``TensorVariable``
replacements. These dimensions correspond to the leading axes of the new tensor variables.
This argument is required if any ``TensorVariable`` replacements introduce new dimensions.
Returns
-------
vectorized_outputs: Variable or Sequence of Variable
Vectorized output variable(s).
Examples
--------
Vectorize a graph with XTensor variables:
.. testcode:: python
from pytensor.xtensor import xtensor
from pytensor.xtensor.vectorization import vectorize_graph
x = xtensor("x", dims=("a",))
y = xtensor("y", dims=("a",))
out = x + y
# We want to vectorize over new dimensions "c" and "b"
# For XTensor, new dimensions can be anywhere
x_new = xtensor("x_new", dims=("c", "a"))
y_new = xtensor("y_new", dims=("a", "b"))
out_vec = vectorize_graph(out, {x: x_new, y: y_new})
# Output batch dimensions are always on the left
assert out_vec.type.dims == ("c", "b", "a")
Vectorize a graph with standard Tensor variables:
.. testcode:: python
from pytensor.tensor import tensor, TensorVariable
from pytensor.xtensor.vectorization import vectorize_graph
x = tensor("x", shape=(3,))
y = tensor("y", shape=(3,))
out = x + y
# We vectorize over new dimension of "a", and "b".
# These must be on the left and broadcast correctly
x_new = tensor("x_new", shape=(5, 3))
y_new = tensor("y_new", shape=(7, 1, 3))
out_vec = vectorize_graph(out, {x: x_new, y: y_new}, new_tensor_dims=["a", "b"])
assert isinstance(out_vec, TensorVariable)
assert out_vec.type.shape == (7, 5, 3)
Vectorize a mixed graph:
.. testcode:: python
from pytensor.tensor import tensor
from pytensor.xtensor import as_xtensor, xtensor
from pytensor.xtensor.vectorization import vectorize_graph
x = xtensor("x", shape=(5,), dims=("a",))
y = tensor("y", shape=(5,))
out = x + as_xtensor(y, dims=("a",))
# Vectorize over a new dimension "c"
x_new = xtensor("x_new", dims=("a", "c"), shape=(5, 3))
y_new = tensor("y_new", shape=(3, 5)) # Leading dim corresponds to "c" (size 3)
out_vec = vectorize_graph(out, {x: x_new, y: y_new}, new_tensor_dims=["c"])
assert out_vec.type.dims == ("c", "a")
# Treat the new dimension of y_new as being "b" (size 3)
# x_new introduces "c" (size 3)
# Result has XTensor-only new dims first ("c"), then new_tensor_dims ("b")
out_vec = vectorize_graph(out, {x: x_new, y: y_new}, new_tensor_dims=["b"])
assert out_vec.type.dims == ("c", "b", "a")
"""
seq_outputs
=
outputs
if
isinstance
(
outputs
,
Sequence
)
else
(
outputs
,)
if
not
all
(
isinstance
(
key
,
Variable
)
and
isinstance
(
value
,
Variable
)
for
key
,
value
in
replace
.
items
()
):
raise
ValueError
(
f
"Some of the replaced items are not Variables: {replace}"
)
# Collect new dimensions and sizes, and validate
new_xtensor_sizes
:
dict
[
str
,
TensorVariable
]
=
{}
new_tensor_dim_lengths
:
list
[
tuple
[
TensorVariable
|
Literal
[
1
],
...
]]
=
[]
for
old
,
new
in
replace
.
items
():
if
isinstance
(
new
,
XTensorVariable
):
old_var_dims_set
=
set
(
old
.
type
.
dims
)
new_var_dims_set
=
set
(
new
.
type
.
dims
)
if
missing_dims
:
=
old_var_dims_set
-
new_var_dims_set
:
raise
ValueError
(
f
"Vectorized input {new} is missing pre-existing dims: {sorted(missing_dims)}"
)
new_xtensor_sizes
.
update
(
{
d
:
s
for
d
,
s
in
new
.
sizes
.
items
()
if
d
not
in
old_var_dims_set
}
)
elif
isinstance
(
new
,
TensorVariable
):
n_new_dims
=
new
.
type
.
ndim
-
old
.
type
.
ndim
if
n_new_dims
<
0
:
raise
ValueError
(
f
"Vectorized input {new} is missing pre-existing dims {new.type.ndim=}, {old.type.ndim=}"
)
if
n_new_dims
>
len
(
new_tensor_dims
):
if
not
new_tensor_dims
:
raise
ValueError
(
f
"TensorVariable replacement {new} has {n_new_dims} batch dimensions. "
f
"You must specify `new_tensor_dims` to label these."
)
else
:
raise
ValueError
(
f
"TensorVariable replacement {new} has {n_new_dims} batch dimensions "
f
"but only {new_tensor_dims=} were specified. "
)
new_tensor_dim_lengths
.
append
(
tuple
(
1
if
b
else
s
for
s
,
b
in
zip
(
tuple
(
new
.
shape
)[:
n_new_dims
],
new
.
type
.
broadcastable
[:
n_new_dims
],
)
)
)
elif
isinstance
(
new
.
type
,
HasShape
)
and
new
.
type
.
ndim
!=
old
.
type
.
ndim
:
raise
NotImplementedError
(
f
"vectorize_graph does not know how to handle batched input {new} of type {new.type}"
)
# Align xtensor batch dimensions on the left, and broadcast tensor batch dimensions
new_dims
=
(
*
(
dim
for
dim
in
new_xtensor_sizes
if
dim
not
in
new_tensor_dims
),
*
new_tensor_dims
,
)
# Create a mapping from new_tensor_dims -> broadcasted shape from tensors
new_tensor_sizes
:
dict
[
str
,
Variable
]
=
{}
if
new_tensor_dims
:
new_tensor_bcast_dim_lengths
=
broadcast_shape
(
*
new_tensor_dim_lengths
,
arrays_are_shapes
=
True
)
del
new_tensor_dim_lengths
if
len
(
new_tensor_bcast_dim_lengths
)
!=
len
(
new_tensor_dims
):
raise
ValueError
(
f
"{len(new_tensor_dims)} tensor dims were specified, but only {len(new_tensor_bcast_dim_lengths)} were found in the new inputs"
)
new_tensor_sizes
=
dict
(
zip
(
new_tensor_dims
,
new_tensor_bcast_dim_lengths
))
# Give preference to tensor sizes to avoid unnecessary broadcasting (Alloc)
# XTensor sizes are implicitly handled by transpose and dim names, so they don't need strict size equality
new_sizes
=
tuple
(
new_xtensor_sizes
.
get
(
dim
,
new_tensor_sizes
.
get
(
dim
,
1
))
for
dim
in
new_dims
)
# Align batch dimensions on the left (*xtensor_unique_batch_dims, *tensor_batch_dims, ...)
# We broadcast tensor batch dims as they may have been length 1
aligned_replace
=
{}
for
old
,
new
in
replace
.
items
():
if
isinstance
(
new
,
XTensorVariable
):
new
=
new
.
transpose
(
*
new_dims
,
...
,
missing_dims
=
"ignore"
)
elif
isinstance
(
new
,
TensorVariable
):
n_existing_batch_dims
=
new
.
type
.
ndim
-
old
.
type
.
ndim
if
n_existing_batch_dims
<
len
(
new_dims
)
or
any
(
new
.
type
.
broadcastable
[:
len
(
new_dims
)]
):
new
=
broadcast_to
(
new
,
shape
=
(
*
new_sizes
,
*
tuple
(
new
.
shape
)[
n_existing_batch_dims
:]),
)
aligned_replace
[
old
]
=
new
del
replace
seq_vect_outputs
=
seq_outputs
remaining_new_dims
=
list
(
new_dims
)
while
remaining_new_dims
:
new_dim
=
remaining_new_dims
.
pop
()
if
remaining_new_dims
:
# We need to use a dummy inputs to batch graph once at a time
# We drop all the dims that are still in `remaining_new_dims`
# Create a mapping: original -> intermediate_batched
single_dim_replace
=
{}
for
old
,
new
in
aligned_replace
.
items
():
n_remaining_dims
=
len
(
remaining_new_dims
)
if
isinstance
(
new
,
XTensorVariable
):
intermediate_dims
,
intermediate_shape
=
unzip
(
(
(
d
,
s
)
for
d
,
s
in
zip
(
new
.
type
.
dims
,
new
.
type
.
shape
)
if
d
not
in
remaining_new_dims
),
n
=
2
,
)
intermediate_type
=
new
.
type
.
clone
(
dims
=
intermediate_dims
,
shape
=
intermediate_shape
)
elif
isinstance
(
new
,
TensorVariable
):
intermediate_type
=
new
.
type
.
clone
(
shape
=
new
.
type
.
shape
[
n_remaining_dims
:]
)
else
:
intermediate_type
=
new
.
type
single_dim_replace
[
old
]
=
intermediate_type
()
# Updated aligned replace mapping: intermediate_batched -> final_batched
aligned_replace
=
dict
(
zip
(
single_dim_replace
.
values
(),
aligned_replace
.
values
())
)
else
:
single_dim_replace
=
aligned_replace
seq_vect_outputs
=
_vectorize_single_dim
(
seq_vect_outputs
,
single_dim_replace
,
new_dim
)
aligned_seq_vect_outputs
=
[
new
.
transpose
(
*
new_dims
,
*
typing_cast
(
XTensorVariable
,
old
)
.
dims
)
if
isinstance
(
new
,
XTensorVariable
)
else
new
for
new
,
old
in
zip
(
seq_vect_outputs
,
seq_outputs
)
]
return
(
aligned_seq_vect_outputs
if
isinstance
(
outputs
,
Sequence
)
else
aligned_seq_vect_outputs
[
0
]
)
tests/xtensor/test_basic.py
浏览文件 @
87470065
...
...
@@ -3,10 +3,12 @@ import pytest
pytest
.
importorskip
(
"xarray"
)
import
re
import
numpy
as
np
from
pytensor
import
function
from
pytensor.graph
import
vectorize_graph
from
pytensor.graph
import
vectorize_graph
as
tensor_vectorize_graph
from
pytensor.tensor
import
matrix
,
vector
from
pytensor.xtensor.basic
import
(
Rename
,
...
...
@@ -14,10 +16,9 @@ from pytensor.xtensor.basic import (
tensor_from_xtensor
,
xtensor_from_tensor
,
)
from
pytensor.xtensor.type
import
xtensor
from
pytensor.xtensor.type
import
as_xtensor
,
xtensor
from
pytensor.xtensor.vectorization
import
vectorize_graph
from
tests.unittest_tools
import
assert_equal_computations
# from pytensor.xtensor.vectorization import vectorize_graph
from
tests.xtensor.util
import
check_vectorization
...
...
@@ -53,9 +54,15 @@ def test_xtensor_from_tensor_vectorize():
t_batched
=
matrix
(
"t_batched"
)
with
pytest
.
raises
(
NotImplementedError
,
match
=
r"Vectorization of .* not implemented"
NotImplementedError
,
match
=
re
.
escape
(
"cannot infer the new dimension labels. Use pytensor.xtensor.vectorization.vectorize_graph instead."
),
):
vectorize_graph
([
x
],
{
t
:
t_batched
})
tensor_vectorize_graph
(
x
,
{
t
:
t_batched
})
vec_x
=
vectorize_graph
(
x
,
{
t
:
t_batched
},
new_tensor_dims
=
(
"b"
,))
assert_equal_computations
([
vec_x
],
[
as_xtensor
(
t_batched
,
dims
=
(
"b"
,
"a"
))])
def
test_tensor_from_xtensor_vectorize
():
...
...
@@ -64,7 +71,7 @@ def test_tensor_from_xtensor_vectorize():
x_batched
=
xtensor
(
"x"
,
dims
=
(
"a"
,
"b"
),
shape
=
(
3
,
5
))
y_batched
=
vectorize_graph
(
y
,
{
x
:
x_batched
})
y_batched
=
tensor_
vectorize_graph
(
y
,
{
x
:
x_batched
})
# vectorize_graph should place output batch dimension on the left
assert
y_batched
.
type
.
shape
==
(
5
,
3
)
assert_equal_computations
([
y_batched
],
[
x_batched
.
transpose
(
"b"
,
...
)
.
values
])
...
...
@@ -72,4 +79,9 @@ def test_tensor_from_xtensor_vectorize():
x_batched
=
xtensor
(
"x"
,
dims
=
(
"c"
,
"a"
,
"b"
),
shape
=
(
7
,
3
,
5
))
# vectorize_graph can't handle multiple batch dimensions safely
with
pytest
.
raises
(
NotImplementedError
):
vectorize_graph
(
y
,
{
x
:
x_batched
})
tensor_vectorize_graph
(
y
,
{
x
:
x_batched
})
# xtensor vectorize_graph can handle this graph safely
y_batched
=
vectorize_graph
(
y
,
{
x
:
x_batched
})
assert
y_batched
.
type
.
shape
==
(
7
,
5
,
3
)
assert_equal_computations
([
y_batched
],
[
x_batched
.
transpose
(
"c"
,
"b"
,
"a"
)
.
values
])
tests/xtensor/test_shape.py
浏览文件 @
87470065
...
...
@@ -15,7 +15,6 @@ from xarray import full_like as xr_full_like
from
xarray
import
ones_like
as
xr_ones_like
from
xarray
import
zeros_like
as
xr_zeros_like
from
pytensor.graph
import
vectorize_graph
from
pytensor.tensor
import
scalar
,
vector
from
pytensor.xtensor.shape
import
(
broadcast
,
...
...
@@ -27,6 +26,7 @@ from pytensor.xtensor.shape import (
zeros_like
,
)
from
pytensor.xtensor.type
import
as_xtensor
,
xtensor
from
pytensor.xtensor.vectorization
import
vectorize_graph
from
tests.xtensor.util
import
(
check_vectorization
,
xr_arange_like
,
...
...
@@ -874,7 +874,7 @@ def test_expand_dims_batch_length_vectorize():
with
pytest
.
raises
(
NotImplementedError
,
match
=
r"Vectorization of .* not implemented"
):
vectorize_graph
([
y
],
{
x
:
x_batch
,
l
:
l_batch
})
vectorize_graph
([
y
],
{
x
:
x_batch
,
l
:
l_batch
}
,
new_tensor_dims
=
[
"batch"
]
)
def
test_unstack_batch_length_vectorize
():
...
...
@@ -888,4 +888,4 @@ def test_unstack_batch_length_vectorize():
with
pytest
.
raises
(
NotImplementedError
,
match
=
r"Vectorization of .* not implemented"
):
vectorize_graph
([
y
],
{
x
:
x_batch
,
l
:
l_batch
})
vectorize_graph
([
y
],
{
x
:
x_batch
,
l
:
l_batch
}
,
new_tensor_dims
=
[
"batch"
]
)
tests/xtensor/test_vectorization.py
0 → 100644
浏览文件 @
87470065
import
numpy
as
np
import
pytest
from
pytensor.tensor
import
TensorVariable
,
broadcast_to
,
tensor
from
pytensor.xtensor.basic
import
xtensor_from_tensor
from
pytensor.xtensor.type
import
XTensorType
,
as_xtensor
,
xtensor
from
pytensor.xtensor.vectorization
import
vectorize_graph
from
tests.unittest_tools
import
assert_equal_computations
class
TestVectorizeGraph
:
def
test_pure_xtensor_graph
(
self
):
x
=
xtensor
(
"x"
,
dims
=
(
"a"
,))
out
=
x
+
1
x_new
=
xtensor
(
"x_new"
,
dims
=
(
"c"
,
"a"
,
"b"
))
[
out_vec
]
=
vectorize_graph
([
out
],
{
x
:
x_new
})
assert
isinstance
(
out_vec
.
type
,
XTensorType
)
assert
out_vec
.
type
.
dims
==
(
"c"
,
"b"
,
"a"
)
expected
=
x_new
.
transpose
(
"c"
,
"b"
,
"a"
)
+
1
assert_equal_computations
([
out_vec
],
[
expected
])
def
test_pure_tensor_graph
(
self
):
x
=
tensor
(
"x"
,
shape
=
())
out
=
x
+
1
x_new
=
tensor
(
"x_new"
,
shape
=
(
5
,))
[
out_vec
]
=
vectorize_graph
([
out
],
{
x
:
x_new
},
new_tensor_dims
=
[
"b"
])
assert
isinstance
(
out_vec
,
TensorVariable
)
assert
out_vec
.
ndim
==
1
expected
=
x_new
+
1
assert_equal_computations
([
out_vec
],
[
expected
])
def
test_intermediate_tensor_graph
(
self
):
x
=
xtensor
(
"x"
,
dims
=
(
"a"
,))
t
=
x
.
values
# Convert to TensorVariable
t2
=
t
+
np
.
ones
(
1
)
out
=
xtensor_from_tensor
(
t2
,
dims
=
(
"a"
,))
x_new
=
xtensor
(
"x_new"
,
dims
=
(
"a"
,
"b"
))
[
out_vec
]
=
vectorize_graph
([
out
],
{
x
:
x_new
})
assert
isinstance
(
out_vec
.
type
,
XTensorType
)
assert
out_vec
.
type
.
dims
==
(
"b"
,
"a"
)
expected
=
as_xtensor
(
x_new
.
transpose
(
"b"
,
"a"
)
.
values
+
np
.
ones
(
1
),
dims
=
(
"b"
,
"a"
)
)
assert_equal_computations
([
out_vec
],
[
expected
])
def
test_intermediate_tensor_multiple_inputs_graph
(
self
):
x
=
xtensor
(
"x"
,
dims
=
(
"a"
,))
y
=
xtensor
(
"y"
,
dims
=
(
"a"
,))
t
=
x
.
values
+
y
.
values
out
=
xtensor_from_tensor
(
t
,
dims
=
(
"a"
,))
x_new
=
xtensor
(
"x_new"
,
dims
=
(
"a"
,
"c"
))
# Both inputs have the same batch dims
y_new
=
xtensor
(
"y_new"
,
dims
=
(
"c"
,
"a"
))
[
out_vec
]
=
vectorize_graph
([
out
],
{
x
:
x_new
,
y
:
y_new
})
assert
isinstance
(
out_vec
.
type
,
XTensorType
)
assert
out_vec
.
type
.
dims
==
(
"c"
,
"a"
)
expected
=
as_xtensor
(
(
x_new
.
transpose
(
"c"
,
"a"
)
.
values
+
y_new
.
transpose
(
"c"
,
"a"
)
.
values
),
dims
=
(
"c"
,
"a"
),
)
assert_equal_computations
([
out_vec
],
[
expected
])
# Inputs have different batch dims
y_new
=
xtensor
(
"y_new"
,
dims
=
(
"b"
,
"a"
))
[
out_vec
]
=
vectorize_graph
([
out
],
{
x
:
x_new
,
y
:
y_new
})
assert
isinstance
(
out_vec
.
type
,
XTensorType
)
assert
out_vec
.
type
.
dims
==
(
"c"
,
"b"
,
"a"
)
expected
=
as_xtensor
(
(
x_new
.
transpose
(
"c"
,
"a"
)
.
values
[:,
None
]
+
y_new
.
transpose
(
"b"
,
"a"
)
.
values
[
None
,
:]
),
dims
=
(
"c"
,
"b"
,
"a"
),
)
assert_equal_computations
([
out_vec
],
[
expected
])
def
test_intermediate_xtensor_graph
(
self
):
x
=
tensor
(
"x"
,
shape
=
(
3
,))
t
=
as_xtensor
(
x
,
dims
=
(
"a"
,))
t2
=
t
+
1
out
=
t2
.
values
x_new
=
tensor
(
"x_new"
,
shape
=
(
5
,
3
))
[
out_vec
]
=
vectorize_graph
([
out
],
{
x
:
x_new
},
new_tensor_dims
=
[
"b"
])
assert
isinstance
(
out_vec
,
TensorVariable
)
assert
out_vec
.
ndim
==
2
expected
=
(
as_xtensor
(
x_new
,
dims
=
(
"b"
,
"a"
))
+
1
)
.
values
assert_equal_computations
([
out_vec
],
[
expected
])
def
test_mixed_type_inputs
(
self
):
x
=
xtensor
(
"x"
,
dims
=
(
"a"
,),
shape
=
(
3
,))
y
=
tensor
(
"y"
,
shape
=
(
5
,))
out
=
as_xtensor
(
y
[
2
:],
dims
=
(
"b"
,))
+
x
x_new
=
xtensor
(
"x_new"
,
dims
=
(
"a"
,
"d"
),
shape
=
(
3
,
7
))
y_new
=
tensor
(
"y_new"
,
shape
=
(
7
,
5
))
# New dimension of y is aligned with the new dimension of x
[
out_vec
]
=
vectorize_graph
([
out
],
{
x
:
x_new
,
y
:
y_new
},
new_tensor_dims
=
[
"d"
])
assert
isinstance
(
out_vec
.
type
,
XTensorType
)
assert
out_vec
.
type
.
dims
==
(
"d"
,
"b"
,
"a"
)
expected
=
as_xtensor
(
y_new
[:,
2
:],
dims
=
(
"d"
,
"b"
))
+
x_new
.
transpose
(
"d"
,
"a"
)
assert_equal_computations
([
out_vec
],
[
expected
])
# New dimension of y is distinct from that of x
[
out_vec
]
=
vectorize_graph
([
out
],
{
x
:
x_new
,
y
:
y_new
},
new_tensor_dims
=
[
"c"
])
assert
isinstance
(
out_vec
.
type
,
XTensorType
)
assert
out_vec
.
type
.
dims
==
(
"d"
,
"c"
,
"b"
,
"a"
)
# x introduced a new dimension "d" which causes y to be broadcasted
y_broadcasted
=
broadcast_to
(
y_new
,
(
x_new
.
sizes
[
"d"
],
y_new
.
shape
[
0
],
y_new
.
shape
[
1
])
)
expected
=
as_xtensor
(
y_broadcasted
[:,
:,
2
:],
dims
=
(
"d"
,
"c"
,
"b"
)
)
+
x_new
.
transpose
(
"d"
,
"a"
)
assert_equal_computations
([
out_vec
],
[
expected
])
def
test_mixed_type_inputs_complex_broadcasting
(
self
):
a
=
xtensor
(
"a"
,
dims
=
(
"a"
,),
shape
=
(
3
,))
b
=
xtensor
(
"b"
,
dims
=
(
"b"
),
shape
=
(
5
,))
y
=
tensor
(
"y"
,
shape
=
(
7
,))
z
=
tensor
(
"z"
,
shape
=
(
11
,))
out
=
a
+
b
+
y
.
sum
()
+
z
.
sum
()
assert
out
.
dims
==
(
"a"
,
"b"
)
a_new
=
xtensor
(
"a_new"
,
dims
=
(
"a*"
,
"a"
),
shape
=
(
33
,
3
))
b_new
=
xtensor
(
"b_new"
,
dims
=
(
"b*"
,
"b"
),
shape
=
(
55
,
5
))
y_new
=
tensor
(
"y_new"
,
shape
=
(
1
,
55
,
2
,
1
,
7
))
z_new
=
tensor
(
"z_new"
,
shape
=
(
33
,
1
,
1
,
2
,
11
))
[
out_vec
]
=
vectorize_graph
(
[
out
],
{
a
:
a_new
,
b
:
b_new
,
y
:
y_new
,
z
:
z_new
},
new_tensor_dims
=
[
"a*"
,
"b*"
,
"y*"
,
"z*"
],
)
assert
isinstance
(
out_vec
.
type
,
XTensorType
)
assert
out_vec
.
type
.
dims
==
(
"a*"
,
"b*"
,
"y*"
,
"z*"
,
"a"
,
"b"
)
batch_shape_truth
=
(
a_new
.
sizes
[
"a*"
],
b_new
.
sizes
[
"b*"
],
y_new
.
shape
[
2
],
z_new
.
shape
[
3
],
)
y_new_bcast
=
broadcast_to
(
y_new
,
(
*
batch_shape_truth
,
y_new
.
shape
[
4
]))
z_new_bcast
=
broadcast_to
(
z_new
,
(
*
batch_shape_truth
,
z_new
.
shape
[
4
]))
expected_out
=
(
(
a_new
+
b_new
)
+
as_xtensor
(
y_new_bcast
.
sum
(
axis
=-
1
),
dims
=
(
"a*"
,
"b*"
,
"y*"
,
"z*"
))
+
as_xtensor
(
z_new_bcast
.
sum
(
axis
=-
1
),
dims
=
(
"a*"
,
"b*"
,
"y*"
,
"z*"
))
)
.
transpose
(
"a*"
,
"b*"
,
"y*"
,
"z*"
,
...
)
assert_equal_computations
([
out_vec
],
[
expected_out
])
def
test_invalid_cases
(
self
):
x
=
xtensor
(
"x"
,
dims
=
(
"a"
,))
out
=
x
+
1
# Missing xtensor dims
x_bad
=
xtensor
(
"x_bad"
,
dims
=
(
"b"
,))
# Missing "a"
with
pytest
.
raises
(
ValueError
,
match
=
"missing pre-existing dims"
):
vectorize_graph
([
out
],
{
x
:
x_bad
})
# New xtensor dims that were present in original graph
y
=
xtensor
(
"y"
,
dims
=
(
"b"
,))
out2
=
x
+
y
x_new_conflict
=
xtensor
(
"x_new"
,
dims
=
(
"a"
,
"b"
))
# "b" is new to x, but present in graph (in y)
with
pytest
.
raises
(
ValueError
,
match
=
"new dimensions that were present"
):
vectorize_graph
([
out2
],
{
x
:
x_new_conflict
})
# Missing tensor dims
t
=
tensor
(
"t"
,
shape
=
(
3
,))
out_t
=
t
+
1
# Replacement has fewer dims (rank 0)
t_bad_rank
=
tensor
(
"t_bad"
,
shape
=
())
with
pytest
.
raises
(
ValueError
,
match
=
"missing pre-existing dims"
):
vectorize_graph
([
out_t
],
{
t
:
t_bad_rank
})
# Missing new_tensor_dims
t_new
=
tensor
(
"t_new"
,
shape
=
(
5
,
5
,
3
))
with
pytest
.
raises
(
ValueError
,
match
=
"You must specify `new_tensor_dims`"
):
vectorize_graph
([
out_t
],
{
t
:
t_new
})
with
pytest
.
raises
(
ValueError
,
match
=
r"but only .* were specified"
):
vectorize_graph
([
out_t
],
{
t
:
t_new
},
new_tensor_dims
=
[
"a"
])
# Excess new_tensor_dims
# Replacement adds 1 dim, but 2 are specified
t_new_1dim
=
tensor
(
"t_new_1dim"
,
shape
=
(
5
,
3
))
with
pytest
.
raises
(
ValueError
,
match
=
"tensor dims were specified, but only"
):
vectorize_graph
([
out_t
],
{
t
:
t_new_1dim
},
new_tensor_dims
=
[
"a"
,
"b"
])
tests/xtensor/util.py
浏览文件 @
87470065
...
...
@@ -10,8 +10,8 @@ from xarray import DataArray
from
xarray.testing
import
assert_allclose
from
pytensor
import
function
from
pytensor.graph
import
vectorize_graph
from
pytensor.xtensor.type
import
XTensorType
,
as_xtensor
from
pytensor.xtensor.vectorization
import
vectorize_graph
def
xr_function
(
*
args
,
**
kwargs
):
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论