Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
20ff202e
提交
20ff202e
authored
6月 12, 2025
作者:
Ricardo Vieira
提交者:
Ricardo Vieira
7月 23, 2025
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Define all batched dot operations as matmul
New rewrite is added to convert unpaired batched row/column matvec or vec products as equivalent matmul products.
上级
e265debc
隐藏空白字符变更
内嵌
并排
正在显示
8 个修改的文件
包含
282 行增加
和
105 行删除
+282
-105
math.py
pytensor/tensor/math.py
+16
-32
blas.py
pytensor/tensor/rewriting/blas.py
+2
-2
elemwise.py
pytensor/tensor/rewriting/elemwise.py
+2
-0
linalg.py
pytensor/tensor/rewriting/linalg.py
+2
-2
math.py
pytensor/tensor/rewriting/math.py
+142
-48
test_blas.py
tests/tensor/rewriting/test_blas.py
+28
-13
test_math.py
tests/tensor/rewriting/test_math.py
+85
-2
test_math.py
tests/tensor/test_math.py
+5
-6
没有找到文件。
pytensor/tensor/math.py
浏览文件 @
20ff202e
...
@@ -3921,23 +3921,7 @@ def logsumexp(x, axis=None, keepdims=False):
...
@@ -3921,23 +3921,7 @@ def logsumexp(x, axis=None, keepdims=False):
return
log
(
sum
(
exp
(
x
),
axis
=
axis
,
keepdims
=
keepdims
))
return
log
(
sum
(
exp
(
x
),
axis
=
axis
,
keepdims
=
keepdims
))
# Predefine all batched variations of Dot
_matmul
=
Blockwise
(
_inner_prod
=
Blockwise
(
_dot
,
signature
=
"(n),(n)->()"
,
)
_matrix_vec_prod
=
Blockwise
(
_dot
,
signature
=
"(m,k),(k)->(m)"
,
)
_vec_matrix_prod
=
Blockwise
(
_dot
,
signature
=
"(k),(k,n)->(n)"
,
)
_matrix_matrix_matmul
=
Blockwise
(
_dot
,
_dot
,
signature
=
"(m,k),(k,n)->(m,n)"
,
signature
=
"(m,k),(k,n)->(m,n)"
,
gufunc_spec
=
(
"numpy.matmul"
,
2
,
1
),
gufunc_spec
=
(
"numpy.matmul"
,
2
,
1
),
...
@@ -3993,11 +3977,11 @@ def matmul(x1: "ArrayLike", x2: "ArrayLike", dtype: Optional["DTypeLike"] = None
...
@@ -3993,11 +3977,11 @@ def matmul(x1: "ArrayLike", x2: "ArrayLike", dtype: Optional["DTypeLike"] = None
if
x1
.
type
.
ndim
==
1
and
x2
.
type
.
ndim
==
1
:
if
x1
.
type
.
ndim
==
1
and
x2
.
type
.
ndim
==
1
:
out
=
_dot
(
x1
,
x2
)
out
=
_dot
(
x1
,
x2
)
elif
x1
.
type
.
ndim
==
1
:
elif
x1
.
type
.
ndim
==
1
:
out
=
_matrix_matrix_matmul
(
x1
[
None
],
x2
)
.
squeeze
(
-
2
)
out
=
vecmat
(
x1
,
x
2
)
elif
x2
.
type
.
ndim
==
1
:
elif
x2
.
type
.
ndim
==
1
:
out
=
_matrix_matrix_matmul
(
x1
,
x2
[:,
None
])
.
squeeze
(
-
1
)
out
=
matvec
(
x1
,
x2
)
else
:
else
:
out
=
_mat
rix_matrix_mat
mul
(
x1
,
x2
)
out
=
_matmul
(
x1
,
x2
)
if
dtype
is
not
None
:
if
dtype
is
not
None
:
out
=
out
.
astype
(
dtype
)
out
=
out
.
astype
(
dtype
)
...
@@ -4047,7 +4031,7 @@ def vecdot(
...
@@ -4047,7 +4031,7 @@ def vecdot(
>>> z_batch = pt.vecdot(x_batch, y_batch) # shape (3,)
>>> z_batch = pt.vecdot(x_batch, y_batch) # shape (3,)
>>> # Equivalent to numpy.vecdot(x_batch, y_batch)
>>> # Equivalent to numpy.vecdot(x_batch, y_batch)
"""
"""
out
=
_inner_prod
(
x1
,
x2
)
out
=
matmul
(
x1
[
...
,
None
,
:],
x2
[
...
,
:,
None
])
.
squeeze
((
-
2
,
-
1
)
)
if
dtype
is
not
None
:
if
dtype
is
not
None
:
out
=
out
.
astype
(
dtype
)
out
=
out
.
astype
(
dtype
)
...
@@ -4096,7 +4080,7 @@ def matvec(
...
@@ -4096,7 +4080,7 @@ def matvec(
>>> result = pt.matvec(batched_A, batched_v) # shape (2, 3)
>>> result = pt.matvec(batched_A, batched_v) # shape (2, 3)
>>> # Equivalent to numpy.matvec(batched_A, batched_v)
>>> # Equivalent to numpy.matvec(batched_A, batched_v)
"""
"""
out
=
_matrix_vec_prod
(
x1
,
x2
)
out
=
matmul
(
x1
,
x2
[
...
,
None
])
.
squeeze
(
-
1
)
if
dtype
is
not
None
:
if
dtype
is
not
None
:
out
=
out
.
astype
(
dtype
)
out
=
out
.
astype
(
dtype
)
...
@@ -4134,18 +4118,18 @@ def vecmat(
...
@@ -4134,18 +4118,18 @@ def vecmat(
--------
--------
>>> import pytensor.tensor as pt
>>> import pytensor.tensor as pt
>>> # Vector-matrix product
>>> # Vector-matrix product
>>> v = pt.vector("v", shape=(3,))
# shape (3,)
>>> v = pt.vector("v", shape=(3,))
>>> A = pt.matrix("A", shape=(3, 4))
# shape (3, 4)
>>> A = pt.matrix("A", shape=(3, 4))
>>> result = pt.vecmat(v, A) # shape (4,)
>>> result = pt.vecmat(v, A) # shape (4,)
>>> # Equivalent to numpy.vecmat(v, A)
>>> # Equivalent to numpy.vecmat(v, A)
>>>
>>>
>>> # Batched vector-matrix product
>>> # Batched vector-matrix product
>>> batched_v = pt.matrix("v", shape=(2, 3))
# shape (2, 3)
>>> batched_v = pt.matrix("v", shape=(2, 3))
>>> batched_A = pt.tensor3("A", shape=(2, 3, 4))
# shape (2, 3, 4)
>>> batched_A = pt.tensor3("A", shape=(2, 3, 4))
>>> result = pt.vecmat(batched_v, batched_A) # shape (2, 4)
>>> result = pt.vecmat(batched_v, batched_A) # shape (2, 4)
>>> # Equivalent to numpy.vecmat(batched_v, batched_A)
>>> # Equivalent to numpy.vecmat(batched_v, batched_A)
"""
"""
out
=
_vec_matrix_prod
(
x1
,
x2
)
out
=
matmul
(
x2
.
mT
,
x1
[
...
,
None
])
.
squeeze
(
-
1
)
if
dtype
is
not
None
:
if
dtype
is
not
None
:
out
=
out
.
astype
(
dtype
)
out
=
out
.
astype
(
dtype
)
...
@@ -4160,18 +4144,18 @@ def vectorize_node_dot(op, node, batched_x, batched_y):
...
@@ -4160,18 +4144,18 @@ def vectorize_node_dot(op, node, batched_x, batched_y):
old_y_ndim
=
old_y
.
type
.
ndim
old_y_ndim
=
old_y
.
type
.
ndim
match
(
old_x_ndim
,
old_y_ndim
):
match
(
old_x_ndim
,
old_y_ndim
):
case
(
1
,
1
):
case
(
1
,
1
):
batch_
op
=
_inner_prod
batch_
fn
=
vecdot
case
(
2
,
1
):
case
(
2
,
1
):
batch_
op
=
_matrix_vec_prod
batch_
fn
=
matvec
case
(
1
,
2
):
case
(
1
,
2
):
batch_
op
=
_vec_matrix_prod
batch_
fn
=
vecmat
case
(
2
,
2
):
case
(
2
,
2
):
batch_
op
=
_matrix_matrix_
matmul
batch_
fn
=
matmul
case
_
:
case
_
:
raise
ValueError
(
raise
ValueError
(
f
"Core dot Op should have 1D or 2D inputs, got {old_x_ndim}D and {old_y_ndim}D."
f
"Core dot Op should have 1D or 2D inputs, got {old_x_ndim}D and {old_y_ndim}D."
)
)
return
batch_
op
(
batched_x
,
batched_y
)
.
owner
return
batch_
fn
(
batched_x
,
batched_y
)
.
owner
def
nan_to_num
(
x
,
nan
=
0.0
,
posinf
=
None
,
neginf
=
None
):
def
nan_to_num
(
x
,
nan
=
0.0
,
posinf
=
None
,
neginf
=
None
):
...
...
pytensor/tensor/rewriting/blas.py
浏览文件 @
20ff202e
...
@@ -98,7 +98,7 @@ from pytensor.tensor.elemwise import DimShuffle, Elemwise
...
@@ -98,7 +98,7 @@ from pytensor.tensor.elemwise import DimShuffle, Elemwise
from
pytensor.tensor.exceptions
import
NotScalarConstantError
from
pytensor.tensor.exceptions
import
NotScalarConstantError
from
pytensor.tensor.math
import
(
from
pytensor.tensor.math
import
(
Dot
,
Dot
,
_mat
rix_matrix_mat
mul
,
_matmul
,
add
,
add
,
mul
,
mul
,
neg
,
neg
,
...
@@ -908,7 +908,7 @@ blas_optdb.register(
...
@@ -908,7 +908,7 @@ blas_optdb.register(
@register_specialize
@register_specialize
@node_rewriter
([
_mat
rix_matrix_mat
mul
])
@node_rewriter
([
_matmul
])
def
specialize_matmul_to_batched_dot
(
fgraph
,
node
):
def
specialize_matmul_to_batched_dot
(
fgraph
,
node
):
"""Rewrite Matmul (Blockwise matrix-matrix) without implicit broadcasted batched dimension as BatchedDot.
"""Rewrite Matmul (Blockwise matrix-matrix) without implicit broadcasted batched dimension as BatchedDot.
...
...
pytensor/tensor/rewriting/elemwise.py
浏览文件 @
20ff202e
...
@@ -39,6 +39,7 @@ from pytensor.tensor.rewriting.basic import (
...
@@ -39,6 +39,7 @@ from pytensor.tensor.rewriting.basic import (
broadcasted_by
,
broadcasted_by
,
register_canonicalize
,
register_canonicalize
,
register_specialize
,
register_specialize
,
register_stabilize
,
)
)
from
pytensor.tensor.variable
import
TensorConstant
,
TensorVariable
from
pytensor.tensor.variable
import
TensorConstant
,
TensorVariable
...
@@ -341,6 +342,7 @@ def is_dimshuffle_useless(new_order, input):
...
@@ -341,6 +342,7 @@ def is_dimshuffle_useless(new_order, input):
@register_canonicalize
@register_canonicalize
@register_stabilize
@register_specialize
@register_specialize
@node_rewriter
([
DimShuffle
])
@node_rewriter
([
DimShuffle
])
def
local_dimshuffle_lift
(
fgraph
,
node
):
def
local_dimshuffle_lift
(
fgraph
,
node
):
...
...
pytensor/tensor/rewriting/linalg.py
浏览文件 @
20ff202e
...
@@ -26,7 +26,7 @@ from pytensor.tensor.basic import (
...
@@ -26,7 +26,7 @@ from pytensor.tensor.basic import (
from
pytensor.tensor.blas
import
Dot22
from
pytensor.tensor.blas
import
Dot22
from
pytensor.tensor.blockwise
import
Blockwise
from
pytensor.tensor.blockwise
import
Blockwise
from
pytensor.tensor.elemwise
import
DimShuffle
,
Elemwise
from
pytensor.tensor.elemwise
import
DimShuffle
,
Elemwise
from
pytensor.tensor.math
import
Dot
,
Prod
,
_mat
rix_matrix_mat
mul
,
log
,
outer
,
prod
from
pytensor.tensor.math
import
Dot
,
Prod
,
_matmul
,
log
,
outer
,
prod
from
pytensor.tensor.nlinalg
import
(
from
pytensor.tensor.nlinalg
import
(
SVD
,
SVD
,
KroneckerProduct
,
KroneckerProduct
,
...
@@ -284,7 +284,7 @@ def cholesky_ldotlt(fgraph, node):
...
@@ -284,7 +284,7 @@ def cholesky_ldotlt(fgraph, node):
# This rewrite only applies to matrix Dot
# This rewrite only applies to matrix Dot
and
A
.
owner
.
inputs
[
0
]
.
type
.
ndim
==
2
and
A
.
owner
.
inputs
[
0
]
.
type
.
ndim
==
2
)
)
or
(
A
.
owner
.
op
==
_mat
rix_matrix_mat
mul
)
or
(
A
.
owner
.
op
==
_matmul
)
)
)
):
):
return
return
...
...
pytensor/tensor/rewriting/math.py
浏览文件 @
20ff202e
...
@@ -28,6 +28,7 @@ from pytensor.tensor.basic import (
...
@@ -28,6 +28,7 @@ from pytensor.tensor.basic import (
as_tensor_variable
,
as_tensor_variable
,
cast
,
cast
,
constant
,
constant
,
expand_dims
,
get_underlying_scalar_constant_value
,
get_underlying_scalar_constant_value
,
moveaxis
,
moveaxis
,
ones_like
,
ones_like
,
...
@@ -35,7 +36,6 @@ from pytensor.tensor.basic import (
...
@@ -35,7 +36,6 @@ from pytensor.tensor.basic import (
switch
,
switch
,
zeros_like
,
zeros_like
,
)
)
from
pytensor.tensor.blockwise
import
Blockwise
from
pytensor.tensor.elemwise
import
CAReduce
,
DimShuffle
,
Elemwise
from
pytensor.tensor.elemwise
import
CAReduce
,
DimShuffle
,
Elemwise
from
pytensor.tensor.exceptions
import
NotScalarConstantError
from
pytensor.tensor.exceptions
import
NotScalarConstantError
from
pytensor.tensor.extra_ops
import
broadcast_arrays
from
pytensor.tensor.extra_ops
import
broadcast_arrays
...
@@ -45,10 +45,7 @@ from pytensor.tensor.math import (
...
@@ -45,10 +45,7 @@ from pytensor.tensor.math import (
Sum
,
Sum
,
_conj
,
_conj
,
_dot
,
_dot
,
_inner_prod
,
_matmul
,
_matrix_matrix_matmul
,
_matrix_vec_prod
,
_vec_matrix_prod
,
add
,
add
,
digamma
,
digamma
,
dot
,
dot
,
...
@@ -182,7 +179,7 @@ def local_lift_transpose_through_dot(fgraph, node):
...
@@ -182,7 +179,7 @@ def local_lift_transpose_through_dot(fgraph, node):
if
not
(
if
not
(
is_matrix_transpose
(
node
.
outputs
[
0
])
is_matrix_transpose
(
node
.
outputs
[
0
])
and
node
.
inputs
[
0
]
.
owner
and
node
.
inputs
[
0
]
.
owner
and
((
dot_op
:
=
node
.
inputs
[
0
]
.
owner
.
op
)
in
(
_dot
,
_mat
rix_matrix_mat
mul
))
and
((
dot_op
:
=
node
.
inputs
[
0
]
.
owner
.
op
)
in
(
_dot
,
_matmul
))
):
):
return
False
return
False
...
@@ -197,60 +194,157 @@ def local_lift_transpose_through_dot(fgraph, node):
...
@@ -197,60 +194,157 @@ def local_lift_transpose_through_dot(fgraph, node):
return
ret
return
ret
@register_stabilize
def
_batched_matmul_to_core_matmul
(
fgraph
,
node
,
allow_reshape
:
bool
):
@register_specialize
"""Move batch dimensions of matmul operands to core matmul
@node_rewriter
(
tracks
=
[
Blockwise
])
def
local_batched_matmul_to_core_matmul
(
fgraph
,
node
):
"""Rewrite matmul where only one of the inputs has batch dimensions to a reshaped core matmul.
Example, if x has batch dimensions
, but y not:
Example, if x has batch dimensions
that don't overlap with batch dimensions of y
x @ y -> (x.reshape(-1, x.shape[-1]) @ y).reshape(*x.shape[:-1], y.shape[-1])
x @ y -> (x.reshape(-1, x.shape[-1]) @ y).reshape(*x.shape[:-1], y.shape[-1])
It also works when y has batch dimensions, but x not.
It also works for batch dimensions of y that don't overlap with batch dimensions of x
"""
# Check whether we have a matmul operation in this node
The rewrite only uses reshape when mixing dimensions, and it can refuse to apply if `allow_reshape=False`
if
not
(
"""
isinstance
(
node
.
op
.
core_op
,
Dot
)
and
len
(
node
.
op
.
inputs_sig
[
0
])
==
2
and
len
(
node
.
op
.
inputs_sig
[
1
])
==
2
):
return
None
x
,
y
=
node
.
inputs
x
,
y
=
node
.
inputs
batch_ndim
=
node
.
op
.
batch_ndim
(
node
)
batch_ndim
=
node
.
op
.
batch_ndim
(
node
)
# Check if x has batch dimensions, but y not (or only broadcastable dimensions)
x_axis_to_merge
=
[
if
any
(
not
b_dim
for
b_dim
in
x
.
type
.
broadcastable
[:
-
2
])
and
all
(
i
y
.
type
.
broadcastable
[:
-
2
]
for
i
,
(
bcast_x
,
bcast_y
)
in
enumerate
(
):
zip
(
x
.
type
.
broadcastable
[:
-
2
],
y
.
type
.
broadcastable
[:
-
2
])
x_stacked
=
x
.
reshape
((
-
1
,
x
.
shape
[
-
1
]))
)
out_stacked
=
x_stacked
@
y
.
squeeze
(
tuple
(
range
(
batch_ndim
)))
if
bcast_y
and
not
bcast_x
out
=
out_stacked
.
reshape
((
*
x
.
shape
[:
-
1
],
y
.
shape
[
-
1
]))
]
return
[
out
]
y_axis_to_merge
=
[
# Otherwise, check if y has batch dimension, but x not
i
elif
any
(
not
b_dim
for
b_dim
in
y
.
type
.
broadcastable
[:
-
2
])
and
all
(
for
i
,
(
bcast_x
,
bcast_y
)
in
enumerate
(
x
.
type
.
broadcastable
[:
-
2
]
zip
(
x
.
type
.
broadcastable
[:
-
2
],
y
.
type
.
broadcastable
[:
-
2
])
):
)
# For the y batch case we need to first move the batch axes and then reshape
if
bcast_x
and
not
bcast_y
# y.shape == (*b, k, n)
]
y_tr
=
moveaxis
(
y
,
-
2
,
0
)
# (k, *b, n)
y_stacked
=
y_tr
.
reshape
((
y
.
shape
[
-
2
],
-
1
))
# (k, *b * n)
if
not
(
x_axis_to_merge
or
y_axis_to_merge
):
out_stacked
=
x
.
squeeze
(
tuple
(
range
(
batch_ndim
)))
@
y_stacked
# (m, *b * n)
return
None
out_stacked_tr
=
out_stacked
.
reshape
(
(
x
.
shape
[
-
2
],
*
y
.
shape
[:
-
2
],
y
.
shape
[
-
1
])
x_shape
=
tuple
(
x
.
shape
)
)
# (m, *b, n)
y_shape
=
tuple
(
y
.
shape
)
out
=
moveaxis
(
out_stacked_tr
,
0
,
-
2
)
# (*b, m, n)
x_is_row
=
x
.
type
.
broadcastable
[
-
2
]
return
[
out
]
y_is_col
=
y
.
type
.
broadcastable
[
-
1
]
n_x_axis_to_merge
=
len
(
x_axis_to_merge
)
# Both x and y have batch dimensions, nothing to do here
n_y_axis_to_merge
=
len
(
y_axis_to_merge
)
return
None
n_axis_to_merge
=
n_x_axis_to_merge
+
n_y_axis_to_merge
x_stacked
,
y_stacked
=
x
,
y
dims_were_merged
=
False
if
n_x_axis_to_merge
:
# ravel batch dimensions of x on the core (m) axis
x_axis_destination
=
tuple
(
range
(
-
n_x_axis_to_merge
-
2
,
-
2
))
x_stacked
=
moveaxis
(
x
,
x_axis_to_merge
,
x_axis_destination
)
if
x_is_row
:
# x was a row matrix, squeeze it to clean up the graph
x_stacked
=
x_stacked
.
squeeze
(
-
2
)
if
n_x_axis_to_merge
>
1
or
not
x_is_row
:
if
not
allow_reshape
:
# TODO: We could allow the y rewrite to go on
# Or just move one axis (the largest) if x is row
return
None
# Ravel moved batch dims together with (m) if needed
x_stacked_shape
=
tuple
(
x_stacked
.
shape
)
x_stacked
=
x_stacked
.
reshape
(
(
*
x_stacked_shape
[:
batch_ndim
-
n_x_axis_to_merge
],
-
1
,
x_shape
[
-
1
])
)
dims_were_merged
=
True
if
n_y_axis_to_merge
:
# ravel batch dimensions of y on the core (n) axis
y_axis_destination
=
tuple
(
range
(
-
n_y_axis_to_merge
-
1
,
-
1
))
y_stacked
=
moveaxis
(
y
,
y_axis_to_merge
,
y_axis_destination
)
if
y_is_col
:
# y was a column matrix, squeeze it to clean up the graph
y_stacked
=
y_stacked
.
squeeze
(
-
1
)
if
n_y_axis_to_merge
>
1
or
not
y_is_col
:
if
not
allow_reshape
:
# TODO: We could allow the x rewrite to go on
# Or just move one axis (the largest) if y is col
return
None
# Ravel moved batch dims together with (n) if needed
y_stacked_shape
=
tuple
(
y_stacked
.
shape
)
y_stacked
=
y_stacked
.
reshape
(
(
*
y_stacked_shape
[:
batch_ndim
-
n_y_axis_to_merge
],
y_shape
[
-
2
],
-
1
)
)
dims_were_merged
=
True
# Squeeze x_dims corresponding to merged dimensions of y
x_axis_to_squeeze
=
np
.
array
(
y_axis_to_merge
)
for
i
in
reversed
(
x_axis_to_merge
):
# The corresponding dimensions of y may have shifted when we merged dimensions of x
x_axis_to_squeeze
[
x_axis_to_squeeze
>
i
]
-=
1
x_stacked
=
x_stacked
.
squeeze
(
tuple
(
x_axis_to_squeeze
))
# Same for y
y_axis_to_squeeze
=
np
.
array
(
x_axis_to_merge
)
for
i
in
reversed
(
y_axis_to_merge
):
y_axis_to_squeeze
[
y_axis_to_squeeze
>
i
]
-=
1
y_stacked
=
y_stacked
.
squeeze
(
tuple
(
y_axis_to_squeeze
))
out_stacked
=
x_stacked
@
y_stacked
# Split back any merged dimensions
if
dims_were_merged
:
x_merged_shapes
=
[
x_shape
[
i
]
for
i
in
x_axis_to_merge
]
if
not
x_is_row
:
# Otherwise we handle that later with expand_dims, which is cleaner
x_merged_shapes
.
append
(
x_shape
[
-
2
])
y_merged_shapes
=
[
y_shape
[
i
]
for
i
in
y_axis_to_merge
]
if
not
y_is_col
:
# Otherwise we handle that later with expand_dims, which is cleaner
y_merged_shapes
.
append
(
y_shape
[
-
1
])
out_stacked_shape
=
tuple
(
out_stacked
.
shape
)
out_unstacked
=
out_stacked
.
reshape
(
(
*
out_stacked_shape
[:
batch_ndim
-
n_axis_to_merge
],
*
x_merged_shapes
,
*
y_merged_shapes
,
)
)
else
:
out_unstacked
=
out_stacked
# Add back dummy row, col axis
# We do this separately to avoid the reshape as much as we can
if
y_is_col
and
(
n_y_axis_to_merge
or
dims_were_merged
):
out_unstacked
=
expand_dims
(
out_unstacked
,
-
1
)
if
x_is_row
and
(
n_x_axis_to_merge
or
dims_were_merged
):
out_unstacked
=
expand_dims
(
out_unstacked
,
-
n_y_axis_to_merge
-
2
)
# Move batch axis back to their original location
source
=
range
(
-
n_axis_to_merge
-
2
,
0
)
destination
=
(
*
x_axis_to_merge
,
-
2
,
*
y_axis_to_merge
,
-
1
)
out
=
moveaxis
(
out_unstacked
,
source
,
destination
)
return
[
out
]
@register_canonicalize
@node_rewriter
(
tracks
=
[
_matmul
])
def
local_batched_matmul_to_core_matmul
(
fgraph
,
node
):
# Allow passing batch dimensions of matmul to core vector / column matrices
return
_batched_matmul_to_core_matmul
(
fgraph
,
node
,
allow_reshape
=
False
)
@register_specialize
@node_rewriter
(
tracks
=
[
_matmul
])
def
local_batched_matmul_to_core_matmul_with_reshape
(
fgraph
,
node
):
# Allow stacking batch dimensions of matmul with core dimensions, with a reshape operation
# We only apply this in specialize, because grahs with reshape are hard to work with
return
_batched_matmul_to_core_matmul
(
fgraph
,
node
,
allow_reshape
=
True
)
@register_canonicalize
@register_canonicalize
@register_specialize
@register_specialize
@node_rewriter
([
_
inner_prod
,
_matrix_vec_prod
,
_vec_matrix_prod
,
_matrix_matrix_
matmul
])
@node_rewriter
([
_matmul
])
def
local_blockwise_dot_to_mul
(
fgraph
,
node
):
def
local_blockwise_dot_to_mul
(
fgraph
,
node
):
"""Rewrite blockwise dots that correspond to multiplication without summation.
"""Rewrite blockwise dots that correspond to multiplication without summation.
...
...
tests/tensor/rewriting/test_blas.py
浏览文件 @
20ff202e
import
numpy
as
np
import
numpy
as
np
import
pytest
import
pytest
from
pytensor
import
function
from
pytensor
import
config
,
function
from
pytensor
import
tensor
as
pt
from
pytensor
import
tensor
as
pt
from
pytensor.compile
import
get_default_mode
from
pytensor.compile
import
get_default_mode
from
pytensor.graph
import
FunctionGraph
from
pytensor.graph
import
FunctionGraph
,
ancestors
from
pytensor.tensor
import
(
from
pytensor.tensor
import
(
col
,
col
,
dscalar
,
dscalar
,
...
@@ -21,7 +21,6 @@ from pytensor.tensor import (
...
@@ -21,7 +21,6 @@ from pytensor.tensor import (
vectorize
,
vectorize
,
)
)
from
pytensor.tensor.blas
import
BatchedDot
from
pytensor.tensor.blas
import
BatchedDot
from
pytensor.tensor.blockwise
import
Blockwise
from
pytensor.tensor.elemwise
import
DimShuffle
from
pytensor.tensor.elemwise
import
DimShuffle
from
pytensor.tensor.rewriting.blas
import
(
from
pytensor.tensor.rewriting.blas
import
(
_as_scalar
,
_as_scalar
,
...
@@ -37,8 +36,11 @@ def XYZab():
...
@@ -37,8 +36,11 @@ def XYZab():
return
matrix
(),
matrix
(),
matrix
(),
scalar
(),
scalar
()
return
matrix
(),
matrix
(),
matrix
(),
scalar
(),
scalar
()
@pytest.mark.parametrize
(
"valid_case"
,
(
True
,
False
))
@pytest.mark.skipif
(
def
test_specialize_matmul_to_batched_dot
(
valid_case
):
config
.
mode
==
"FAST_COMPILE"
,
reason
=
"Test requires specialization rewrites"
)
@pytest.mark.parametrize
(
"aligned"
,
(
True
,
False
))
def
test_specialize_matmul_to_batched_dot
(
aligned
):
signature
=
BatchedDot
.
gufunc_signature
signature
=
BatchedDot
.
gufunc_signature
rewrite
=
specialize_matmul_to_batched_dot
.
__name__
rewrite
=
specialize_matmul_to_batched_dot
.
__name__
...
@@ -49,23 +51,36 @@ def test_specialize_matmul_to_batched_dot(valid_case):
...
@@ -49,23 +51,36 @@ def test_specialize_matmul_to_batched_dot(valid_case):
return
np
.
matmul
(
x
,
y
)
return
np
.
matmul
(
x
,
y
)
x
=
tensor
(
shape
=
(
7
,
5
,
3
,
3
))
x
=
tensor
(
shape
=
(
7
,
5
,
3
,
3
))
if
valid_case
:
if
aligned
:
y
=
tensor
(
shape
=
(
7
,
5
,
3
,
3
))
y
=
tensor
(
shape
=
(
7
,
5
,
3
,
3
))
else
:
else
:
y
=
tensor
(
shape
=
(
5
,
3
,
3
))
y
=
tensor
(
shape
=
(
5
,
3
,
3
))
out
=
vectorize
(
core_pt
,
signature
=
signature
)(
x
,
y
)
assert
(
sum
(
isinstance
(
var
.
owner
.
op
,
BatchedDot
)
for
var
in
ancestors
([
out
])
if
var
.
owner
)
==
0
)
vectorize_pt
=
function
(
vectorize_pt
=
function
(
[
x
,
y
],
[
x
,
y
],
vectorize
(
core_pt
,
signature
=
signature
)(
x
,
y
)
,
out
,
mode
=
get_default_mode
()
.
including
(
rewrite
),
mode
=
get_default_mode
()
.
including
(
rewrite
),
)
)
blocwkise_node
=
any
(
isinstance
(
node
.
op
,
Blockwise
)
for
node
in
vectorize_pt
.
maker
.
fgraph
.
apply_nodes
assert
(
sum
(
isinstance
(
var
.
owner
.
op
,
BatchedDot
)
for
var
in
ancestors
(
vectorize_pt
.
maker
.
fgraph
.
outputs
)
if
var
.
owner
)
==
1
)
)
if
valid_case
:
assert
not
blocwkise_node
else
:
assert
blocwkise_node
x_test
=
np
.
random
.
normal
(
size
=
x
.
type
.
shape
)
.
astype
(
x
.
type
.
dtype
)
x_test
=
np
.
random
.
normal
(
size
=
x
.
type
.
shape
)
.
astype
(
x
.
type
.
dtype
)
y_test
=
np
.
random
.
normal
(
size
=
y
.
type
.
shape
)
.
astype
(
y
.
type
.
dtype
)
y_test
=
np
.
random
.
normal
(
size
=
y
.
type
.
shape
)
.
astype
(
y
.
type
.
dtype
)
...
...
tests/tensor/rewriting/test_math.py
浏览文件 @
20ff202e
...
@@ -42,6 +42,7 @@ from pytensor.tensor.math import (
...
@@ -42,6 +42,7 @@ from pytensor.tensor.math import (
Prod
,
Prod
,
Sum
,
Sum
,
_conj
,
_conj
,
_matmul
,
add
,
add
,
arccosh
,
arccosh
,
arcsinh
,
arcsinh
,
...
@@ -4566,6 +4567,88 @@ def test_local_batched_matmul_to_core_matmul():
...
@@ -4566,6 +4567,88 @@ def test_local_batched_matmul_to_core_matmul():
np
.
testing
.
assert_allclose
(
fn
(
x_test
,
y_test
),
x_test
@
y_test
)
np
.
testing
.
assert_allclose
(
fn
(
x_test
,
y_test
),
x_test
@
y_test
)
@pytest.mark.parametrize
(
"mat_shape, vec_shape"
,
[
[(
1
,
2
,
2
),
(
5
,
2
)],
[(
5
,
2
,
2
),
(
1
,
2
)],
[(
1
,
1
,
2
,
2
),
(
7
,
5
,
2
)],
[(
7
,
5
,
2
,
2
),
(
1
,
1
,
5
,
2
)],
[(
1
,
5
,
1
,
2
,
2
),
(
7
,
5
,
7
,
2
)],
[(
7
,
5
,
7
,
2
,
2
),
(
1
,
5
,
1
,
2
)],
[(
5
,
1
,
3
,
1
,
2
,
2
),
(
1
,
7
,
3
,
7
,
2
)],
[(
1
,
7
,
3
,
7
,
2
,
2
),
(
5
,
1
,
3
,
1
,
2
)],
],
ids
=
str
,
)
@pytest.mark.parametrize
(
"func"
,
(
"matvec"
,
"vecmat"
,
"vecdot"
))
def
test_batch_matvec_to_matmul
(
func
,
mat_shape
,
vec_shape
):
def
count_matvec_nodes
(
graph
):
# Counts how many matmul nodes actually correspond to matvec or vecmat
return
len
(
[
var
for
var
in
ancestors
([
graph
])
if
(
var
.
owner
is
not
None
and
var
.
owner
.
op
==
_matmul
and
(
(
var
.
owner
.
inputs
[
0
]
.
type
.
shape
[
-
2
]
==
1
)
or
(
var
.
owner
.
inputs
[
1
]
.
type
.
shape
[
-
1
]
==
1
)
)
)
]
)
mat
=
pt
.
tensor
(
"mat"
,
shape
=
mat_shape
,
dtype
=
"float64"
)
vec
=
pt
.
tensor
(
"vec"
,
shape
=
vec_shape
,
dtype
=
"float64"
)
if
func
==
"matvec"
:
out
=
pt
.
matvec
(
mat
,
vec
)
elif
func
==
"vecmat"
:
out
=
pt
.
vecmat
(
vec
,
mat
)
elif
func
==
"vecdot"
:
out
=
pt
.
vecdot
(
mat
[
...
,
0
],
vec
)
else
:
raise
NotImplementedError
(
func
)
assert
count_matvec_nodes
(
out
)
==
1
rewritten_out
=
rewrite_graph
(
out
,
include
=
(
"canonicalize"
,
"specialize"
,
),
exclude
=
(
"local_eager_useless_unbatched_blockwise"
,
"specialize_matmul_to_batched_dot"
,
),
)
# No `matvec` in the rewritten out if one of the vector can be treated as a matrix
expected
=
not
any
(
mat_dim
==
1
and
vec_dim
!=
1
for
vec_dim
,
mat_dim
in
zip
(
vec_shape
[:
-
1
],
mat_shape
[:
-
2
])
)
if
not
expected
and
func
==
"vecdot"
:
# In this case there are two vectors, so we may still end up with a `matvec` unless the second vec can also be treated as matrix
expected
=
not
any
(
mat_dim
!=
1
and
vec_dim
==
1
for
vec_dim
,
mat_dim
in
zip
(
vec_shape
[:
-
1
],
mat_shape
[:
-
2
])
)
assert
count_matvec_nodes
(
rewritten_out
)
==
expected
rng
=
np
.
random
.
default_rng
(
mat_shape
+
vec_shape
)
eval_dict
=
{
mat
:
rng
.
random
(
mat
.
type
.
shape
),
vec
:
rng
.
random
(
vec
.
type
.
shape
)}
# Evaluate results are correct without further rewrites
no_optimization
=
Mode
(
linker
=
"py"
,
optimizer
=
None
)
np
.
testing
.
assert_allclose
(
rewritten_out
.
eval
(
eval_dict
,
mode
=
no_optimization
),
out
.
eval
(
eval_dict
,
mode
=
no_optimization
),
)
def
test_log_kv_stabilization
():
def
test_log_kv_stabilization
():
x
=
pt
.
scalar
(
"x"
)
x
=
pt
.
scalar
(
"x"
)
out
=
log
(
kv
(
4.5
,
x
))
out
=
log
(
kv
(
4.5
,
x
))
...
@@ -4616,8 +4699,8 @@ def test_local_dot_to_mul(batched, a_shape, b_shape):
...
@@ -4616,8 +4699,8 @@ def test_local_dot_to_mul(batched, a_shape, b_shape):
out
=
dot
(
a
,
b
)
out
=
dot
(
a
,
b
)
if
batched
:
if
batched
:
batch_a
=
tensor
(
"batch_a"
,
shape
=
(
1
,
5
,
*
a_shape
))
batch_a
=
tensor
(
"batch_a"
,
shape
=
(
2
,
1
,
5
,
*
a_shape
))
batch_b
=
tensor
(
"batch_b"
,
shape
=
(
7
,
1
,
*
b_shape
))
batch_b
=
tensor
(
"batch_b"
,
shape
=
(
2
,
7
,
1
,
*
b_shape
))
out
=
vectorize_graph
(
out
,
{
a
:
batch_a
,
b
:
batch_b
})
out
=
vectorize_graph
(
out
,
{
a
:
batch_a
,
b
:
batch_b
})
a
=
batch_a
a
=
batch_a
b
=
batch_b
b
=
batch_b
...
...
tests/tensor/test_math.py
浏览文件 @
20ff202e
...
@@ -2092,9 +2092,9 @@ class TestDot:
...
@@ -2092,9 +2092,9 @@ class TestDot:
def
test_matrix_vector_ops
():
def
test_matrix_vector_ops
():
"""Test vecdot, matvec, and vecmat helper functions."""
"""Test vecdot, matvec, and vecmat helper functions."""
rng
=
np
.
random
.
default_rng
(
seed
=
utt
.
fetch_seed
()
)
rng
=
np
.
random
.
default_rng
(
2089
)
# Create test data with batch dimension (2)
atol
=
1e-7
if
config
.
floatX
==
"float32"
else
1e-15
batch_size
=
2
batch_size
=
2
dim_k
=
4
# Common dimension
dim_k
=
4
# Common dimension
dim_m
=
3
# Matrix rows
dim_m
=
3
# Matrix rows
...
@@ -2109,7 +2109,6 @@ def test_matrix_vector_ops():
...
@@ -2109,7 +2109,6 @@ def test_matrix_vector_ops():
mat_kn_val
=
random
(
batch_size
,
dim_k
,
dim_n
,
rng
=
rng
)
.
astype
(
config
.
floatX
)
mat_kn_val
=
random
(
batch_size
,
dim_k
,
dim_n
,
rng
=
rng
)
.
astype
(
config
.
floatX
)
vec_k_val
=
random
(
batch_size
,
dim_k
,
rng
=
rng
)
.
astype
(
config
.
floatX
)
vec_k_val
=
random
(
batch_size
,
dim_k
,
rng
=
rng
)
.
astype
(
config
.
floatX
)
# Create tensor variables with matching dtype
mat_mk
=
tensor
(
mat_mk
=
tensor
(
name
=
"mat_mk"
,
shape
=
(
batch_size
,
dim_m
,
dim_k
),
dtype
=
config
.
floatX
name
=
"mat_mk"
,
shape
=
(
batch_size
,
dim_m
,
dim_k
),
dtype
=
config
.
floatX
)
)
...
@@ -2130,7 +2129,7 @@ def test_matrix_vector_ops():
...
@@ -2130,7 +2129,7 @@ def test_matrix_vector_ops():
expected_vecdot
=
np
.
zeros
((
batch_size
,),
dtype
=
np
.
int32
)
expected_vecdot
=
np
.
zeros
((
batch_size
,),
dtype
=
np
.
int32
)
for
i
in
range
(
batch_size
):
for
i
in
range
(
batch_size
):
expected_vecdot
[
i
]
=
np
.
sum
(
vec_k_val
[
i
]
*
vec_k_val
[
i
])
expected_vecdot
[
i
]
=
np
.
sum
(
vec_k_val
[
i
]
*
vec_k_val
[
i
])
np
.
testing
.
assert_allclose
(
result
,
expected_vecdot
)
np
.
testing
.
assert_allclose
(
result
,
expected_vecdot
,
atol
=
atol
)
# Test 2: matvec - matrix-vector product
# Test 2: matvec - matrix-vector product
matvec_out
=
matvec
(
mat_mk
,
vec_k
)
matvec_out
=
matvec
(
mat_mk
,
vec_k
)
...
@@ -2141,7 +2140,7 @@ def test_matrix_vector_ops():
...
@@ -2141,7 +2140,7 @@ def test_matrix_vector_ops():
expected_matvec
=
np
.
zeros
((
batch_size
,
dim_m
),
dtype
=
config
.
floatX
)
expected_matvec
=
np
.
zeros
((
batch_size
,
dim_m
),
dtype
=
config
.
floatX
)
for
i
in
range
(
batch_size
):
for
i
in
range
(
batch_size
):
expected_matvec
[
i
]
=
np
.
dot
(
mat_mk_val
[
i
],
vec_k_val
[
i
])
expected_matvec
[
i
]
=
np
.
dot
(
mat_mk_val
[
i
],
vec_k_val
[
i
])
np
.
testing
.
assert_allclose
(
result_matvec
,
expected_matvec
)
np
.
testing
.
assert_allclose
(
result_matvec
,
expected_matvec
,
atol
=
atol
)
# Test 3: vecmat - vector-matrix product
# Test 3: vecmat - vector-matrix product
vecmat_out
=
vecmat
(
vec_k
,
mat_kn
)
vecmat_out
=
vecmat
(
vec_k
,
mat_kn
)
...
@@ -2152,7 +2151,7 @@ def test_matrix_vector_ops():
...
@@ -2152,7 +2151,7 @@ def test_matrix_vector_ops():
expected_vecmat
=
np
.
zeros
((
batch_size
,
dim_n
),
dtype
=
config
.
floatX
)
expected_vecmat
=
np
.
zeros
((
batch_size
,
dim_n
),
dtype
=
config
.
floatX
)
for
i
in
range
(
batch_size
):
for
i
in
range
(
batch_size
):
expected_vecmat
[
i
]
=
np
.
dot
(
vec_k_val
[
i
],
mat_kn_val
[
i
])
expected_vecmat
[
i
]
=
np
.
dot
(
vec_k_val
[
i
],
mat_kn_val
[
i
])
np
.
testing
.
assert_allclose
(
result_vecmat
,
expected_vecmat
)
np
.
testing
.
assert_allclose
(
result_vecmat
,
expected_vecmat
,
atol
=
atol
)
class
TestTensordot
:
class
TestTensordot
:
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论