Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
9d99267c
提交
9d99267c
authored
2月 05, 2026
作者:
ricardoV94
提交者:
Ricardo Vieira
2月 19, 2026
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Implement vectorize_node for XOps
上级
9dd929ab
隐藏空白字符变更
内嵌
并排
正在显示
16 个修改的文件
包含
522 行增加
和
13 行删除
+522
-13
pyproject.toml
pyproject.toml
+1
-1
basic.py
pytensor/xtensor/basic.py
+36
-0
indexing.py
pytensor/xtensor/indexing.py
+34
-0
reduction.py
pytensor/xtensor/reduction.py
+6
-0
shape.py
pytensor/xtensor/shape.py
+43
-0
type.py
pytensor/xtensor/type.py
+1
-1
vectorization.py
pytensor/xtensor/vectorization.py
+64
-2
test_basic.py
tests/xtensor/test_basic.py
+50
-1
test_indexing.py
tests/xtensor/test_indexing.py
+41
-0
test_linalg.py
tests/xtensor/test_linalg.py
+20
-1
test_math.py
tests/xtensor/test_math.py
+14
-1
test_random.py
tests/xtensor/test_random.py
+26
-0
test_reduction.py
tests/xtensor/test_reduction.py
+19
-1
test_shape.py
tests/xtensor/test_shape.py
+91
-2
test_signal.py
tests/xtensor/test_signal.py
+17
-1
util.py
tests/xtensor/util.py
+59
-2
没有找到文件。
pyproject.toml
浏览文件 @
9d99267c
...
...
@@ -163,7 +163,7 @@ lines-after-imports = 2
"tests/link/numba/**/test_*.py"
=
["E402"]
"tests/link/pytorch/**/test_*.py"
=
["E402"]
"tests/link/mlx/**/test_*.py"
=
["E402"]
"tests/xtensor/**/
test_
*.py"
=
["E402"]
"tests/xtensor/**/*.py"
=
["E402"]
...
...
pytensor/xtensor/basic.py
浏览文件 @
9d99267c
...
...
@@ -2,6 +2,7 @@ from collections.abc import Sequence
from
pytensor.compile.ops
import
TypeCastingOp
from
pytensor.graph
import
Apply
,
Op
from
pytensor.graph.basic
import
Variable
from
pytensor.tensor.type
import
TensorType
from
pytensor.xtensor.type
import
XTensorType
,
as_xtensor
,
xtensor
...
...
@@ -17,6 +18,9 @@ class XOp(Op):
def
do_constant_folding
(
self
,
fgraph
,
node
):
return
False
def
vectorize_node
(
self
,
node
,
*
new_inputs
)
->
Sequence
[
Variable
]:
raise
NotImplementedError
(
f
"Vectorized node not implemented for {self}"
)
class
XTypeCastOp
(
TypeCastingOp
):
"""Base class for Ops that type cast between TensorType and XTensorType.
...
...
@@ -27,6 +31,9 @@ class XTypeCastOp(TypeCastingOp):
def
infer_shape
(
self
,
fgraph
,
node
,
input_shapes
):
return
input_shapes
def
vectorize_node
(
self
,
node
,
*
new_inputs
)
->
Sequence
[
Variable
]:
raise
NotImplementedError
(
f
"Vectorized node not implemented for {self}"
)
class
TensorFromXTensor
(
XTypeCastOp
):
__props__
=
()
...
...
@@ -42,6 +49,16 @@ class TensorFromXTensor(XTypeCastOp):
[
g_out
]
=
g_outs
return
[
xtensor_from_tensor
(
g_out
,
dims
=
x
.
type
.
dims
)]
def
vectorize_node
(
self
,
node
,
new_x
):
[
old_x
]
=
node
.
inputs
if
(
new_x
.
ndim
-
old_x
.
ndim
)
>
1
:
raise
NotImplementedError
(
f
"Vectorization of {self} cannot guarantee correct placement of multiple batch dimensions. "
"You can call vectorize_graph one batch dimension at a time."
)
new_x
=
new_x
.
transpose
(
...
,
*
old_x
.
dims
)
return
[
self
(
new_x
)]
tensor_from_xtensor
=
TensorFromXTensor
()
...
...
@@ -63,6 +80,15 @@ class XTensorFromTensor(XTypeCastOp):
[
g_out
]
=
g_outs
return
[
tensor_from_xtensor
(
g_out
)]
def
vectorize_node
(
self
,
node
,
new_x
):
[
old_x
]
=
node
.
inputs
if
new_x
.
ndim
!=
old_x
.
ndim
:
raise
NotImplementedError
(
f
"Vectorization of {self} with batched inputs not implemented, "
"as it can't infer new dimension labels"
)
return
[
self
(
new_x
)]
def
xtensor_from_tensor
(
x
,
dims
,
name
=
None
):
return
XTensorFromTensor
(
dims
=
dims
)(
x
,
name
=
name
)
...
...
@@ -85,6 +111,16 @@ class Rename(XTypeCastOp):
[
g_out
]
=
g_outs
return
[
rename
(
g_out
,
dims
=
x
.
type
.
dims
)]
def
vectorize_node
(
self
,
node
,
new_x
):
[
old_x
]
=
node
.
inputs
old_dim_mapping
=
dict
(
zip
(
old_x
.
dims
,
self
.
new_dims
,
strict
=
True
))
# new_dims may include a mix of old dims (possibly re-ordered), and new dims which won't be renamed
new_dims
=
tuple
(
old_dim_mapping
.
get
(
new_dim
,
new_dim
)
for
new_dim
in
new_x
.
dims
)
return
[
type
(
self
)(
new_dims
)(
new_x
)]
def
rename
(
x
,
name_dict
:
dict
[
str
,
str
]
|
None
=
None
,
**
names
:
str
):
if
name_dict
is
not
None
:
...
...
pytensor/xtensor/indexing.py
浏览文件 @
9d99267c
...
...
@@ -4,6 +4,7 @@
# https://numpy.org/neps/nep-0021-advanced-indexing.html
# https://docs.xarray.dev/en/latest/user-guide/indexing.html
# https://tutorial.xarray.dev/intermediate/indexing/advanced-indexing.html
from
itertools
import
chain
from
typing
import
Literal
from
pytensor.graph.basic
import
Apply
,
Constant
,
Variable
...
...
@@ -11,6 +12,7 @@ from pytensor.scalar.basic import discrete_dtypes
from
pytensor.tensor.basic
import
as_tensor
from
pytensor.tensor.type_other
import
NoneTypeT
,
SliceType
,
make_slice
from
pytensor.xtensor.basic
import
XOp
,
xtensor_from_tensor
from
pytensor.xtensor.shape
import
broadcast
from
pytensor.xtensor.type
import
XTensorType
,
as_xtensor
,
xtensor
...
...
@@ -195,6 +197,15 @@ class Index(XOp):
output
=
xtensor
(
dtype
=
x
.
type
.
dtype
,
shape
=
out_shape
,
dims
=
out_dims
)
return
Apply
(
self
,
[
x
,
*
idxs
],
[
output
])
def
vectorize_node
(
self
,
node
,
new_x
,
*
new_idxs
):
# new_x may have dims in different order
# we pair each pre-existing dim to the respective index
# with new dims having simply a slice(None)
old_x
,
*
_
=
node
.
inputs
dims_to_idxs
=
dict
(
zip
(
old_x
.
dims
,
new_idxs
,
strict
=
False
))
new_idxs
=
tuple
(
dims_to_idxs
.
get
(
dim
,
slice
(
None
))
for
dim
in
new_x
.
dims
)
return
[
self
(
new_x
,
*
new_idxs
)]
index
=
Index
()
...
...
@@ -226,6 +237,29 @@ class IndexUpdate(XOp):
out
=
x
.
type
()
return
Apply
(
self
,
[
x
,
y
,
*
idxs
],
[
out
])
def
vectorize_node
(
self
,
node
,
*
new_inputs
):
# If y or the indices have new dimensions we need to broadcast_x
exclude
:
set
[
str
]
=
set
(
chain
.
from_iterable
(
old_inp
.
dims
for
old_inp
in
node
.
inputs
if
isinstance
(
old_inp
.
type
,
XTensorType
)
)
)
old_x
,
*
_
=
node
.
inputs
new_x
,
*
_
=
broadcast
(
*
[
new_inp
for
new_inp
in
new_inputs
if
isinstance
(
new_inp
.
type
,
XTensorType
)
],
exclude
=
tuple
(
exclude
),
)
# New batch dimensions must go on the right since indices map to indexed dimensions positionally in the Op
new_x
=
new_x
.
transpose
(
*
old_x
.
dims
,
...
)
_
,
new_y
,
*
new_idxs
=
new_inputs
return
[
self
(
new_x
,
new_y
,
*
new_idxs
)]
index_assignment
=
IndexUpdate
(
"set"
)
index_increment
=
IndexUpdate
(
"inc"
)
pytensor/xtensor/reduction.py
浏览文件 @
9d99267c
...
...
@@ -46,6 +46,9 @@ class XReduce(XOp):
output
=
xtensor
(
dtype
=
x
.
type
.
dtype
,
shape
=
out_shape
,
dims
=
out_dims
)
return
Apply
(
self
,
[
x
],
[
output
])
def
vectorize_node
(
self
,
node
,
new_x
):
return
[
self
(
new_x
)]
def
_process_user_dims
(
x
:
XTensorVariable
,
dim
:
REDUCE_DIM
)
->
Sequence
[
str
]:
if
isinstance
(
dim
,
str
):
...
...
@@ -117,6 +120,9 @@ class XCumReduce(XOp):
out
=
x
.
type
()
return
Apply
(
self
,
[
x
],
[
out
])
def
vectorize_node
(
self
,
node
,
new_x
):
return
[
self
(
new_x
)]
def
cumreduce
(
x
,
dim
:
REDUCE_DIM
,
*
,
binary_op
):
x
=
as_xtensor
(
x
)
...
...
pytensor/xtensor/shape.py
浏览文件 @
9d99267c
...
...
@@ -68,6 +68,9 @@ class Stack(XOp):
)
return
Apply
(
self
,
[
x
],
[
output
])
def
vectorize_node
(
self
,
node
,
new_x
):
return
[
self
(
new_x
)]
def
stack
(
x
,
dim
:
dict
[
str
,
Sequence
[
str
]]
|
None
=
None
,
**
dims
:
Sequence
[
str
]):
if
dim
is
not
None
:
...
...
@@ -146,6 +149,14 @@ class UnStack(XOp):
)
return
Apply
(
self
,
[
x
,
*
unstacked_lengths
],
[
output
])
def
vectorize_node
(
self
,
node
,
new_x
,
*
new_unstacked_length
):
new_unstacked_length
=
[
ul
.
squeeze
()
for
ul
in
new_unstacked_length
]
if
not
all
(
ul
.
type
.
ndim
==
0
for
ul
in
new_unstacked_length
):
raise
NotImplementedError
(
f
"Vectorization of {self} with batched unstacked_length not implemented, "
)
return
[
self
(
new_x
,
*
new_unstacked_length
)]
def
unstack
(
x
,
dim
:
dict
[
str
,
dict
[
str
,
int
]]
|
None
=
None
,
**
dims
:
dict
[
str
,
int
]):
if
dim
is
not
None
:
...
...
@@ -189,6 +200,11 @@ class Transpose(XOp):
)
return
Apply
(
self
,
[
x
],
[
output
])
def
vectorize_node
(
self
,
node
,
new_x
):
old_dims
=
self
.
dims
new_dims
=
tuple
(
dim
for
dim
in
new_x
.
dims
if
dim
not
in
old_dims
)
return
[
type
(
self
)(
dims
=
(
*
new_dims
,
*
old_dims
))(
new_x
)]
def
transpose
(
x
,
...
...
@@ -302,6 +318,9 @@ class Concat(XOp):
output
=
xtensor
(
dtype
=
dtype
,
dims
=
dims
,
shape
=
shape
)
return
Apply
(
self
,
inputs
,
[
output
])
def
vectorize_node
(
self
,
node
,
*
new_inputs
):
return
[
self
(
*
new_inputs
)]
def
concat
(
xtensors
,
dim
:
str
):
"""Concatenate a sequence of XTensorVariables along a specified dimension.
...
...
@@ -383,6 +402,9 @@ class Squeeze(XOp):
)
return
Apply
(
self
,
[
x
],
[
out
])
def
vectorize_node
(
self
,
node
,
new_x
):
return
[
self
(
new_x
)]
def
squeeze
(
x
,
dim
:
str
|
Sequence
[
str
]
|
None
=
None
):
"""Remove dimensions of size 1 from an XTensorVariable."""
...
...
@@ -442,6 +464,14 @@ class ExpandDims(XOp):
)
return
Apply
(
self
,
[
x
,
size
],
[
out
])
def
vectorize_node
(
self
,
node
,
new_x
,
new_size
):
new_size
=
new_size
.
squeeze
()
if
new_size
.
type
.
ndim
!=
0
:
raise
NotImplementedError
(
f
"Vectorization of {self} with batched new_size not implemented, "
)
return
[
self
(
new_x
,
new_size
)]
def
expand_dims
(
x
,
dim
=
None
,
axis
=
None
,
**
dim_kwargs
):
"""Add one or more new dimensions to an XTensorVariable."""
...
...
@@ -537,6 +567,19 @@ class Broadcast(XOp):
return
Apply
(
self
,
inputs
,
outputs
)
def
vectorize_node
(
self
,
node
,
*
new_inputs
):
if
exclude_set
:
=
set
(
self
.
exclude
):
for
new_x
,
old_x
in
zip
(
node
.
inputs
,
new_inputs
,
strict
=
True
):
if
invalid_excluded
:
=
(
(
set
(
new_x
.
dims
)
-
set
(
old_x
.
dims
))
&
exclude_set
):
raise
NotImplementedError
(
f
"Vectorize of {self} is undefined because one of the inputs {new_x} "
f
"has an excluded dimension {sorted(invalid_excluded)} that it did not have before."
)
return
self
(
*
new_inputs
,
return_list
=
True
)
def
broadcast
(
*
args
,
exclude
:
str
|
Sequence
[
str
]
|
None
=
None
...
...
pytensor/xtensor/type.py
浏览文件 @
9d99267c
...
...
@@ -1044,7 +1044,7 @@ def as_xtensor(x, dims: Sequence[str] | None = None, *, name: str | None = None)
if
isinstance
(
x
,
Variable
):
if
isinstance
(
x
.
type
,
XTensorType
):
if
(
dims
is
None
)
or
(
x
.
type
.
dims
==
dims
):
if
(
dims
is
None
)
or
(
x
.
type
.
dims
==
tuple
(
dims
)
):
return
x
else
:
raise
ValueError
(
...
...
pytensor/xtensor/vectorization.py
浏览文件 @
9d99267c
...
...
@@ -6,6 +6,8 @@ import numpy as np
from
pytensor
import
scalar
as
ps
from
pytensor
import
shared
from
pytensor.graph
import
Apply
,
Op
from
pytensor.graph.basic
import
Variable
from
pytensor.graph.replace
import
_vectorize_node
from
pytensor.scalar
import
discrete_dtypes
from
pytensor.tensor
import
tensor
from
pytensor.tensor.random.op
import
RNGConsumerOp
...
...
@@ -14,8 +16,11 @@ from pytensor.tensor.utils import (
get_static_shape_from_size_variables
,
)
from
pytensor.utils
import
unzip
from
pytensor.xtensor.basic
import
XOp
from
pytensor.xtensor.type
import
XTensorVariable
,
as_xtensor
,
xtensor
from
pytensor.xtensor.basic
import
(
XOp
,
XTypeCastOp
,
)
from
pytensor.xtensor.type
import
XTensorType
,
XTensorVariable
,
as_xtensor
,
xtensor
def
combine_dims_and_shape
(
...
...
@@ -74,6 +79,9 @@ class XElemwise(XOp):
]
return
Apply
(
self
,
inputs
,
outputs
)
def
vectorize_node
(
self
,
node
,
*
new_inputs
):
return
self
(
*
new_inputs
,
return_list
=
True
)
class
XBlockwise
(
XOp
):
__props__
=
(
"core_op"
,
"core_dims"
)
...
...
@@ -141,6 +149,9 @@ class XBlockwise(XOp):
]
return
Apply
(
self
,
inputs
,
outputs
)
def
vectorize_node
(
self
,
node
,
*
new_inputs
):
return
self
(
*
new_inputs
,
return_list
=
True
)
class
XRV
(
XOp
,
RNGConsumerOp
):
"""Wrapper for RandomVariable operations that follows xarray-like broadcasting semantics.
...
...
@@ -288,3 +299,54 @@ class XRV(XOp, RNGConsumerOp):
)
return
Apply
(
self
,
[
rng
,
*
extra_dim_lengths
,
*
params
],
[
rng
.
type
(),
out
])
def
vectorize_node
(
self
,
node
,
*
new_inputs
):
new_rng
,
*
new_extra_dim_lengths_and_params
=
new_inputs
k
=
len
(
self
.
extra_dims
)
new_extra_dim_lengths
,
new_params
=
(
new_extra_dim_lengths_and_params
[:
k
],
new_extra_dim_lengths_and_params
[
k
:],
)
new_extra_dim_lengths
=
[
dl
.
squeeze
()
for
dl
in
new_extra_dim_lengths
]
if
not
all
(
dl
.
type
.
ndim
==
0
for
dl
in
new_extra_dim_lengths
):
raise
NotImplementedError
(
f
"Vectorization of {self} with batched extra_dim_lengths not implemented, "
)
return
self
.
make_node
(
new_rng
,
*
new_extra_dim_lengths
,
*
new_params
)
.
outputs
@_vectorize_node.register
(
XOp
)
@_vectorize_node.register
(
XTypeCastOp
)
def
vectorize_xop
(
op
:
XOp
,
node
,
*
new_inputs
)
->
Sequence
[
Variable
]:
old_inp_dims
=
[
inp
.
dims
for
inp
in
node
.
inputs
if
isinstance
(
inp
.
type
,
XTensorType
)
]
old_out_dims
=
[
out
.
dims
for
out
in
node
.
outputs
if
isinstance
(
out
.
type
,
XTensorType
)
]
all_old_dims_set
=
set
(
chain
.
from_iterable
((
*
old_inp_dims
,
old_out_dims
)))
for
new_inp
,
old_inp
in
zip
(
new_inputs
,
node
.
inputs
,
strict
=
True
):
if
not
(
isinstance
(
new_inp
.
type
,
XTensorType
)
and
isinstance
(
old_inp
.
type
,
XTensorType
)
):
continue
old_dims_set
=
set
(
old_inp
.
dims
)
new_dims_set
=
set
(
new_inp
.
dims
)
# Validate that new inputs didn't drop pre-existing dims
if
missing_dims
:
=
old_dims_set
-
new_dims_set
:
raise
ValueError
(
f
"Vectorized input {new_inp} is missing pre-existing dims: {sorted(missing_dims)}"
)
# Or have new dimensions that were already in the graph
if
new_core_dims
:
=
((
new_dims_set
-
old_dims_set
)
&
all_old_dims_set
):
raise
ValueError
(
f
"Vectorized input {new_inp} has new dimensions that were present in the original graph: {new_core_dims}"
)
return
op
.
vectorize_node
(
node
,
*
new_inputs
)
tests/xtensor/test_basic.py
浏览文件 @
9d99267c
import
pytest
pytest
.
importorskip
(
"xarray"
)
import
numpy
as
np
from
pytensor
import
function
from
pytensor.xtensor.basic
import
Rename
from
pytensor.graph
import
vectorize_graph
from
pytensor.tensor
import
matrix
,
vector
from
pytensor.xtensor.basic
import
(
Rename
,
rename
,
tensor_from_xtensor
,
xtensor_from_tensor
,
)
from
pytensor.xtensor.type
import
xtensor
from
tests.unittest_tools
import
assert_equal_computations
# from pytensor.xtensor.vectorization import vectorize_graph
from
tests.xtensor.util
import
check_vectorization
def
test_shape_feature_does_not_see_xop
():
...
...
@@ -24,3 +40,36 @@ def test_shape_feature_does_not_see_xop():
fn
=
function
([
x
],
out
)
np
.
testing
.
assert_allclose
(
fn
([
1
,
2
,
3
]),
[
0
,
0
,
0
])
assert
not
CALLED
def
test_rename_vectorize
():
ab
=
xtensor
(
"ab"
,
dims
=
(
"a"
,
"b"
),
shape
=
(
2
,
3
),
dtype
=
"float64"
)
check_vectorization
(
ab
,
rename
(
ab
,
a
=
"c"
))
def
test_xtensor_from_tensor_vectorize
():
t
=
vector
(
"t"
)
x
=
xtensor_from_tensor
(
t
,
dims
=
(
"a"
,))
t_batched
=
matrix
(
"t_batched"
)
with
pytest
.
raises
(
NotImplementedError
,
match
=
r"Vectorization of .* not implemented"
):
vectorize_graph
([
x
],
{
t
:
t_batched
})
def
test_tensor_from_xtensor_vectorize
():
x
=
xtensor
(
"x"
,
dims
=
(
"a"
,),
shape
=
(
3
,))
y
=
tensor_from_xtensor
(
x
)
x_batched
=
xtensor
(
"x"
,
dims
=
(
"a"
,
"b"
),
shape
=
(
3
,
5
))
y_batched
=
vectorize_graph
(
y
,
{
x
:
x_batched
})
# vectorize_graph should place output batch dimension on the left
assert
y_batched
.
type
.
shape
==
(
5
,
3
)
assert_equal_computations
([
y_batched
],
[
x_batched
.
transpose
(
"b"
,
...
)
.
values
])
x_batched
=
xtensor
(
"x"
,
dims
=
(
"c"
,
"a"
,
"b"
),
shape
=
(
7
,
3
,
5
))
# vectorize_graph can't handle multiple batch dimensions safely
with
pytest
.
raises
(
NotImplementedError
):
vectorize_graph
(
y
,
{
x
:
x_batched
})
tests/xtensor/test_indexing.py
浏览文件 @
9d99267c
...
...
@@ -14,6 +14,7 @@ from pytensor.tensor import tensor
from
pytensor.xtensor
import
xtensor
from
tests.unittest_tools
import
assert_equal_computations
from
tests.xtensor.util
import
(
check_vectorization
,
xr_arange_like
,
xr_assert_allclose
,
xr_function
,
...
...
@@ -542,3 +543,43 @@ def test_empty_update_index():
fn
=
xr_function
([
x
],
out1
)
x_test
=
xr_random_like
(
x
)
xr_assert_allclose
(
fn
(
x_test
),
x_test
+
1
)
def
test_indexing_vectorize
():
abc
=
xtensor
(
dims
=
(
"a"
,
"b"
,
"c"
),
shape
=
(
3
,
5
,
7
))
a_idx
=
xtensor
(
dims
=
(
"a"
,),
shape
=
(
5
,),
dtype
=
"int64"
)
c_idx
=
xtensor
(
dims
=
(
"c"
,),
shape
=
(
3
,),
dtype
=
"int64"
)
abc_val
=
xr_random_like
(
abc
)
a_idx_val
=
DataArray
([
0
,
1
,
0
,
2
,
0
],
dims
=
(
"a"
,))
c_idx_val
=
DataArray
([
0
,
5
,
6
],
dims
=
(
"c"
,))
check_vectorization
([
abc
,
a_idx
],
[
abc
.
isel
(
a
=
a_idx
)],
[
abc_val
,
a_idx_val
])
check_vectorization
(
[
abc
,
a_idx
],
[
abc
.
isel
(
a
=
a_idx
.
rename
(
a
=
"b"
))],
[
abc_val
,
a_idx_val
]
)
check_vectorization
(
[
abc
,
a_idx
],
[
abc
.
isel
(
a
=
a_idx
.
rename
(
a
=
"d"
))],
[
abc_val
,
a_idx_val
]
)
check_vectorization
([
abc
,
a_idx
],
[
abc
.
isel
(
c
=
a_idx
[:
3
])],
[
abc_val
,
a_idx_val
])
check_vectorization
(
[
abc
,
a_idx
],
[
abc
.
isel
(
a
=
a_idx
,
c
=
a_idx
)],
[
abc_val
,
a_idx_val
]
)
check_vectorization
(
[
abc
,
a_idx
,
c_idx
],
[
abc
.
isel
(
a
=
a_idx
,
c
=
c_idx
)],
[
abc_val
,
a_idx_val
,
c_idx_val
],
)
def
test_index_update_vectorize
():
x
=
xtensor
(
"x"
,
dims
=
(
"a"
,
"b"
),
shape
=
(
3
,
5
))
idx
=
xtensor
(
"idx"
,
dims
=
(
"b*"
,),
shape
=
(
7
,),
dtype
=
int
)
y
=
xtensor
(
"y"
,
dims
=
(
"b*"
,),
shape
=
(
7
,))
x_val
=
xr_random_like
(
x
)
idx_val
=
DataArray
([
2
,
0
,
4
,
0
,
1
,
0
,
3
],
dims
=
(
"b*"
,))
y_val
=
xr_random_like
(
y
)
check_vectorization
([
x
,
idx
,
y
],
[
x
.
isel
(
b
=
idx
)
.
set
(
y
)],
[
x_val
,
idx_val
,
y_val
])
check_vectorization
([
x
,
idx
,
y
],
[
x
.
isel
(
b
=
idx
)
.
inc
(
y
)],
[
x_val
,
idx_val
,
y_val
])
tests/xtensor/test_linalg.py
浏览文件 @
9d99267c
...
...
@@ -16,7 +16,7 @@ from xarray_einstats.linalg import (
from
pytensor.xtensor.linalg
import
cholesky
,
solve
from
pytensor.xtensor.type
import
xtensor
from
tests.xtensor.util
import
xr_assert_allclose
,
xr_function
from
tests.xtensor.util
import
check_vectorization
,
xr_assert_allclose
,
xr_function
def
test_cholesky
():
...
...
@@ -74,3 +74,22 @@ def test_solve_matrix_b():
fn
(
a_test
,
b_test
),
xr_solve
(
a_test
,
b_test
,
dims
=
[
"country"
,
"city"
,
"district"
]),
)
def
test_linalg_vectorize
():
# Note: We only need to test a couple Ops, since the vectorization logic is not Op specific
a
=
xtensor
(
"b"
,
dims
=
(
"a"
,),
shape
=
(
3
,))
ab
=
xtensor
(
"a"
,
dims
=
(
"a"
,
"b"
),
shape
=
(
3
,
3
))
test_spd
=
np
.
random
.
randn
(
3
,
3
)
test_spd
=
test_spd
@
test_spd
.
T
check_vectorization
(
[
ab
],
[
cholesky
(
ab
,
dims
=
(
"b"
,
"a"
))],
input_vals
=
[
DataArray
(
test_spd
,
dims
=
(
"a"
,
"b"
))],
)
check_vectorization
(
[
ab
,
a
],
[
solve
(
ab
,
a
,
dims
=
(
"a"
,
"b"
))],
)
tests/xtensor/test_math.py
浏览文件 @
9d99267c
...
...
@@ -17,7 +17,12 @@ from pytensor.scalar import ScalarOp
from
pytensor.xtensor.basic
import
rename
from
pytensor.xtensor.math
import
add
,
exp
,
logsumexp
from
pytensor.xtensor.type
import
xtensor
from
tests.xtensor.util
import
xr_arange_like
,
xr_assert_allclose
,
xr_function
from
tests.xtensor.util
import
(
check_vectorization
,
xr_arange_like
,
xr_assert_allclose
,
xr_function
,
)
def
test_all_scalar_ops_are_wrapped
():
...
...
@@ -340,3 +345,11 @@ def test_dot_errors():
match
=
r"(Input operand 1 has a mismatch in its core dimension 0|incompatible array sizes for np.dot)"
,
):
fn
(
x_test
,
y_test
)
def
test_xelemwise_vectorize
():
ab
=
xtensor
(
"ab"
,
dims
=
(
"a"
,
"b"
),
shape
=
(
2
,
3
))
bc
=
xtensor
(
"bc"
,
dims
=
(
"b"
,
"c"
),
shape
=
(
3
,
5
))
check_vectorization
([
ab
],
[
exp
(
ab
)])
check_vectorization
([
ab
,
bc
],
[
ab
+
bc
])
tests/xtensor/test_random.py
浏览文件 @
9d99267c
...
...
@@ -9,6 +9,7 @@ import re
from
copy
import
deepcopy
import
numpy
as
np
from
xarray
import
DataArray
import
pytensor.tensor.random
as
ptr
import
pytensor.xtensor.random
as
pxr
...
...
@@ -26,6 +27,7 @@ from pytensor.xtensor.random import (
normal
,
)
from
pytensor.xtensor.vectorization
import
XRV
from
tests.xtensor.util
import
check_vectorization
def
lower_rewrite
(
vars
):
...
...
@@ -438,3 +440,27 @@ def test_multivariate_normal():
):
# cov must have both core_dims
multivariate_normal
(
mu_xr
,
cov_xr
,
core_dims
=
(
"rows"
,
"missing_cols"
))
def
test_xrv_vectorize
():
# Note: We only need to test a couple Ops, since the vectorization logic is not Op specific
n
=
xtensor
(
"n"
,
dims
=
(
"n"
,),
shape
=
(
3
,),
dtype
=
int
)
pna
=
xtensor
(
"p"
,
dims
=
(
"p"
,
"n"
,
"a"
),
shape
=
(
5
,
3
,
2
))
out
=
multinomial
(
n
,
pna
,
core_dims
=
(
"p"
,),
extra_dims
=
{
"extra"
:
5
})
check_vectorization
(
[
n
,
pna
],
[
out
],
input_vals
=
[
DataArray
([
3
,
5
,
10
],
dims
=
(
"n"
,)),
DataArray
(
np
.
random
.
multinomial
(
n
=
1
,
pvals
=
np
.
ones
(
5
)
/
5
,
size
=
(
2
,
3
))
.
T
,
dims
=
(
"p"
,
"n"
,
"a"
),
),
],
)
def
test_xrv_batch_extra_dim_vectorize
():
# TODO: Check it raises NotImplementedError when we try to batch the extra_dim of an xrv
pass
tests/xtensor/test_reduction.py
浏览文件 @
9d99267c
...
...
@@ -8,7 +8,12 @@ import numpy as np
import
xarray
as
xr
from
pytensor.xtensor.type
import
xtensor
from
tests.xtensor.util
import
xr_arange_like
,
xr_assert_allclose
,
xr_function
from
tests.xtensor.util
import
(
check_vectorization
,
xr_arange_like
,
xr_assert_allclose
,
xr_function
,
)
@pytest.mark.parametrize
(
...
...
@@ -99,3 +104,16 @@ def test_discrete_reduction_upcasting(signed):
res
=
fn
(
x_val
)
np
.
testing
.
assert_allclose
(
res
,
[
test_val
,
test_val
**
2
])
xr_assert_allclose
(
res
,
x_val
.
cumprod
())
def
test_reduction_vectorize
():
# Note: We only need to test a couple Ops, since the vectorization logic is not Op specific
abc
=
xtensor
(
"x"
,
dims
=
(
"a"
,
"b"
,
"c"
),
shape
=
(
3
,
5
,
7
))
check_vectorization
([
abc
],
[
abc
.
sum
(
dim
=
"a"
)])
check_vectorization
([
abc
],
[
abc
.
max
(
dim
=
(
"a"
,
"c"
))])
check_vectorization
([
abc
],
[
abc
.
all
()])
check_vectorization
([
abc
],
[
abc
.
cumsum
(
dim
=
"b"
)])
check_vectorization
([
abc
],
[
abc
.
cumsum
(
dim
=
(
"c"
,
"b"
))])
check_vectorization
([
abc
],
[
abc
.
cumprod
()])
tests/xtensor/test_shape.py
浏览文件 @
9d99267c
...
...
@@ -15,7 +15,8 @@ from xarray import full_like as xr_full_like
from
xarray
import
ones_like
as
xr_ones_like
from
xarray
import
zeros_like
as
xr_zeros_like
from
pytensor.tensor
import
scalar
from
pytensor.graph
import
vectorize_graph
from
pytensor.tensor
import
scalar
,
vector
from
pytensor.xtensor.shape
import
(
broadcast
,
concat
,
...
...
@@ -25,8 +26,9 @@ from pytensor.xtensor.shape import (
unstack
,
zeros_like
,
)
from
pytensor.xtensor.type
import
xtensor
from
pytensor.xtensor.type
import
as_xtensor
,
xtensor
from
tests.xtensor.util
import
(
check_vectorization
,
xr_arange_like
,
xr_assert_allclose
,
xr_function
,
...
...
@@ -800,3 +802,90 @@ def test_zeros_like():
expected1
=
xr_zeros_like
(
x_test
)
xr_assert_allclose
(
result1
,
expected1
)
assert
result1
.
dtype
==
expected1
.
dtype
def
test_shape_ops_vectorize
():
a1
=
xtensor
(
"a1"
,
dims
=
(
"a"
,
"1"
),
shape
=
(
2
,
1
),
dtype
=
"float64"
)
ab
=
xtensor
(
"ab"
,
dims
=
(
"a"
,
"b"
),
shape
=
(
2
,
3
),
dtype
=
"float64"
)
abc
=
xtensor
(
"abc"
,
dims
=
(
"a"
,
"b"
,
"c"
),
shape
=
(
2
,
3
,
5
),
dtype
=
"float64"
)
a_bc_d
=
xtensor
(
"a_bc_d"
,
dims
=
(
"a"
,
"bc"
,
"d"
),
shape
=
(
4
,
15
,
7
))
check_vectorization
(
abc
,
abc
.
transpose
(
"b"
,
"c"
,
"a"
))
check_vectorization
(
abc
,
abc
.
transpose
(
"b"
,
...
))
check_vectorization
(
abc
,
stack
(
abc
,
new_dim
=
(
"a"
,
"c"
)))
check_vectorization
(
a_bc_d
,
unstack
(
a_bc_d
,
bc
=
dict
(
b
=
3
,
c
=
5
)))
check_vectorization
([
abc
,
ab
],
concat
([
abc
,
ab
],
dim
=
"a"
))
check_vectorization
(
a1
,
a1
.
squeeze
(
"1"
))
check_vectorization
(
abc
,
abc
.
expand_dims
(
d
=
5
))
check_vectorization
([
ab
,
abc
],
broadcast
(
ab
,
abc
))
check_vectorization
([
ab
,
abc
,
a1
],
broadcast
(
ab
,
abc
,
a1
,
exclude
=
"1"
))
# a is longer in a_bc_d than in ab and abc, helper can't handle that
# check_vectorization([ab, abc, a_bc_d], broadcast(ab, abc, a_bc_d, exclude="a"))
def
test_broadcast_exclude_vectorize
():
x
=
xtensor
(
"x"
,
dims
=
(
"a"
,
"b"
),
shape
=
(
3
,
4
))
y
=
xtensor
(
"y"
,
dims
=
(
"b"
,
"c"
),
shape
=
(
7
,
5
))
out_x
,
out_y
=
broadcast
(
x
,
y
,
exclude
=
(
"b"
,))
x_val
=
xr_random_like
(
x
)
y_val
=
xr_random_like
(
y
)
x_batch_val
=
x_val
.
expand_dims
({
"batch"
:
2
})
y_batch_val
=
y_val
.
expand_dims
({
"batch"
:
2
})
x_batch
=
as_xtensor
(
x_batch_val
)
.
type
(
"x_batch"
)
y_batch
=
as_xtensor
(
y_batch_val
)
.
type
(
"y_batch"
)
[
out_x_vec
,
out_y_vec
]
=
vectorize_graph
([
out_x
,
out_y
],
{
x
:
x_batch
,
y
:
y_batch
})
fn
=
xr_function
([
x_batch
,
y_batch
],
[
out_x_vec
,
out_y_vec
])
res_x
,
res_y
=
fn
(
x_batch_val
,
y_batch_val
)
expected_x
=
[]
expected_y
=
[]
for
i
in
range
(
2
):
ex_x
,
ex_y
=
xr_broadcast
(
x_batch_val
.
isel
(
batch
=
i
),
y_batch_val
.
isel
(
batch
=
i
),
exclude
=
(
"b"
,)
)
expected_x
.
append
(
ex_x
)
expected_y
.
append
(
ex_y
)
expected_x
=
xr_concat
(
expected_x
,
dim
=
"batch"
)
expected_y
=
xr_concat
(
expected_y
,
dim
=
"batch"
)
xr_assert_allclose
(
res_x
,
expected_x
)
xr_assert_allclose
(
res_y
,
expected_y
)
def
test_expand_dims_batch_length_vectorize
():
x
=
xtensor
(
"x"
,
dims
=
(
"a"
,),
shape
=
(
3
,))
l
=
scalar
(
"l"
,
dtype
=
"int64"
)
y
=
x
.
expand_dims
(
b
=
l
)
x_batch
=
as_xtensor
(
xr_random_like
(
x
)
.
expand_dims
(
batch
=
2
))
.
type
(
"x_batch"
)
l_batch
=
vector
(
"l_batch"
,
dtype
=
"int64"
)
with
pytest
.
raises
(
NotImplementedError
,
match
=
r"Vectorization of .* not implemented"
):
vectorize_graph
([
y
],
{
x
:
x_batch
,
l
:
l_batch
})
def
test_unstack_batch_length_vectorize
():
x
=
xtensor
(
"x"
,
dims
=
(
"ab"
,),
shape
=
(
12
,))
l
=
scalar
(
"l"
,
dtype
=
"int64"
)
y
=
unstack
(
x
,
ab
=
{
"a"
:
l
,
"b"
:
x
.
sizes
[
"ab"
]
//
l
})
x_batch
=
as_xtensor
(
xr_random_like
(
x
)
.
expand_dims
(
batch
=
2
))
.
type
(
"x_batch"
)
l_batch
=
vector
(
"l_batch"
,
dtype
=
"int64"
)
with
pytest
.
raises
(
NotImplementedError
,
match
=
r"Vectorization of .* not implemented"
):
vectorize_graph
([
y
],
{
x
:
x_batch
,
l
:
l_batch
})
tests/xtensor/test_signal.py
浏览文件 @
9d99267c
...
...
@@ -12,7 +12,12 @@ from xarray import apply_ufunc
from
pytensor.xtensor.signal
import
convolve1d
from
pytensor.xtensor.type
import
xtensor
from
tests.xtensor.util
import
xr_arange_like
,
xr_assert_allclose
,
xr_function
from
tests.xtensor.util
import
(
check_vectorization
,
xr_arange_like
,
xr_assert_allclose
,
xr_function
,
)
@pytest.mark.parametrize
(
"mode"
,
(
"full"
,
"valid"
,
"same"
))
...
...
@@ -68,3 +73,14 @@ def test_convolve_1d_invalid():
match
=
re
.
escape
(
"Input 1 has invalid core dims ['time']. Allowed: ('kernel',)"
),
):
convolve1d
(
in1
,
in2
.
rename
({
"batch"
:
"time"
}),
dims
=
(
"time"
,
"kernel"
))
def
test_signal_vectorize
():
# Note: We only need to test a couple Ops, since the vectorization logic is not Op specific
ab
=
xtensor
(
"a"
,
dims
=
(
"a"
,
"b"
),
shape
=
(
3
,
3
))
c
=
xtensor
(
name
=
"c"
,
dims
=
(
"c"
,),
shape
=
(
7
,))
check_vectorization
(
[
ab
,
c
],
[
convolve1d
(
ab
,
c
,
dims
=
(
"a"
,
"c"
))],
)
tests/xtensor/util.py
浏览文件 @
9d99267c
import
pytest
pytest
.
importorskip
(
"xarray"
)
xr
=
pytest
.
importorskip
(
"xarray"
)
from
itertools
import
chain
import
numpy
as
np
from
xarray
import
DataArray
from
xarray.testing
import
assert_allclose
from
pytensor
import
function
from
pytensor.xtensor.type
import
XTensorType
from
pytensor.graph
import
vectorize_graph
from
pytensor.xtensor.type
import
XTensorType
,
as_xtensor
def
xr_function
(
*
args
,
**
kwargs
):
...
...
@@ -76,3 +79,57 @@ def xr_random_like(x, rng=None):
return
DataArray
(
rng
.
standard_normal
(
size
=
x
.
type
.
shape
,
dtype
=
x
.
type
.
dtype
),
dims
=
x
.
type
.
dims
)
def
check_vectorization
(
inputs
,
outputs
,
input_vals
=
None
,
rng
=
None
):
# Create core graph and function
if
not
isinstance
(
inputs
,
list
|
tuple
):
inputs
=
(
inputs
,)
if
not
isinstance
(
outputs
,
list
|
tuple
):
outputs
=
(
outputs
,)
# apply_ufunc isn't happy with list output or single entry
_core_fn
=
function
(
inputs
,
outputs
)
def
core_fn
(
*
args
,
_core_fn
=
_core_fn
):
res
=
_core_fn
(
*
args
)
if
len
(
res
)
==
1
:
return
res
[
0
]
else
:
return
tuple
(
res
)
if
input_vals
is
None
:
rng
=
np
.
random
.
default_rng
(
rng
)
input_vals
=
[
xr_random_like
(
inp
,
rng
)
for
inp
in
inputs
]
# Create vectorized inputs
batch_inputs
=
[]
batch_input_vals
=
[]
for
i
,
(
inp
,
val
)
in
enumerate
(
zip
(
inputs
,
input_vals
)):
new_val
=
val
.
expand_dims
({
f
"batch_{i}"
:
2
**
(
i
+
1
)})
new_inp
=
as_xtensor
(
new_val
)
.
type
(
f
"batch_{inp.name or f'input{i}'}"
)
batch_inputs
.
append
(
new_inp
)
batch_input_vals
.
append
(
new_val
)
# Create vectorized function
new_outputs
=
vectorize_graph
(
outputs
,
dict
(
zip
(
inputs
,
batch_inputs
)))
vec_fn
=
xr_function
(
batch_inputs
,
new_outputs
)
vec_res
=
vec_fn
(
*
batch_input_vals
)
# xarray.apply_ufunc with vectorize=True loops over non-core dims
input_core_dims
=
[
i
.
dims
for
i
in
inputs
]
output_core_dims
=
[
o
.
dims
for
o
in
outputs
]
expected_res
=
xr
.
apply_ufunc
(
core_fn
,
*
batch_input_vals
,
input_core_dims
=
input_core_dims
,
output_core_dims
=
output_core_dims
,
exclude_dims
=
set
(
chain
.
from_iterable
((
*
input_core_dims
,
*
output_core_dims
))),
vectorize
=
True
,
)
if
not
isinstance
(
expected_res
,
list
|
tuple
):
expected_res
=
(
expected_res
,)
for
v_r
,
e_r
in
zip
(
vec_res
,
expected_res
):
xr_assert_allclose
(
v_r
,
e_r
.
transpose
(
*
v_r
.
dims
))
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论