Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
20ff202e
提交
20ff202e
authored
6月 12, 2025
作者:
Ricardo Vieira
提交者:
Ricardo Vieira
7月 23, 2025
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Define all batched dot operations as matmul
New rewrite is added to convert unpaired batched row/column matvec or vec products as equivalent matmul products.
上级
e265debc
显示空白字符变更
内嵌
并排
正在显示
8 个修改的文件
包含
279 行增加
和
102 行删除
+279
-102
math.py
pytensor/tensor/math.py
+16
-32
blas.py
pytensor/tensor/rewriting/blas.py
+2
-2
elemwise.py
pytensor/tensor/rewriting/elemwise.py
+2
-0
linalg.py
pytensor/tensor/rewriting/linalg.py
+2
-2
math.py
pytensor/tensor/rewriting/math.py
+139
-45
test_blas.py
tests/tensor/rewriting/test_blas.py
+28
-13
test_math.py
tests/tensor/rewriting/test_math.py
+85
-2
test_math.py
tests/tensor/test_math.py
+5
-6
没有找到文件。
pytensor/tensor/math.py
浏览文件 @
20ff202e
...
...
@@ -3921,23 +3921,7 @@ def logsumexp(x, axis=None, keepdims=False):
return
log
(
sum
(
exp
(
x
),
axis
=
axis
,
keepdims
=
keepdims
))
# Predefine all batched variations of Dot
_inner_prod
=
Blockwise
(
_dot
,
signature
=
"(n),(n)->()"
,
)
_matrix_vec_prod
=
Blockwise
(
_dot
,
signature
=
"(m,k),(k)->(m)"
,
)
_vec_matrix_prod
=
Blockwise
(
_dot
,
signature
=
"(k),(k,n)->(n)"
,
)
_matrix_matrix_matmul
=
Blockwise
(
_matmul
=
Blockwise
(
_dot
,
signature
=
"(m,k),(k,n)->(m,n)"
,
gufunc_spec
=
(
"numpy.matmul"
,
2
,
1
),
...
...
@@ -3993,11 +3977,11 @@ def matmul(x1: "ArrayLike", x2: "ArrayLike", dtype: Optional["DTypeLike"] = None
if
x1
.
type
.
ndim
==
1
and
x2
.
type
.
ndim
==
1
:
out
=
_dot
(
x1
,
x2
)
elif
x1
.
type
.
ndim
==
1
:
out
=
_matrix_matrix_matmul
(
x1
[
None
],
x2
)
.
squeeze
(
-
2
)
out
=
vecmat
(
x1
,
x
2
)
elif
x2
.
type
.
ndim
==
1
:
out
=
_matrix_matrix_matmul
(
x1
,
x2
[:,
None
])
.
squeeze
(
-
1
)
out
=
matvec
(
x1
,
x2
)
else
:
out
=
_mat
rix_matrix_mat
mul
(
x1
,
x2
)
out
=
_matmul
(
x1
,
x2
)
if
dtype
is
not
None
:
out
=
out
.
astype
(
dtype
)
...
...
@@ -4047,7 +4031,7 @@ def vecdot(
>>> z_batch = pt.vecdot(x_batch, y_batch) # shape (3,)
>>> # Equivalent to numpy.vecdot(x_batch, y_batch)
"""
out
=
_inner_prod
(
x1
,
x2
)
out
=
matmul
(
x1
[
...
,
None
,
:],
x2
[
...
,
:,
None
])
.
squeeze
((
-
2
,
-
1
)
)
if
dtype
is
not
None
:
out
=
out
.
astype
(
dtype
)
...
...
@@ -4096,7 +4080,7 @@ def matvec(
>>> result = pt.matvec(batched_A, batched_v) # shape (2, 3)
>>> # Equivalent to numpy.matvec(batched_A, batched_v)
"""
out
=
_matrix_vec_prod
(
x1
,
x2
)
out
=
matmul
(
x1
,
x2
[
...
,
None
])
.
squeeze
(
-
1
)
if
dtype
is
not
None
:
out
=
out
.
astype
(
dtype
)
...
...
@@ -4134,18 +4118,18 @@ def vecmat(
--------
>>> import pytensor.tensor as pt
>>> # Vector-matrix product
>>> v = pt.vector("v", shape=(3,))
# shape (3,)
>>> A = pt.matrix("A", shape=(3, 4))
# shape (3, 4)
>>> v = pt.vector("v", shape=(3,))
>>> A = pt.matrix("A", shape=(3, 4))
>>> result = pt.vecmat(v, A) # shape (4,)
>>> # Equivalent to numpy.vecmat(v, A)
>>>
>>> # Batched vector-matrix product
>>> batched_v = pt.matrix("v", shape=(2, 3))
# shape (2, 3)
>>> batched_A = pt.tensor3("A", shape=(2, 3, 4))
# shape (2, 3, 4)
>>> batched_v = pt.matrix("v", shape=(2, 3))
>>> batched_A = pt.tensor3("A", shape=(2, 3, 4))
>>> result = pt.vecmat(batched_v, batched_A) # shape (2, 4)
>>> # Equivalent to numpy.vecmat(batched_v, batched_A)
"""
out
=
_vec_matrix_prod
(
x1
,
x2
)
out
=
matmul
(
x2
.
mT
,
x1
[
...
,
None
])
.
squeeze
(
-
1
)
if
dtype
is
not
None
:
out
=
out
.
astype
(
dtype
)
...
...
@@ -4160,18 +4144,18 @@ def vectorize_node_dot(op, node, batched_x, batched_y):
old_y_ndim
=
old_y
.
type
.
ndim
match
(
old_x_ndim
,
old_y_ndim
):
case
(
1
,
1
):
batch_
op
=
_inner_prod
batch_
fn
=
vecdot
case
(
2
,
1
):
batch_
op
=
_matrix_vec_prod
batch_
fn
=
matvec
case
(
1
,
2
):
batch_
op
=
_vec_matrix_prod
batch_
fn
=
vecmat
case
(
2
,
2
):
batch_
op
=
_matrix_matrix_
matmul
batch_
fn
=
matmul
case
_
:
raise
ValueError
(
f
"Core dot Op should have 1D or 2D inputs, got {old_x_ndim}D and {old_y_ndim}D."
)
return
batch_
op
(
batched_x
,
batched_y
)
.
owner
return
batch_
fn
(
batched_x
,
batched_y
)
.
owner
def
nan_to_num
(
x
,
nan
=
0.0
,
posinf
=
None
,
neginf
=
None
):
...
...
pytensor/tensor/rewriting/blas.py
浏览文件 @
20ff202e
...
...
@@ -98,7 +98,7 @@ from pytensor.tensor.elemwise import DimShuffle, Elemwise
from
pytensor.tensor.exceptions
import
NotScalarConstantError
from
pytensor.tensor.math
import
(
Dot
,
_mat
rix_matrix_mat
mul
,
_matmul
,
add
,
mul
,
neg
,
...
...
@@ -908,7 +908,7 @@ blas_optdb.register(
@register_specialize
@node_rewriter
([
_mat
rix_matrix_mat
mul
])
@node_rewriter
([
_matmul
])
def
specialize_matmul_to_batched_dot
(
fgraph
,
node
):
"""Rewrite Matmul (Blockwise matrix-matrix) without implicit broadcasted batched dimension as BatchedDot.
...
...
pytensor/tensor/rewriting/elemwise.py
浏览文件 @
20ff202e
...
...
@@ -39,6 +39,7 @@ from pytensor.tensor.rewriting.basic import (
broadcasted_by
,
register_canonicalize
,
register_specialize
,
register_stabilize
,
)
from
pytensor.tensor.variable
import
TensorConstant
,
TensorVariable
...
...
@@ -341,6 +342,7 @@ def is_dimshuffle_useless(new_order, input):
@register_canonicalize
@register_stabilize
@register_specialize
@node_rewriter
([
DimShuffle
])
def
local_dimshuffle_lift
(
fgraph
,
node
):
...
...
pytensor/tensor/rewriting/linalg.py
浏览文件 @
20ff202e
...
...
@@ -26,7 +26,7 @@ from pytensor.tensor.basic import (
from
pytensor.tensor.blas
import
Dot22
from
pytensor.tensor.blockwise
import
Blockwise
from
pytensor.tensor.elemwise
import
DimShuffle
,
Elemwise
from
pytensor.tensor.math
import
Dot
,
Prod
,
_mat
rix_matrix_mat
mul
,
log
,
outer
,
prod
from
pytensor.tensor.math
import
Dot
,
Prod
,
_matmul
,
log
,
outer
,
prod
from
pytensor.tensor.nlinalg
import
(
SVD
,
KroneckerProduct
,
...
...
@@ -284,7 +284,7 @@ def cholesky_ldotlt(fgraph, node):
# This rewrite only applies to matrix Dot
and
A
.
owner
.
inputs
[
0
]
.
type
.
ndim
==
2
)
or
(
A
.
owner
.
op
==
_mat
rix_matrix_mat
mul
)
or
(
A
.
owner
.
op
==
_matmul
)
)
):
return
...
...
pytensor/tensor/rewriting/math.py
浏览文件 @
20ff202e
...
...
@@ -28,6 +28,7 @@ from pytensor.tensor.basic import (
as_tensor_variable
,
cast
,
constant
,
expand_dims
,
get_underlying_scalar_constant_value
,
moveaxis
,
ones_like
,
...
...
@@ -35,7 +36,6 @@ from pytensor.tensor.basic import (
switch
,
zeros_like
,
)
from
pytensor.tensor.blockwise
import
Blockwise
from
pytensor.tensor.elemwise
import
CAReduce
,
DimShuffle
,
Elemwise
from
pytensor.tensor.exceptions
import
NotScalarConstantError
from
pytensor.tensor.extra_ops
import
broadcast_arrays
...
...
@@ -45,10 +45,7 @@ from pytensor.tensor.math import (
Sum
,
_conj
,
_dot
,
_inner_prod
,
_matrix_matrix_matmul
,
_matrix_vec_prod
,
_vec_matrix_prod
,
_matmul
,
add
,
digamma
,
dot
,
...
...
@@ -182,7 +179,7 @@ def local_lift_transpose_through_dot(fgraph, node):
if
not
(
is_matrix_transpose
(
node
.
outputs
[
0
])
and
node
.
inputs
[
0
]
.
owner
and
((
dot_op
:
=
node
.
inputs
[
0
]
.
owner
.
op
)
in
(
_dot
,
_mat
rix_matrix_mat
mul
))
and
((
dot_op
:
=
node
.
inputs
[
0
]
.
owner
.
op
)
in
(
_dot
,
_matmul
))
):
return
False
...
...
@@ -197,60 +194,157 @@ def local_lift_transpose_through_dot(fgraph, node):
return
ret
@register_stabilize
@register_specialize
@node_rewriter
(
tracks
=
[
Blockwise
])
def
local_batched_matmul_to_core_matmul
(
fgraph
,
node
):
"""Rewrite matmul where only one of the inputs has batch dimensions to a reshaped core matmul.
def
_batched_matmul_to_core_matmul
(
fgraph
,
node
,
allow_reshape
:
bool
):
"""Move batch dimensions of matmul operands to core matmul
Example, if x has batch dimensions
, but y not:
Example, if x has batch dimensions
that don't overlap with batch dimensions of y
x @ y -> (x.reshape(-1, x.shape[-1]) @ y).reshape(*x.shape[:-1], y.shape[-1])
It also works when y has batch dimensions, but x not.
"""
It also works for batch dimensions of y that don't overlap with batch dimensions of x
# Check whether we have a matmul operation in this node
if
not
(
isinstance
(
node
.
op
.
core_op
,
Dot
)
and
len
(
node
.
op
.
inputs_sig
[
0
])
==
2
and
len
(
node
.
op
.
inputs_sig
[
1
])
==
2
):
return
None
The rewrite only uses reshape when mixing dimensions, and it can refuse to apply if `allow_reshape=False`
"""
x
,
y
=
node
.
inputs
batch_ndim
=
node
.
op
.
batch_ndim
(
node
)
# Check if x has batch dimensions, but y not (or only broadcastable dimensions)
if
any
(
not
b_dim
for
b_dim
in
x
.
type
.
broadcastable
[:
-
2
])
and
all
(
y
.
type
.
broadcastable
[:
-
2
]
):
x_stacked
=
x
.
reshape
((
-
1
,
x
.
shape
[
-
1
]))
out_stacked
=
x_stacked
@
y
.
squeeze
(
tuple
(
range
(
batch_ndim
)))
out
=
out_stacked
.
reshape
((
*
x
.
shape
[:
-
1
],
y
.
shape
[
-
1
]))
return
[
out
]
x_axis_to_merge
=
[
i
for
i
,
(
bcast_x
,
bcast_y
)
in
enumerate
(
zip
(
x
.
type
.
broadcastable
[:
-
2
],
y
.
type
.
broadcastable
[:
-
2
])
)
if
bcast_y
and
not
bcast_x
]
# Otherwise, check if y has batch dimension, but x not
elif
any
(
not
b_dim
for
b_dim
in
y
.
type
.
broadcastable
[:
-
2
])
and
all
(
x
.
type
.
broadcastable
[:
-
2
]
):
# For the y batch case we need to first move the batch axes and then reshape
# y.shape == (*b, k, n)
y_tr
=
moveaxis
(
y
,
-
2
,
0
)
# (k, *b, n)
y_stacked
=
y_tr
.
reshape
((
y
.
shape
[
-
2
],
-
1
))
# (k, *b * n)
out_stacked
=
x
.
squeeze
(
tuple
(
range
(
batch_ndim
)))
@
y_stacked
# (m, *b * n)
out_stacked_tr
=
out_stacked
.
reshape
(
(
x
.
shape
[
-
2
],
*
y
.
shape
[:
-
2
],
y
.
shape
[
-
1
])
)
# (m, *b, n)
out
=
moveaxis
(
out_stacked_tr
,
0
,
-
2
)
# (*b, m, n)
return
[
out
]
y_axis_to_merge
=
[
i
for
i
,
(
bcast_x
,
bcast_y
)
in
enumerate
(
zip
(
x
.
type
.
broadcastable
[:
-
2
],
y
.
type
.
broadcastable
[:
-
2
])
)
if
bcast_x
and
not
bcast_y
]
if
not
(
x_axis_to_merge
or
y_axis_to_merge
):
return
None
# Both x and y have batch dimensions, nothing to do here
x_shape
=
tuple
(
x
.
shape
)
y_shape
=
tuple
(
y
.
shape
)
x_is_row
=
x
.
type
.
broadcastable
[
-
2
]
y_is_col
=
y
.
type
.
broadcastable
[
-
1
]
n_x_axis_to_merge
=
len
(
x_axis_to_merge
)
n_y_axis_to_merge
=
len
(
y_axis_to_merge
)
n_axis_to_merge
=
n_x_axis_to_merge
+
n_y_axis_to_merge
x_stacked
,
y_stacked
=
x
,
y
dims_were_merged
=
False
if
n_x_axis_to_merge
:
# ravel batch dimensions of x on the core (m) axis
x_axis_destination
=
tuple
(
range
(
-
n_x_axis_to_merge
-
2
,
-
2
))
x_stacked
=
moveaxis
(
x
,
x_axis_to_merge
,
x_axis_destination
)
if
x_is_row
:
# x was a row matrix, squeeze it to clean up the graph
x_stacked
=
x_stacked
.
squeeze
(
-
2
)
if
n_x_axis_to_merge
>
1
or
not
x_is_row
:
if
not
allow_reshape
:
# TODO: We could allow the y rewrite to go on
# Or just move one axis (the largest) if x is row
return
None
# Ravel moved batch dims together with (m) if needed
x_stacked_shape
=
tuple
(
x_stacked
.
shape
)
x_stacked
=
x_stacked
.
reshape
(
(
*
x_stacked_shape
[:
batch_ndim
-
n_x_axis_to_merge
],
-
1
,
x_shape
[
-
1
])
)
dims_were_merged
=
True
if
n_y_axis_to_merge
:
# ravel batch dimensions of y on the core (n) axis
y_axis_destination
=
tuple
(
range
(
-
n_y_axis_to_merge
-
1
,
-
1
))
y_stacked
=
moveaxis
(
y
,
y_axis_to_merge
,
y_axis_destination
)
if
y_is_col
:
# y was a column matrix, squeeze it to clean up the graph
y_stacked
=
y_stacked
.
squeeze
(
-
1
)
if
n_y_axis_to_merge
>
1
or
not
y_is_col
:
if
not
allow_reshape
:
# TODO: We could allow the x rewrite to go on
# Or just move one axis (the largest) if y is col
return
None
# Ravel moved batch dims together with (n) if needed
y_stacked_shape
=
tuple
(
y_stacked
.
shape
)
y_stacked
=
y_stacked
.
reshape
(
(
*
y_stacked_shape
[:
batch_ndim
-
n_y_axis_to_merge
],
y_shape
[
-
2
],
-
1
)
)
dims_were_merged
=
True
# Squeeze x_dims corresponding to merged dimensions of y
x_axis_to_squeeze
=
np
.
array
(
y_axis_to_merge
)
for
i
in
reversed
(
x_axis_to_merge
):
# The corresponding dimensions of y may have shifted when we merged dimensions of x
x_axis_to_squeeze
[
x_axis_to_squeeze
>
i
]
-=
1
x_stacked
=
x_stacked
.
squeeze
(
tuple
(
x_axis_to_squeeze
))
# Same for y
y_axis_to_squeeze
=
np
.
array
(
x_axis_to_merge
)
for
i
in
reversed
(
y_axis_to_merge
):
y_axis_to_squeeze
[
y_axis_to_squeeze
>
i
]
-=
1
y_stacked
=
y_stacked
.
squeeze
(
tuple
(
y_axis_to_squeeze
))
out_stacked
=
x_stacked
@
y_stacked
# Split back any merged dimensions
if
dims_were_merged
:
x_merged_shapes
=
[
x_shape
[
i
]
for
i
in
x_axis_to_merge
]
if
not
x_is_row
:
# Otherwise we handle that later with expand_dims, which is cleaner
x_merged_shapes
.
append
(
x_shape
[
-
2
])
y_merged_shapes
=
[
y_shape
[
i
]
for
i
in
y_axis_to_merge
]
if
not
y_is_col
:
# Otherwise we handle that later with expand_dims, which is cleaner
y_merged_shapes
.
append
(
y_shape
[
-
1
])
out_stacked_shape
=
tuple
(
out_stacked
.
shape
)
out_unstacked
=
out_stacked
.
reshape
(
(
*
out_stacked_shape
[:
batch_ndim
-
n_axis_to_merge
],
*
x_merged_shapes
,
*
y_merged_shapes
,
)
)
else
:
out_unstacked
=
out_stacked
# Add back dummy row, col axis
# We do this separately to avoid the reshape as much as we can
if
y_is_col
and
(
n_y_axis_to_merge
or
dims_were_merged
):
out_unstacked
=
expand_dims
(
out_unstacked
,
-
1
)
if
x_is_row
and
(
n_x_axis_to_merge
or
dims_were_merged
):
out_unstacked
=
expand_dims
(
out_unstacked
,
-
n_y_axis_to_merge
-
2
)
# Move batch axis back to their original location
source
=
range
(
-
n_axis_to_merge
-
2
,
0
)
destination
=
(
*
x_axis_to_merge
,
-
2
,
*
y_axis_to_merge
,
-
1
)
out
=
moveaxis
(
out_unstacked
,
source
,
destination
)
return
[
out
]
@register_canonicalize
@node_rewriter
(
tracks
=
[
_matmul
])
def
local_batched_matmul_to_core_matmul
(
fgraph
,
node
):
# Allow passing batch dimensions of matmul to core vector / column matrices
return
_batched_matmul_to_core_matmul
(
fgraph
,
node
,
allow_reshape
=
False
)
@register_specialize
@node_rewriter
(
tracks
=
[
_matmul
])
def
local_batched_matmul_to_core_matmul_with_reshape
(
fgraph
,
node
):
# Allow stacking batch dimensions of matmul with core dimensions, with a reshape operation
# We only apply this in specialize, because grahs with reshape are hard to work with
return
_batched_matmul_to_core_matmul
(
fgraph
,
node
,
allow_reshape
=
True
)
@register_canonicalize
@register_specialize
@node_rewriter
([
_
inner_prod
,
_matrix_vec_prod
,
_vec_matrix_prod
,
_matrix_matrix_
matmul
])
@node_rewriter
([
_matmul
])
def
local_blockwise_dot_to_mul
(
fgraph
,
node
):
"""Rewrite blockwise dots that correspond to multiplication without summation.
...
...
tests/tensor/rewriting/test_blas.py
浏览文件 @
20ff202e
import
numpy
as
np
import
pytest
from
pytensor
import
function
from
pytensor
import
config
,
function
from
pytensor
import
tensor
as
pt
from
pytensor.compile
import
get_default_mode
from
pytensor.graph
import
FunctionGraph
from
pytensor.graph
import
FunctionGraph
,
ancestors
from
pytensor.tensor
import
(
col
,
dscalar
,
...
...
@@ -21,7 +21,6 @@ from pytensor.tensor import (
vectorize
,
)
from
pytensor.tensor.blas
import
BatchedDot
from
pytensor.tensor.blockwise
import
Blockwise
from
pytensor.tensor.elemwise
import
DimShuffle
from
pytensor.tensor.rewriting.blas
import
(
_as_scalar
,
...
...
@@ -37,8 +36,11 @@ def XYZab():
return
matrix
(),
matrix
(),
matrix
(),
scalar
(),
scalar
()
@pytest.mark.parametrize
(
"valid_case"
,
(
True
,
False
))
def
test_specialize_matmul_to_batched_dot
(
valid_case
):
@pytest.mark.skipif
(
config
.
mode
==
"FAST_COMPILE"
,
reason
=
"Test requires specialization rewrites"
)
@pytest.mark.parametrize
(
"aligned"
,
(
True
,
False
))
def
test_specialize_matmul_to_batched_dot
(
aligned
):
signature
=
BatchedDot
.
gufunc_signature
rewrite
=
specialize_matmul_to_batched_dot
.
__name__
...
...
@@ -49,23 +51,36 @@ def test_specialize_matmul_to_batched_dot(valid_case):
return
np
.
matmul
(
x
,
y
)
x
=
tensor
(
shape
=
(
7
,
5
,
3
,
3
))
if
valid_case
:
if
aligned
:
y
=
tensor
(
shape
=
(
7
,
5
,
3
,
3
))
else
:
y
=
tensor
(
shape
=
(
5
,
3
,
3
))
out
=
vectorize
(
core_pt
,
signature
=
signature
)(
x
,
y
)
assert
(
sum
(
isinstance
(
var
.
owner
.
op
,
BatchedDot
)
for
var
in
ancestors
([
out
])
if
var
.
owner
)
==
0
)
vectorize_pt
=
function
(
[
x
,
y
],
vectorize
(
core_pt
,
signature
=
signature
)(
x
,
y
)
,
out
,
mode
=
get_default_mode
()
.
including
(
rewrite
),
)
blocwkise_node
=
any
(
isinstance
(
node
.
op
,
Blockwise
)
for
node
in
vectorize_pt
.
maker
.
fgraph
.
apply_nodes
assert
(
sum
(
isinstance
(
var
.
owner
.
op
,
BatchedDot
)
for
var
in
ancestors
(
vectorize_pt
.
maker
.
fgraph
.
outputs
)
if
var
.
owner
)
==
1
)
if
valid_case
:
assert
not
blocwkise_node
else
:
assert
blocwkise_node
x_test
=
np
.
random
.
normal
(
size
=
x
.
type
.
shape
)
.
astype
(
x
.
type
.
dtype
)
y_test
=
np
.
random
.
normal
(
size
=
y
.
type
.
shape
)
.
astype
(
y
.
type
.
dtype
)
...
...
tests/tensor/rewriting/test_math.py
浏览文件 @
20ff202e
...
...
@@ -42,6 +42,7 @@ from pytensor.tensor.math import (
Prod
,
Sum
,
_conj
,
_matmul
,
add
,
arccosh
,
arcsinh
,
...
...
@@ -4566,6 +4567,88 @@ def test_local_batched_matmul_to_core_matmul():
np
.
testing
.
assert_allclose
(
fn
(
x_test
,
y_test
),
x_test
@
y_test
)
@pytest.mark.parametrize
(
"mat_shape, vec_shape"
,
[
[(
1
,
2
,
2
),
(
5
,
2
)],
[(
5
,
2
,
2
),
(
1
,
2
)],
[(
1
,
1
,
2
,
2
),
(
7
,
5
,
2
)],
[(
7
,
5
,
2
,
2
),
(
1
,
1
,
5
,
2
)],
[(
1
,
5
,
1
,
2
,
2
),
(
7
,
5
,
7
,
2
)],
[(
7
,
5
,
7
,
2
,
2
),
(
1
,
5
,
1
,
2
)],
[(
5
,
1
,
3
,
1
,
2
,
2
),
(
1
,
7
,
3
,
7
,
2
)],
[(
1
,
7
,
3
,
7
,
2
,
2
),
(
5
,
1
,
3
,
1
,
2
)],
],
ids
=
str
,
)
@pytest.mark.parametrize
(
"func"
,
(
"matvec"
,
"vecmat"
,
"vecdot"
))
def
test_batch_matvec_to_matmul
(
func
,
mat_shape
,
vec_shape
):
def
count_matvec_nodes
(
graph
):
# Counts how many matmul nodes actually correspond to matvec or vecmat
return
len
(
[
var
for
var
in
ancestors
([
graph
])
if
(
var
.
owner
is
not
None
and
var
.
owner
.
op
==
_matmul
and
(
(
var
.
owner
.
inputs
[
0
]
.
type
.
shape
[
-
2
]
==
1
)
or
(
var
.
owner
.
inputs
[
1
]
.
type
.
shape
[
-
1
]
==
1
)
)
)
]
)
mat
=
pt
.
tensor
(
"mat"
,
shape
=
mat_shape
,
dtype
=
"float64"
)
vec
=
pt
.
tensor
(
"vec"
,
shape
=
vec_shape
,
dtype
=
"float64"
)
if
func
==
"matvec"
:
out
=
pt
.
matvec
(
mat
,
vec
)
elif
func
==
"vecmat"
:
out
=
pt
.
vecmat
(
vec
,
mat
)
elif
func
==
"vecdot"
:
out
=
pt
.
vecdot
(
mat
[
...
,
0
],
vec
)
else
:
raise
NotImplementedError
(
func
)
assert
count_matvec_nodes
(
out
)
==
1
rewritten_out
=
rewrite_graph
(
out
,
include
=
(
"canonicalize"
,
"specialize"
,
),
exclude
=
(
"local_eager_useless_unbatched_blockwise"
,
"specialize_matmul_to_batched_dot"
,
),
)
# No `matvec` in the rewritten out if one of the vector can be treated as a matrix
expected
=
not
any
(
mat_dim
==
1
and
vec_dim
!=
1
for
vec_dim
,
mat_dim
in
zip
(
vec_shape
[:
-
1
],
mat_shape
[:
-
2
])
)
if
not
expected
and
func
==
"vecdot"
:
# In this case there are two vectors, so we may still end up with a `matvec` unless the second vec can also be treated as matrix
expected
=
not
any
(
mat_dim
!=
1
and
vec_dim
==
1
for
vec_dim
,
mat_dim
in
zip
(
vec_shape
[:
-
1
],
mat_shape
[:
-
2
])
)
assert
count_matvec_nodes
(
rewritten_out
)
==
expected
rng
=
np
.
random
.
default_rng
(
mat_shape
+
vec_shape
)
eval_dict
=
{
mat
:
rng
.
random
(
mat
.
type
.
shape
),
vec
:
rng
.
random
(
vec
.
type
.
shape
)}
# Evaluate results are correct without further rewrites
no_optimization
=
Mode
(
linker
=
"py"
,
optimizer
=
None
)
np
.
testing
.
assert_allclose
(
rewritten_out
.
eval
(
eval_dict
,
mode
=
no_optimization
),
out
.
eval
(
eval_dict
,
mode
=
no_optimization
),
)
def
test_log_kv_stabilization
():
x
=
pt
.
scalar
(
"x"
)
out
=
log
(
kv
(
4.5
,
x
))
...
...
@@ -4616,8 +4699,8 @@ def test_local_dot_to_mul(batched, a_shape, b_shape):
out
=
dot
(
a
,
b
)
if
batched
:
batch_a
=
tensor
(
"batch_a"
,
shape
=
(
1
,
5
,
*
a_shape
))
batch_b
=
tensor
(
"batch_b"
,
shape
=
(
7
,
1
,
*
b_shape
))
batch_a
=
tensor
(
"batch_a"
,
shape
=
(
2
,
1
,
5
,
*
a_shape
))
batch_b
=
tensor
(
"batch_b"
,
shape
=
(
2
,
7
,
1
,
*
b_shape
))
out
=
vectorize_graph
(
out
,
{
a
:
batch_a
,
b
:
batch_b
})
a
=
batch_a
b
=
batch_b
...
...
tests/tensor/test_math.py
浏览文件 @
20ff202e
...
...
@@ -2092,9 +2092,9 @@ class TestDot:
def
test_matrix_vector_ops
():
"""Test vecdot, matvec, and vecmat helper functions."""
rng
=
np
.
random
.
default_rng
(
seed
=
utt
.
fetch_seed
()
)
rng
=
np
.
random
.
default_rng
(
2089
)
# Create test data with batch dimension (2)
atol
=
1e-7
if
config
.
floatX
==
"float32"
else
1e-15
batch_size
=
2
dim_k
=
4
# Common dimension
dim_m
=
3
# Matrix rows
...
...
@@ -2109,7 +2109,6 @@ def test_matrix_vector_ops():
mat_kn_val
=
random
(
batch_size
,
dim_k
,
dim_n
,
rng
=
rng
)
.
astype
(
config
.
floatX
)
vec_k_val
=
random
(
batch_size
,
dim_k
,
rng
=
rng
)
.
astype
(
config
.
floatX
)
# Create tensor variables with matching dtype
mat_mk
=
tensor
(
name
=
"mat_mk"
,
shape
=
(
batch_size
,
dim_m
,
dim_k
),
dtype
=
config
.
floatX
)
...
...
@@ -2130,7 +2129,7 @@ def test_matrix_vector_ops():
expected_vecdot
=
np
.
zeros
((
batch_size
,),
dtype
=
np
.
int32
)
for
i
in
range
(
batch_size
):
expected_vecdot
[
i
]
=
np
.
sum
(
vec_k_val
[
i
]
*
vec_k_val
[
i
])
np
.
testing
.
assert_allclose
(
result
,
expected_vecdot
)
np
.
testing
.
assert_allclose
(
result
,
expected_vecdot
,
atol
=
atol
)
# Test 2: matvec - matrix-vector product
matvec_out
=
matvec
(
mat_mk
,
vec_k
)
...
...
@@ -2141,7 +2140,7 @@ def test_matrix_vector_ops():
expected_matvec
=
np
.
zeros
((
batch_size
,
dim_m
),
dtype
=
config
.
floatX
)
for
i
in
range
(
batch_size
):
expected_matvec
[
i
]
=
np
.
dot
(
mat_mk_val
[
i
],
vec_k_val
[
i
])
np
.
testing
.
assert_allclose
(
result_matvec
,
expected_matvec
)
np
.
testing
.
assert_allclose
(
result_matvec
,
expected_matvec
,
atol
=
atol
)
# Test 3: vecmat - vector-matrix product
vecmat_out
=
vecmat
(
vec_k
,
mat_kn
)
...
...
@@ -2152,7 +2151,7 @@ def test_matrix_vector_ops():
expected_vecmat
=
np
.
zeros
((
batch_size
,
dim_n
),
dtype
=
config
.
floatX
)
for
i
in
range
(
batch_size
):
expected_vecmat
[
i
]
=
np
.
dot
(
vec_k_val
[
i
],
mat_kn_val
[
i
])
np
.
testing
.
assert_allclose
(
result_vecmat
,
expected_vecmat
)
np
.
testing
.
assert_allclose
(
result_vecmat
,
expected_vecmat
,
atol
=
atol
)
class
TestTensordot
:
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论