提交 efd9f491 authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Deprecate BLAS batch helper functions

上级 b3da2a4b
......@@ -79,10 +79,14 @@ import functools
import logging
import os
import shlex
import warnings
from pathlib import Path
import numpy as np
from pytensor.graph import vectorize_graph
from pytensor.npy_2_compat import normalize_axis_tuple
try:
import numpy.__config__
......@@ -100,9 +104,9 @@ from pytensor.link.c.params_type import ParamsType
from pytensor.printing import FunctionPrinter, pprint
from pytensor.scalar import bool as bool_t
from pytensor.tensor import basic as ptb
from pytensor.tensor.basic import expand_dims
from pytensor.tensor.blas_headers import blas_header_text, blas_header_version
from pytensor.tensor.shape import shape_padright, specify_broadcastable
from pytensor.tensor.math import dot, tensordot
from pytensor.tensor.shape import specify_broadcastable
from pytensor.tensor.type import DenseTensorType, tensor
......@@ -1604,8 +1608,8 @@ class BatchedDot(COp):
x, y = inp
(gz,) = grads
xgrad = batched_dot(gz, y.dimshuffle(0, 2, 1))
ygrad = batched_dot(x.dimshuffle(0, 2, 1), gz)
xgrad = _batched_dot(gz, y.dimshuffle(0, 2, 1))
ygrad = _batched_dot(x.dimshuffle(0, 2, 1), gz)
# If x or y contain broadcastable dimensions but only one of
# them know that a matching dimensions is broadcastable, the
......@@ -1729,31 +1733,22 @@ def batched_dot(a, b):
dot products in terms of batched matrix-matrix dot products, so
it may be possible to further optimize for performance.
"""
warnings.warn(
"batched_dot is deprecated. "
"Use `dot` in conjution with `tensor.vectorize` or `graph.replace.vectorize_graph`",
FutureWarning,
)
a, b = ptb.as_tensor_variable(a), ptb.as_tensor_variable(b)
if a.ndim == 0:
raise TypeError("a must have at least one (batch) axis")
elif b.ndim == 0:
raise TypeError("b must have at least one (batch) axis")
elif a.ndim == 1:
return shape_padright(a, (b.ndim - 1)) * b
elif b.ndim == 1:
return a * shape_padright(b, (a.ndim - 1))
elif a.ndim > 3 or b.ndim > 3:
return batched_tensordot(a, b, [[a.ndim - 1], [np.maximum(1, b.ndim - 2)]])
else:
# If either a or b is a batched vector, expand dims and later squeeze them
expanded_axis = []
if a.ndim == 2:
a = expand_dims(a, axis=1)
expanded_axis.append(1)
if b.ndim == 2:
b = expand_dims(b, axis=2)
expanded_axis.append(2)
out = _batched_dot(a, b)
if expanded_axis:
out = out.squeeze(axis=expanded_axis)
return out
core_a = a[0].type()
core_b = b[0].type()
core_dot = dot(core_a, core_b)
return vectorize_graph(core_dot, replace={core_a: a, core_b: b})
def batched_tensordot(x, y, axes=2):
......@@ -1791,6 +1786,22 @@ def batched_tensordot(x, y, axes=2):
reshapes to reduce the tensor dot product to a matrix or vector
dot product. Finally, it calls batched_dot to compute the result.
"""
from pytensor.tensor.math import _tensordot_as_dot
warnings.warn(
"batched_tensordot is deprecated. "
"Use `tensordot` in conjuction with `tensor.vectorize` or `graph.replace.vectorize_graph`",
FutureWarning,
)
if isinstance(axes, int):
core_axes = axes
else:
# Convert batched axes to core axes
core_axes_a = [a - 1 for a in normalize_axis_tuple(axes[0], x.type.ndim)]
core_axes = [a - 1 for a in normalize_axis_tuple(axes[1], y.type.ndim)]
core_axes = [core_axes_a, core_axes]
core_x = x[0].type()
core_y = y[0].type()
core_tensordot = tensordot(core_x, core_y, axes=core_axes)
return _tensordot_as_dot(x, y, axes, dot=batched_dot, batched=True)
return vectorize_graph(core_tensordot, replace={core_x: x, core_y: y})
......@@ -50,7 +50,7 @@ from pytensor.tensor.type import (
tensor,
uint_dtypes,
)
from pytensor.tensor.utils import as_list, normalize_reduce_axis
from pytensor.tensor.utils import normalize_reduce_axis
from pytensor.tensor.variable import (
TensorVariable,
_tensor_py_operators,
......@@ -3208,133 +3208,6 @@ def dense_dot(a, b):
return _dot(a, b)
def _tensordot_as_dot(a, b, axes, dot, batched):
"""
Reduces a tensor dot product to a matrix or vector dot product. Based
on code from Tijmen Tieleman's gnumpy
(http://www.cs.toronto.edu/~tijmen/gnumpy.html).
Please see the documentation of tensordot for the meaning of the a, b
and axes arguments.
:param dot: a function that accepts two symbolic variables and computes
the appropriate dot product (e.g. dot, batched_dot)
:type dot: function
:param batched: whether to treat the first axis of a and b as a batch
axis. If so, this axis will be preserved in the output,
allowing this function to be used also for batched
tensor dot products.
:type batched: boolean
:returns: a tensor with shape equal to the concatenation of a's shape
(less any dimensions that were summed over) and b's shape
(less the first dimension and any dimensions that were summed
over).
:rtype: symbolic tensor
"""
a, b = as_tensor_variable(a), as_tensor_variable(b)
if not np.isscalar(axes) and len(axes) != 2:
raise ValueError(
"Axes should be an integer or a "
f"list/tuple of len 2 ({axes} was provided)"
)
# if 'axes' is a number of axes to multiply and sum over (trailing axes
# of a, leading axes of b), we can just reshape and use dot.
elif np.isscalar(axes):
axes = int(axes)
for operand_name, operand in (("a", a), ("b", b)):
if axes > operand.ndim:
raise ValueError(
f"axes can not be larger than the dimension of {operand_name} "
f"({operand_name}.ndim={operand.ndim}, axes={axes})"
)
if batched and axes == operand.ndim:
raise ValueError(
"axes to sum over must not include the batch axis "
f"of {operand_name} ({operand_name}.ndim={operand.ndim}, axes={axes})"
)
batch_axes = 1 if batched else 0
a_outaxes = slice(0, a.ndim - axes)
b_outaxes = slice(batch_axes + axes, b.ndim)
outshape = concatenate([a.shape[a_outaxes], b.shape[b_outaxes]])
outbcast = a.broadcastable[a_outaxes] + b.broadcastable[b_outaxes]
outndim = len(outbcast)
a_shape = [1] * 2
b_shape = [1] * 2
# compute total size of summed axes
for i in range(0, axes):
a_shape[1] *= a.shape[-(i + 1)]
b_shape[0] *= b.shape[batch_axes + i]
# compute total size of other axes
for i in range(0, a.ndim - axes - batch_axes):
a_shape[0] *= a.shape[batch_axes + i]
for i in range(0, b.ndim - axes - batch_axes):
b_shape[1] *= b.shape[-(i + 1)]
if batched:
a_shape.insert(0, a.shape[0])
b_shape.insert(0, b.shape[0])
a_reshaped = a.reshape(a_shape)
b_reshaped = b.reshape(b_shape)
out_reshaped = dot(a_reshaped, b_reshaped)
out = out_reshaped.reshape(outshape, ndim=outndim)
# Make sure the broadcastable pattern of the result is correct,
# since some shape information can be lost in the reshapes.
if out.type.broadcastable != outbcast:
out = specify_broadcastable(
out, *(ax for (ax, b) in enumerate(outbcast) if b)
)
return out
# if 'axes' is a list, transpose a and b such that the summed axes of a
# are last and the summed axes of b are first.
else:
axes = [as_list(axes_) for axes_ in axes]
if len(axes[0]) != len(axes[1]):
raise ValueError("Axes elements must have the same length.")
for i, (operand_name, operand) in enumerate((("a", a), ("b", b))):
if len(axes[i]) > operand.ndim:
raise ValueError(
f"axes[{i}] should be array_like with length less than "
f"the dimensions of {operand_name} ({operand_name}.ndim={operand.ndim}, len(axes[0])={len(axes[i])})."
)
if len(axes[i]) > 0 and np.max(axes[i]) >= operand.ndim:
raise ValueError(
f"axes[{i}] contains dimensions greater than or equal "
f"to {operand_name}.ndim ({operand_name}.ndim={operand.ndim}, max(axes[0])={np.max(np.array(axes[i]))})."
)
if batched and 0 in axes[i]:
raise ValueError(
"axes to sum over must not contain the batch axis "
f"(axes[{i}]={axes[i]})"
)
batch_axes = [0] if batched else []
other_axes = [
[x for x in range(operand.ndim) if x not in axes[i] and x not in batch_axes]
for i, operand in enumerate((a, b))
]
a_shuffled = a.dimshuffle(batch_axes + other_axes[0] + axes[0])
b_shuffled = b.dimshuffle(batch_axes + axes[1] + other_axes[1])
# now that a and b are in the right order, recur with integer axes
return _tensordot_as_dot(
a_shuffled, b_shuffled, len(axes[0]), dot=dot, batched=batched
)
def tensordot(
a: TensorLike, b: TensorLike, axes: int | Sequence[Sequence[int]] = 2
) -> TensorVariable:
......
......@@ -84,9 +84,9 @@ from pytensor.graph.utils import InconsistencyError
from pytensor.tensor import basic as ptb
from pytensor.tensor.blas import (
Dot22,
_batched_dot,
_dot22,
_dot22scalar,
batched_dot,
gemm_inplace,
gemm_no_inplace,
gemv_inplace,
......@@ -928,7 +928,7 @@ def specialize_matmul_to_batched_dot(fgraph, node):
x = x.reshape((-1, x_shape[-2], x_shape[-1]))
y = y.reshape((-1, y_shape[-2], y_shape[-1]))
new_out = batched_dot(x, y)
new_out = _batched_dot(x, y)
if len(x_shape) > 3:
# And then unravel it
......
......@@ -107,14 +107,6 @@ def shape_of_variables(
return l
def as_list(x):
"""Convert x to a list if it is an iterable; otherwise, wrap it in a list."""
try:
return list(x)
except TypeError:
return [x]
def import_func_from_string(func_string: str): # -> Optional[Callable]:
func = getattr(np, func_string, None)
if func is not None:
......
......@@ -27,6 +27,7 @@ from pytensor.tensor.blas import (
Gemm,
Gemv,
Ger,
_batched_dot,
_dot22,
_dot22scalar,
batched_dot,
......@@ -2446,7 +2447,7 @@ class TestInferShape(unittest_tools.InferShapeTester):
rng = np.random.default_rng(unittest_tools.fetch_seed())
TestBatchedDot = makeTester(
name="BatchedDotTester",
op=batched_dot,
op=_batched_dot,
expected=(
lambda xs, ys: np.asarray(
[
......@@ -2460,34 +2461,10 @@ TestBatchedDot = makeTester(
grad=dict(
correct1=(random(3, 5, 7, rng=rng), random(3, 7, 5, rng=rng)),
correct2=(random(3, 5, 7, rng=rng), random(3, 7, 9, rng=rng)),
correct3=(random(3, 5, 7, rng=rng), random(3, 7, rng=rng)),
correct4=(random(3, 5), random(3, 5, 7, rng=rng)),
correct5=(random(3, rng=rng), random(3, 5, 7, rng=rng)),
correct6=(random(3, 5, rng=rng), random(3, rng=rng)),
correct7=(random(3, 5, rng=rng), random(3, 5, rng=rng)),
correct8=(random(3, rng=rng), random(3, rng=rng)),
correct9=(random(3, 5, 7, 11, rng=rng), random(3, rng=rng)),
correct10=(random(3, 2, 6, 5, rng=rng), random(3, 5, rng=rng)),
correct11=(random(3, 2, 6, 5, rng=rng), random(3, 5, 7, rng=rng)),
correct12=(random(3, 2, 6, 5, rng=rng), random(3, 7, 5, 8, rng=rng)),
mixed1=(random(3, 5, rng=rng).astype("float32"), random(3, 5, 7, rng=rng)),
mixed2=(random(3, 5, rng=rng).astype("float64"), random(3, 5, 7, rng=rng)),
),
good=dict(
correct1=(random(3, 5, 7, rng=rng), random(3, 7, 5, rng=rng)),
correct2=(random(3, 5, 7, rng=rng), random(3, 7, 9, rng=rng)),
correct3=(random(3, 5, 7, rng=rng), random(3, 7, rng=rng)),
correct4=(random(3, 5, rng=rng), random(3, 5, 7, rng=rng)),
correct5=(random(3, rng=rng), random(3, 5, 7, rng=rng)),
correct6=(random(3, 5, rng=rng), random(3, rng=rng)),
correct7=(random(3, 5, rng=rng), random(3, 5, rng=rng)),
correct8=(random(3, rng=rng), random(3, rng=rng)),
correct9=(random(3, 5, 7, 11, rng=rng), random(3, rng=rng)),
correct10=(random(3, 7, 11, 5, rng=rng), random(3, 5, rng=rng)),
correct11=(random(3, 7, 11, 5, rng=rng), random(3, 5, 13, rng=rng)),
correct12=(random(3, 7, 11, 5, rng=rng), random(3, 13, 5, 17, rng=rng)),
mixed1=(random(3, 5, rng=rng).astype("float32"), random(3, 5, 7, rng=rng)),
mixed2=(random(3, 5, rng=rng).astype("float64"), random(3, 5, 7, rng=rng)),
),
bad_build=dict(
no_batch_axis2=(random(rng=rng), random(3, 5, rng=rng)),
......@@ -2496,13 +2473,8 @@ TestBatchedDot = makeTester(
bad_runtime=dict(
batch_dim_mismatch1=(random(2, 5, 7, rng=rng), random(3, 7, 9, rng=rng)),
batch_dim_mismatch2=(random(3, 5, 7, rng=rng), random(2, 7, 9, rng=rng)),
batch_dim_mismatch3=(random(3, rng=rng), random(5, rng=rng)),
bad_dim1=(random(3, 5, 7, rng=rng), random(3, 5, 7, rng=rng)),
bad_dim2=(random(3, 5, 7, rng=rng), random(3, 8, 3, rng=rng)),
bad_dim3=(random(3, 5, rng=rng), random(3, 7, rng=rng)),
bad_dim4=(random(3, 5, 7, 11, rng=rng), random(3, 5, rng=rng)),
bad_dim5=(random(3, 5, 7, 11, rng=rng), random(3, 5, 13, rng=rng)),
bad_dim6=(random(3, 5, 7, 11, rng=rng), random(3, 13, 5, 17, rng=rng)),
),
)
......@@ -2511,7 +2483,8 @@ def test_batched_dot():
rng = np.random.default_rng(unittest_tools.fetch_seed())
first = tensor3("first")
second = tensor3("second")
output = batched_dot(first, second)
with pytest.warns(FutureWarning):
output = batched_dot(first, second)
first_val = rng.random((10, 10, 20)).astype(config.floatX)
second_val = rng.random((10, 20, 5)).astype(config.floatX)
result_fn = function([first, second], output)
......@@ -2522,7 +2495,8 @@ def test_batched_dot():
first_mat = dmatrix("first")
second_mat = dmatrix("second")
output = batched_dot(first_mat, second_mat)
with pytest.warns(FutureWarning):
output = batched_dot(first_mat, second_mat)
first_mat_val = rng.random((10, 10)).astype(config.floatX)
second_mat_val = rng.random((10, 10)).astype(config.floatX)
result_fn = function([first_mat, second_mat], output)
......@@ -2540,7 +2514,7 @@ def test_batched_dot_not_contiguous():
X = tensor3()
W = tensor3()
Z = batched_dot(X, W)
Z = _batched_dot(X, W)
f = function([X, W], Z)
w = np_genarray(30, 10, 5)
......@@ -2568,7 +2542,7 @@ def test_batched_dot_blas_flags():
x = tensor("x", shape=(2, 5, 3))
y = tensor("y", shape=(2, 3, 1))
out = batched_dot(x, y)
out = _batched_dot(x, y)
assert isinstance(out.owner.op, BatchedDot)
x_test = rng.normal(size=x.type.shape).astype(x.type.dtype)
y_test = rng.normal(size=y.type.shape).astype(y.type.dtype)
......@@ -2590,7 +2564,8 @@ def test_batched_tensordot():
first = tensor4("first")
second = tensor4("second")
axes = [[1, 2], [3, 1]]
output = batched_tensordot(first, second, axes)
with pytest.warns(FutureWarning):
output = batched_tensordot(first, second, axes)
first_val = rng.random((8, 10, 20, 3)).astype(config.floatX)
second_val = rng.random((8, 20, 5, 10)).astype(config.floatX)
result_fn = function([first, second], output)
......@@ -2602,7 +2577,8 @@ def test_batched_tensordot():
first_mat = dmatrix("first")
second_mat = dmatrix("second")
axes = 1
output = batched_tensordot(first_mat, second_mat, axes)
with pytest.warns(FutureWarning):
output = batched_tensordot(first_mat, second_mat, axes)
first_mat_val = rng.random((10, 4)).astype(config.floatX)
second_mat_val = rng.random((10, 4)).astype(config.floatX)
result_fn = function([first_mat, second_mat], output)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论