Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
efd9f491
提交
efd9f491
authored
2月 17, 2025
作者:
ricardoV94
提交者:
Ricardo Vieira
3月 11, 2025
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Deprecate BLAS batch helper functions
上级
b3da2a4b
显示空白字符变更
内嵌
并排
正在显示
5 个修改的文件
包含
47 行增加
和
195 行删除
+47
-195
blas.py
pytensor/tensor/blas.py
+36
-25
math.py
pytensor/tensor/math.py
+1
-128
blas.py
pytensor/tensor/rewriting/blas.py
+2
-2
utils.py
pytensor/tensor/utils.py
+0
-8
test_blas.py
tests/tensor/test_blas.py
+8
-32
没有找到文件。
pytensor/tensor/blas.py
浏览文件 @
efd9f491
...
...
@@ -79,10 +79,14 @@ import functools
import
logging
import
os
import
shlex
import
warnings
from
pathlib
import
Path
import
numpy
as
np
from
pytensor.graph
import
vectorize_graph
from
pytensor.npy_2_compat
import
normalize_axis_tuple
try
:
import
numpy.__config__
...
...
@@ -100,9 +104,9 @@ from pytensor.link.c.params_type import ParamsType
from
pytensor.printing
import
FunctionPrinter
,
pprint
from
pytensor.scalar
import
bool
as
bool_t
from
pytensor.tensor
import
basic
as
ptb
from
pytensor.tensor.basic
import
expand_dims
from
pytensor.tensor.blas_headers
import
blas_header_text
,
blas_header_version
from
pytensor.tensor.shape
import
shape_padright
,
specify_broadcastable
from
pytensor.tensor.math
import
dot
,
tensordot
from
pytensor.tensor.shape
import
specify_broadcastable
from
pytensor.tensor.type
import
DenseTensorType
,
tensor
...
...
@@ -1604,8 +1608,8 @@ class BatchedDot(COp):
x
,
y
=
inp
(
gz
,)
=
grads
xgrad
=
batched_dot
(
gz
,
y
.
dimshuffle
(
0
,
2
,
1
))
ygrad
=
batched_dot
(
x
.
dimshuffle
(
0
,
2
,
1
),
gz
)
xgrad
=
_
batched_dot
(
gz
,
y
.
dimshuffle
(
0
,
2
,
1
))
ygrad
=
_
batched_dot
(
x
.
dimshuffle
(
0
,
2
,
1
),
gz
)
# If x or y contain broadcastable dimensions but only one of
# them know that a matching dimensions is broadcastable, the
...
...
@@ -1729,31 +1733,22 @@ def batched_dot(a, b):
dot products in terms of batched matrix-matrix dot products, so
it may be possible to further optimize for performance.
"""
warnings
.
warn
(
"batched_dot is deprecated. "
"Use `dot` in conjution with `tensor.vectorize` or `graph.replace.vectorize_graph`"
,
FutureWarning
,
)
a
,
b
=
ptb
.
as_tensor_variable
(
a
),
ptb
.
as_tensor_variable
(
b
)
if
a
.
ndim
==
0
:
raise
TypeError
(
"a must have at least one (batch) axis"
)
elif
b
.
ndim
==
0
:
raise
TypeError
(
"b must have at least one (batch) axis"
)
elif
a
.
ndim
==
1
:
return
shape_padright
(
a
,
(
b
.
ndim
-
1
))
*
b
elif
b
.
ndim
==
1
:
return
a
*
shape_padright
(
b
,
(
a
.
ndim
-
1
))
elif
a
.
ndim
>
3
or
b
.
ndim
>
3
:
return
batched_tensordot
(
a
,
b
,
[[
a
.
ndim
-
1
],
[
np
.
maximum
(
1
,
b
.
ndim
-
2
)]])
else
:
# If either a or b is a batched vector, expand dims and later squeeze them
expanded_axis
=
[]
if
a
.
ndim
==
2
:
a
=
expand_dims
(
a
,
axis
=
1
)
expanded_axis
.
append
(
1
)
if
b
.
ndim
==
2
:
b
=
expand_dims
(
b
,
axis
=
2
)
expanded_axis
.
append
(
2
)
out
=
_batched_dot
(
a
,
b
)
if
expanded_axis
:
out
=
out
.
squeeze
(
axis
=
expanded_axis
)
return
out
core_a
=
a
[
0
]
.
type
()
core_b
=
b
[
0
]
.
type
()
core_dot
=
dot
(
core_a
,
core_b
)
return
vectorize_graph
(
core_dot
,
replace
=
{
core_a
:
a
,
core_b
:
b
})
def
batched_tensordot
(
x
,
y
,
axes
=
2
):
...
...
@@ -1791,6 +1786,22 @@ def batched_tensordot(x, y, axes=2):
reshapes to reduce the tensor dot product to a matrix or vector
dot product. Finally, it calls batched_dot to compute the result.
"""
from
pytensor.tensor.math
import
_tensordot_as_dot
warnings
.
warn
(
"batched_tensordot is deprecated. "
"Use `tensordot` in conjuction with `tensor.vectorize` or `graph.replace.vectorize_graph`"
,
FutureWarning
,
)
if
isinstance
(
axes
,
int
):
core_axes
=
axes
else
:
# Convert batched axes to core axes
core_axes_a
=
[
a
-
1
for
a
in
normalize_axis_tuple
(
axes
[
0
],
x
.
type
.
ndim
)]
core_axes
=
[
a
-
1
for
a
in
normalize_axis_tuple
(
axes
[
1
],
y
.
type
.
ndim
)]
core_axes
=
[
core_axes_a
,
core_axes
]
core_x
=
x
[
0
]
.
type
()
core_y
=
y
[
0
]
.
type
()
core_tensordot
=
tensordot
(
core_x
,
core_y
,
axes
=
core_axes
)
return
_tensordot_as_dot
(
x
,
y
,
axes
,
dot
=
batched_dot
,
batched
=
True
)
return
vectorize_graph
(
core_tensordot
,
replace
=
{
core_x
:
x
,
core_y
:
y
}
)
pytensor/tensor/math.py
浏览文件 @
efd9f491
...
...
@@ -50,7 +50,7 @@ from pytensor.tensor.type import (
tensor
,
uint_dtypes
,
)
from
pytensor.tensor.utils
import
as_list
,
normalize_reduce_axis
from
pytensor.tensor.utils
import
normalize_reduce_axis
from
pytensor.tensor.variable
import
(
TensorVariable
,
_tensor_py_operators
,
...
...
@@ -3208,133 +3208,6 @@ def dense_dot(a, b):
return
_dot
(
a
,
b
)
def
_tensordot_as_dot
(
a
,
b
,
axes
,
dot
,
batched
):
"""
Reduces a tensor dot product to a matrix or vector dot product. Based
on code from Tijmen Tieleman's gnumpy
(http://www.cs.toronto.edu/~tijmen/gnumpy.html).
Please see the documentation of tensordot for the meaning of the a, b
and axes arguments.
:param dot: a function that accepts two symbolic variables and computes
the appropriate dot product (e.g. dot, batched_dot)
:type dot: function
:param batched: whether to treat the first axis of a and b as a batch
axis. If so, this axis will be preserved in the output,
allowing this function to be used also for batched
tensor dot products.
:type batched: boolean
:returns: a tensor with shape equal to the concatenation of a's shape
(less any dimensions that were summed over) and b's shape
(less the first dimension and any dimensions that were summed
over).
:rtype: symbolic tensor
"""
a
,
b
=
as_tensor_variable
(
a
),
as_tensor_variable
(
b
)
if
not
np
.
isscalar
(
axes
)
and
len
(
axes
)
!=
2
:
raise
ValueError
(
"Axes should be an integer or a "
f
"list/tuple of len 2 ({axes} was provided)"
)
# if 'axes' is a number of axes to multiply and sum over (trailing axes
# of a, leading axes of b), we can just reshape and use dot.
elif
np
.
isscalar
(
axes
):
axes
=
int
(
axes
)
for
operand_name
,
operand
in
((
"a"
,
a
),
(
"b"
,
b
)):
if
axes
>
operand
.
ndim
:
raise
ValueError
(
f
"axes can not be larger than the dimension of {operand_name} "
f
"({operand_name}.ndim={operand.ndim}, axes={axes})"
)
if
batched
and
axes
==
operand
.
ndim
:
raise
ValueError
(
"axes to sum over must not include the batch axis "
f
"of {operand_name} ({operand_name}.ndim={operand.ndim}, axes={axes})"
)
batch_axes
=
1
if
batched
else
0
a_outaxes
=
slice
(
0
,
a
.
ndim
-
axes
)
b_outaxes
=
slice
(
batch_axes
+
axes
,
b
.
ndim
)
outshape
=
concatenate
([
a
.
shape
[
a_outaxes
],
b
.
shape
[
b_outaxes
]])
outbcast
=
a
.
broadcastable
[
a_outaxes
]
+
b
.
broadcastable
[
b_outaxes
]
outndim
=
len
(
outbcast
)
a_shape
=
[
1
]
*
2
b_shape
=
[
1
]
*
2
# compute total size of summed axes
for
i
in
range
(
0
,
axes
):
a_shape
[
1
]
*=
a
.
shape
[
-
(
i
+
1
)]
b_shape
[
0
]
*=
b
.
shape
[
batch_axes
+
i
]
# compute total size of other axes
for
i
in
range
(
0
,
a
.
ndim
-
axes
-
batch_axes
):
a_shape
[
0
]
*=
a
.
shape
[
batch_axes
+
i
]
for
i
in
range
(
0
,
b
.
ndim
-
axes
-
batch_axes
):
b_shape
[
1
]
*=
b
.
shape
[
-
(
i
+
1
)]
if
batched
:
a_shape
.
insert
(
0
,
a
.
shape
[
0
])
b_shape
.
insert
(
0
,
b
.
shape
[
0
])
a_reshaped
=
a
.
reshape
(
a_shape
)
b_reshaped
=
b
.
reshape
(
b_shape
)
out_reshaped
=
dot
(
a_reshaped
,
b_reshaped
)
out
=
out_reshaped
.
reshape
(
outshape
,
ndim
=
outndim
)
# Make sure the broadcastable pattern of the result is correct,
# since some shape information can be lost in the reshapes.
if
out
.
type
.
broadcastable
!=
outbcast
:
out
=
specify_broadcastable
(
out
,
*
(
ax
for
(
ax
,
b
)
in
enumerate
(
outbcast
)
if
b
)
)
return
out
# if 'axes' is a list, transpose a and b such that the summed axes of a
# are last and the summed axes of b are first.
else
:
axes
=
[
as_list
(
axes_
)
for
axes_
in
axes
]
if
len
(
axes
[
0
])
!=
len
(
axes
[
1
]):
raise
ValueError
(
"Axes elements must have the same length."
)
for
i
,
(
operand_name
,
operand
)
in
enumerate
(((
"a"
,
a
),
(
"b"
,
b
))):
if
len
(
axes
[
i
])
>
operand
.
ndim
:
raise
ValueError
(
f
"axes[{i}] should be array_like with length less than "
f
"the dimensions of {operand_name} ({operand_name}.ndim={operand.ndim}, len(axes[0])={len(axes[i])})."
)
if
len
(
axes
[
i
])
>
0
and
np
.
max
(
axes
[
i
])
>=
operand
.
ndim
:
raise
ValueError
(
f
"axes[{i}] contains dimensions greater than or equal "
f
"to {operand_name}.ndim ({operand_name}.ndim={operand.ndim}, max(axes[0])={np.max(np.array(axes[i]))})."
)
if
batched
and
0
in
axes
[
i
]:
raise
ValueError
(
"axes to sum over must not contain the batch axis "
f
"(axes[{i}]={axes[i]})"
)
batch_axes
=
[
0
]
if
batched
else
[]
other_axes
=
[
[
x
for
x
in
range
(
operand
.
ndim
)
if
x
not
in
axes
[
i
]
and
x
not
in
batch_axes
]
for
i
,
operand
in
enumerate
((
a
,
b
))
]
a_shuffled
=
a
.
dimshuffle
(
batch_axes
+
other_axes
[
0
]
+
axes
[
0
])
b_shuffled
=
b
.
dimshuffle
(
batch_axes
+
axes
[
1
]
+
other_axes
[
1
])
# now that a and b are in the right order, recur with integer axes
return
_tensordot_as_dot
(
a_shuffled
,
b_shuffled
,
len
(
axes
[
0
]),
dot
=
dot
,
batched
=
batched
)
def
tensordot
(
a
:
TensorLike
,
b
:
TensorLike
,
axes
:
int
|
Sequence
[
Sequence
[
int
]]
=
2
)
->
TensorVariable
:
...
...
pytensor/tensor/rewriting/blas.py
浏览文件 @
efd9f491
...
...
@@ -84,9 +84,9 @@ from pytensor.graph.utils import InconsistencyError
from
pytensor.tensor
import
basic
as
ptb
from
pytensor.tensor.blas
import
(
Dot22
,
_batched_dot
,
_dot22
,
_dot22scalar
,
batched_dot
,
gemm_inplace
,
gemm_no_inplace
,
gemv_inplace
,
...
...
@@ -928,7 +928,7 @@ def specialize_matmul_to_batched_dot(fgraph, node):
x
=
x
.
reshape
((
-
1
,
x_shape
[
-
2
],
x_shape
[
-
1
]))
y
=
y
.
reshape
((
-
1
,
y_shape
[
-
2
],
y_shape
[
-
1
]))
new_out
=
batched_dot
(
x
,
y
)
new_out
=
_
batched_dot
(
x
,
y
)
if
len
(
x_shape
)
>
3
:
# And then unravel it
...
...
pytensor/tensor/utils.py
浏览文件 @
efd9f491
...
...
@@ -107,14 +107,6 @@ def shape_of_variables(
return
l
def
as_list
(
x
):
"""Convert x to a list if it is an iterable; otherwise, wrap it in a list."""
try
:
return
list
(
x
)
except
TypeError
:
return
[
x
]
def
import_func_from_string
(
func_string
:
str
):
# -> Optional[Callable]:
func
=
getattr
(
np
,
func_string
,
None
)
if
func
is
not
None
:
...
...
tests/tensor/test_blas.py
浏览文件 @
efd9f491
...
...
@@ -27,6 +27,7 @@ from pytensor.tensor.blas import (
Gemm
,
Gemv
,
Ger
,
_batched_dot
,
_dot22
,
_dot22scalar
,
batched_dot
,
...
...
@@ -2446,7 +2447,7 @@ class TestInferShape(unittest_tools.InferShapeTester):
rng
=
np
.
random
.
default_rng
(
unittest_tools
.
fetch_seed
())
TestBatchedDot
=
makeTester
(
name
=
"BatchedDotTester"
,
op
=
batched_dot
,
op
=
_
batched_dot
,
expected
=
(
lambda
xs
,
ys
:
np
.
asarray
(
[
...
...
@@ -2460,34 +2461,10 @@ TestBatchedDot = makeTester(
grad
=
dict
(
correct1
=
(
random
(
3
,
5
,
7
,
rng
=
rng
),
random
(
3
,
7
,
5
,
rng
=
rng
)),
correct2
=
(
random
(
3
,
5
,
7
,
rng
=
rng
),
random
(
3
,
7
,
9
,
rng
=
rng
)),
correct3
=
(
random
(
3
,
5
,
7
,
rng
=
rng
),
random
(
3
,
7
,
rng
=
rng
)),
correct4
=
(
random
(
3
,
5
),
random
(
3
,
5
,
7
,
rng
=
rng
)),
correct5
=
(
random
(
3
,
rng
=
rng
),
random
(
3
,
5
,
7
,
rng
=
rng
)),
correct6
=
(
random
(
3
,
5
,
rng
=
rng
),
random
(
3
,
rng
=
rng
)),
correct7
=
(
random
(
3
,
5
,
rng
=
rng
),
random
(
3
,
5
,
rng
=
rng
)),
correct8
=
(
random
(
3
,
rng
=
rng
),
random
(
3
,
rng
=
rng
)),
correct9
=
(
random
(
3
,
5
,
7
,
11
,
rng
=
rng
),
random
(
3
,
rng
=
rng
)),
correct10
=
(
random
(
3
,
2
,
6
,
5
,
rng
=
rng
),
random
(
3
,
5
,
rng
=
rng
)),
correct11
=
(
random
(
3
,
2
,
6
,
5
,
rng
=
rng
),
random
(
3
,
5
,
7
,
rng
=
rng
)),
correct12
=
(
random
(
3
,
2
,
6
,
5
,
rng
=
rng
),
random
(
3
,
7
,
5
,
8
,
rng
=
rng
)),
mixed1
=
(
random
(
3
,
5
,
rng
=
rng
)
.
astype
(
"float32"
),
random
(
3
,
5
,
7
,
rng
=
rng
)),
mixed2
=
(
random
(
3
,
5
,
rng
=
rng
)
.
astype
(
"float64"
),
random
(
3
,
5
,
7
,
rng
=
rng
)),
),
good
=
dict
(
correct1
=
(
random
(
3
,
5
,
7
,
rng
=
rng
),
random
(
3
,
7
,
5
,
rng
=
rng
)),
correct2
=
(
random
(
3
,
5
,
7
,
rng
=
rng
),
random
(
3
,
7
,
9
,
rng
=
rng
)),
correct3
=
(
random
(
3
,
5
,
7
,
rng
=
rng
),
random
(
3
,
7
,
rng
=
rng
)),
correct4
=
(
random
(
3
,
5
,
rng
=
rng
),
random
(
3
,
5
,
7
,
rng
=
rng
)),
correct5
=
(
random
(
3
,
rng
=
rng
),
random
(
3
,
5
,
7
,
rng
=
rng
)),
correct6
=
(
random
(
3
,
5
,
rng
=
rng
),
random
(
3
,
rng
=
rng
)),
correct7
=
(
random
(
3
,
5
,
rng
=
rng
),
random
(
3
,
5
,
rng
=
rng
)),
correct8
=
(
random
(
3
,
rng
=
rng
),
random
(
3
,
rng
=
rng
)),
correct9
=
(
random
(
3
,
5
,
7
,
11
,
rng
=
rng
),
random
(
3
,
rng
=
rng
)),
correct10
=
(
random
(
3
,
7
,
11
,
5
,
rng
=
rng
),
random
(
3
,
5
,
rng
=
rng
)),
correct11
=
(
random
(
3
,
7
,
11
,
5
,
rng
=
rng
),
random
(
3
,
5
,
13
,
rng
=
rng
)),
correct12
=
(
random
(
3
,
7
,
11
,
5
,
rng
=
rng
),
random
(
3
,
13
,
5
,
17
,
rng
=
rng
)),
mixed1
=
(
random
(
3
,
5
,
rng
=
rng
)
.
astype
(
"float32"
),
random
(
3
,
5
,
7
,
rng
=
rng
)),
mixed2
=
(
random
(
3
,
5
,
rng
=
rng
)
.
astype
(
"float64"
),
random
(
3
,
5
,
7
,
rng
=
rng
)),
),
bad_build
=
dict
(
no_batch_axis2
=
(
random
(
rng
=
rng
),
random
(
3
,
5
,
rng
=
rng
)),
...
...
@@ -2496,13 +2473,8 @@ TestBatchedDot = makeTester(
bad_runtime
=
dict
(
batch_dim_mismatch1
=
(
random
(
2
,
5
,
7
,
rng
=
rng
),
random
(
3
,
7
,
9
,
rng
=
rng
)),
batch_dim_mismatch2
=
(
random
(
3
,
5
,
7
,
rng
=
rng
),
random
(
2
,
7
,
9
,
rng
=
rng
)),
batch_dim_mismatch3
=
(
random
(
3
,
rng
=
rng
),
random
(
5
,
rng
=
rng
)),
bad_dim1
=
(
random
(
3
,
5
,
7
,
rng
=
rng
),
random
(
3
,
5
,
7
,
rng
=
rng
)),
bad_dim2
=
(
random
(
3
,
5
,
7
,
rng
=
rng
),
random
(
3
,
8
,
3
,
rng
=
rng
)),
bad_dim3
=
(
random
(
3
,
5
,
rng
=
rng
),
random
(
3
,
7
,
rng
=
rng
)),
bad_dim4
=
(
random
(
3
,
5
,
7
,
11
,
rng
=
rng
),
random
(
3
,
5
,
rng
=
rng
)),
bad_dim5
=
(
random
(
3
,
5
,
7
,
11
,
rng
=
rng
),
random
(
3
,
5
,
13
,
rng
=
rng
)),
bad_dim6
=
(
random
(
3
,
5
,
7
,
11
,
rng
=
rng
),
random
(
3
,
13
,
5
,
17
,
rng
=
rng
)),
),
)
...
...
@@ -2511,6 +2483,7 @@ def test_batched_dot():
rng
=
np
.
random
.
default_rng
(
unittest_tools
.
fetch_seed
())
first
=
tensor3
(
"first"
)
second
=
tensor3
(
"second"
)
with
pytest
.
warns
(
FutureWarning
):
output
=
batched_dot
(
first
,
second
)
first_val
=
rng
.
random
((
10
,
10
,
20
))
.
astype
(
config
.
floatX
)
second_val
=
rng
.
random
((
10
,
20
,
5
))
.
astype
(
config
.
floatX
)
...
...
@@ -2522,6 +2495,7 @@ def test_batched_dot():
first_mat
=
dmatrix
(
"first"
)
second_mat
=
dmatrix
(
"second"
)
with
pytest
.
warns
(
FutureWarning
):
output
=
batched_dot
(
first_mat
,
second_mat
)
first_mat_val
=
rng
.
random
((
10
,
10
))
.
astype
(
config
.
floatX
)
second_mat_val
=
rng
.
random
((
10
,
10
))
.
astype
(
config
.
floatX
)
...
...
@@ -2540,7 +2514,7 @@ def test_batched_dot_not_contiguous():
X
=
tensor3
()
W
=
tensor3
()
Z
=
batched_dot
(
X
,
W
)
Z
=
_
batched_dot
(
X
,
W
)
f
=
function
([
X
,
W
],
Z
)
w
=
np_genarray
(
30
,
10
,
5
)
...
...
@@ -2568,7 +2542,7 @@ def test_batched_dot_blas_flags():
x
=
tensor
(
"x"
,
shape
=
(
2
,
5
,
3
))
y
=
tensor
(
"y"
,
shape
=
(
2
,
3
,
1
))
out
=
batched_dot
(
x
,
y
)
out
=
_
batched_dot
(
x
,
y
)
assert
isinstance
(
out
.
owner
.
op
,
BatchedDot
)
x_test
=
rng
.
normal
(
size
=
x
.
type
.
shape
)
.
astype
(
x
.
type
.
dtype
)
y_test
=
rng
.
normal
(
size
=
y
.
type
.
shape
)
.
astype
(
y
.
type
.
dtype
)
...
...
@@ -2590,6 +2564,7 @@ def test_batched_tensordot():
first
=
tensor4
(
"first"
)
second
=
tensor4
(
"second"
)
axes
=
[[
1
,
2
],
[
3
,
1
]]
with
pytest
.
warns
(
FutureWarning
):
output
=
batched_tensordot
(
first
,
second
,
axes
)
first_val
=
rng
.
random
((
8
,
10
,
20
,
3
))
.
astype
(
config
.
floatX
)
second_val
=
rng
.
random
((
8
,
20
,
5
,
10
))
.
astype
(
config
.
floatX
)
...
...
@@ -2602,6 +2577,7 @@ def test_batched_tensordot():
first_mat
=
dmatrix
(
"first"
)
second_mat
=
dmatrix
(
"second"
)
axes
=
1
with
pytest
.
warns
(
FutureWarning
):
output
=
batched_tensordot
(
first_mat
,
second_mat
,
axes
)
first_mat_val
=
rng
.
random
((
10
,
4
))
.
astype
(
config
.
floatX
)
second_mat_val
=
rng
.
random
((
10
,
4
))
.
astype
(
config
.
floatX
)
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论