提交 20ff202e authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Define all batched dot operations as matmul

New rewrite is added to convert unpaired batched row/column matvec or vec products as equivalent matmul products.
上级 e265debc
...@@ -3921,23 +3921,7 @@ def logsumexp(x, axis=None, keepdims=False): ...@@ -3921,23 +3921,7 @@ def logsumexp(x, axis=None, keepdims=False):
return log(sum(exp(x), axis=axis, keepdims=keepdims)) return log(sum(exp(x), axis=axis, keepdims=keepdims))
# Predefine all batched variations of Dot _matmul = Blockwise(
_inner_prod = Blockwise(
_dot,
signature="(n),(n)->()",
)
_matrix_vec_prod = Blockwise(
_dot,
signature="(m,k),(k)->(m)",
)
_vec_matrix_prod = Blockwise(
_dot,
signature="(k),(k,n)->(n)",
)
_matrix_matrix_matmul = Blockwise(
_dot, _dot,
signature="(m,k),(k,n)->(m,n)", signature="(m,k),(k,n)->(m,n)",
gufunc_spec=("numpy.matmul", 2, 1), gufunc_spec=("numpy.matmul", 2, 1),
...@@ -3993,11 +3977,11 @@ def matmul(x1: "ArrayLike", x2: "ArrayLike", dtype: Optional["DTypeLike"] = None ...@@ -3993,11 +3977,11 @@ def matmul(x1: "ArrayLike", x2: "ArrayLike", dtype: Optional["DTypeLike"] = None
if x1.type.ndim == 1 and x2.type.ndim == 1: if x1.type.ndim == 1 and x2.type.ndim == 1:
out = _dot(x1, x2) out = _dot(x1, x2)
elif x1.type.ndim == 1: elif x1.type.ndim == 1:
out = _matrix_matrix_matmul(x1[None], x2).squeeze(-2) out = vecmat(x1, x2)
elif x2.type.ndim == 1: elif x2.type.ndim == 1:
out = _matrix_matrix_matmul(x1, x2[:, None]).squeeze(-1) out = matvec(x1, x2)
else: else:
out = _matrix_matrix_matmul(x1, x2) out = _matmul(x1, x2)
if dtype is not None: if dtype is not None:
out = out.astype(dtype) out = out.astype(dtype)
...@@ -4047,7 +4031,7 @@ def vecdot( ...@@ -4047,7 +4031,7 @@ def vecdot(
>>> z_batch = pt.vecdot(x_batch, y_batch) # shape (3,) >>> z_batch = pt.vecdot(x_batch, y_batch) # shape (3,)
>>> # Equivalent to numpy.vecdot(x_batch, y_batch) >>> # Equivalent to numpy.vecdot(x_batch, y_batch)
""" """
out = _inner_prod(x1, x2) out = matmul(x1[..., None, :], x2[..., :, None]).squeeze((-2, -1))
if dtype is not None: if dtype is not None:
out = out.astype(dtype) out = out.astype(dtype)
...@@ -4096,7 +4080,7 @@ def matvec( ...@@ -4096,7 +4080,7 @@ def matvec(
>>> result = pt.matvec(batched_A, batched_v) # shape (2, 3) >>> result = pt.matvec(batched_A, batched_v) # shape (2, 3)
>>> # Equivalent to numpy.matvec(batched_A, batched_v) >>> # Equivalent to numpy.matvec(batched_A, batched_v)
""" """
out = _matrix_vec_prod(x1, x2) out = matmul(x1, x2[..., None]).squeeze(-1)
if dtype is not None: if dtype is not None:
out = out.astype(dtype) out = out.astype(dtype)
...@@ -4134,18 +4118,18 @@ def vecmat( ...@@ -4134,18 +4118,18 @@ def vecmat(
-------- --------
>>> import pytensor.tensor as pt >>> import pytensor.tensor as pt
>>> # Vector-matrix product >>> # Vector-matrix product
>>> v = pt.vector("v", shape=(3,)) # shape (3,) >>> v = pt.vector("v", shape=(3,))
>>> A = pt.matrix("A", shape=(3, 4)) # shape (3, 4) >>> A = pt.matrix("A", shape=(3, 4))
>>> result = pt.vecmat(v, A) # shape (4,) >>> result = pt.vecmat(v, A) # shape (4,)
>>> # Equivalent to numpy.vecmat(v, A) >>> # Equivalent to numpy.vecmat(v, A)
>>> >>>
>>> # Batched vector-matrix product >>> # Batched vector-matrix product
>>> batched_v = pt.matrix("v", shape=(2, 3)) # shape (2, 3) >>> batched_v = pt.matrix("v", shape=(2, 3))
>>> batched_A = pt.tensor3("A", shape=(2, 3, 4)) # shape (2, 3, 4) >>> batched_A = pt.tensor3("A", shape=(2, 3, 4))
>>> result = pt.vecmat(batched_v, batched_A) # shape (2, 4) >>> result = pt.vecmat(batched_v, batched_A) # shape (2, 4)
>>> # Equivalent to numpy.vecmat(batched_v, batched_A) >>> # Equivalent to numpy.vecmat(batched_v, batched_A)
""" """
out = _vec_matrix_prod(x1, x2) out = matmul(x2.mT, x1[..., None]).squeeze(-1)
if dtype is not None: if dtype is not None:
out = out.astype(dtype) out = out.astype(dtype)
...@@ -4160,18 +4144,18 @@ def vectorize_node_dot(op, node, batched_x, batched_y): ...@@ -4160,18 +4144,18 @@ def vectorize_node_dot(op, node, batched_x, batched_y):
old_y_ndim = old_y.type.ndim old_y_ndim = old_y.type.ndim
match (old_x_ndim, old_y_ndim): match (old_x_ndim, old_y_ndim):
case (1, 1): case (1, 1):
batch_op = _inner_prod batch_fn = vecdot
case (2, 1): case (2, 1):
batch_op = _matrix_vec_prod batch_fn = matvec
case (1, 2): case (1, 2):
batch_op = _vec_matrix_prod batch_fn = vecmat
case (2, 2): case (2, 2):
batch_op = _matrix_matrix_matmul batch_fn = matmul
case _: case _:
raise ValueError( raise ValueError(
f"Core dot Op should have 1D or 2D inputs, got {old_x_ndim}D and {old_y_ndim}D." f"Core dot Op should have 1D or 2D inputs, got {old_x_ndim}D and {old_y_ndim}D."
) )
return batch_op(batched_x, batched_y).owner return batch_fn(batched_x, batched_y).owner
def nan_to_num(x, nan=0.0, posinf=None, neginf=None): def nan_to_num(x, nan=0.0, posinf=None, neginf=None):
......
...@@ -98,7 +98,7 @@ from pytensor.tensor.elemwise import DimShuffle, Elemwise ...@@ -98,7 +98,7 @@ from pytensor.tensor.elemwise import DimShuffle, Elemwise
from pytensor.tensor.exceptions import NotScalarConstantError from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.math import ( from pytensor.tensor.math import (
Dot, Dot,
_matrix_matrix_matmul, _matmul,
add, add,
mul, mul,
neg, neg,
...@@ -908,7 +908,7 @@ blas_optdb.register( ...@@ -908,7 +908,7 @@ blas_optdb.register(
@register_specialize @register_specialize
@node_rewriter([_matrix_matrix_matmul]) @node_rewriter([_matmul])
def specialize_matmul_to_batched_dot(fgraph, node): def specialize_matmul_to_batched_dot(fgraph, node):
"""Rewrite Matmul (Blockwise matrix-matrix) without implicit broadcasted batched dimension as BatchedDot. """Rewrite Matmul (Blockwise matrix-matrix) without implicit broadcasted batched dimension as BatchedDot.
......
...@@ -39,6 +39,7 @@ from pytensor.tensor.rewriting.basic import ( ...@@ -39,6 +39,7 @@ from pytensor.tensor.rewriting.basic import (
broadcasted_by, broadcasted_by,
register_canonicalize, register_canonicalize,
register_specialize, register_specialize,
register_stabilize,
) )
from pytensor.tensor.variable import TensorConstant, TensorVariable from pytensor.tensor.variable import TensorConstant, TensorVariable
...@@ -341,6 +342,7 @@ def is_dimshuffle_useless(new_order, input): ...@@ -341,6 +342,7 @@ def is_dimshuffle_useless(new_order, input):
@register_canonicalize @register_canonicalize
@register_stabilize
@register_specialize @register_specialize
@node_rewriter([DimShuffle]) @node_rewriter([DimShuffle])
def local_dimshuffle_lift(fgraph, node): def local_dimshuffle_lift(fgraph, node):
......
...@@ -26,7 +26,7 @@ from pytensor.tensor.basic import ( ...@@ -26,7 +26,7 @@ from pytensor.tensor.basic import (
from pytensor.tensor.blas import Dot22 from pytensor.tensor.blas import Dot22
from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.elemwise import DimShuffle, Elemwise from pytensor.tensor.elemwise import DimShuffle, Elemwise
from pytensor.tensor.math import Dot, Prod, _matrix_matrix_matmul, log, outer, prod from pytensor.tensor.math import Dot, Prod, _matmul, log, outer, prod
from pytensor.tensor.nlinalg import ( from pytensor.tensor.nlinalg import (
SVD, SVD,
KroneckerProduct, KroneckerProduct,
...@@ -284,7 +284,7 @@ def cholesky_ldotlt(fgraph, node): ...@@ -284,7 +284,7 @@ def cholesky_ldotlt(fgraph, node):
# This rewrite only applies to matrix Dot # This rewrite only applies to matrix Dot
and A.owner.inputs[0].type.ndim == 2 and A.owner.inputs[0].type.ndim == 2
) )
or (A.owner.op == _matrix_matrix_matmul) or (A.owner.op == _matmul)
) )
): ):
return return
......
...@@ -28,6 +28,7 @@ from pytensor.tensor.basic import ( ...@@ -28,6 +28,7 @@ from pytensor.tensor.basic import (
as_tensor_variable, as_tensor_variable,
cast, cast,
constant, constant,
expand_dims,
get_underlying_scalar_constant_value, get_underlying_scalar_constant_value,
moveaxis, moveaxis,
ones_like, ones_like,
...@@ -35,7 +36,6 @@ from pytensor.tensor.basic import ( ...@@ -35,7 +36,6 @@ from pytensor.tensor.basic import (
switch, switch,
zeros_like, zeros_like,
) )
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
from pytensor.tensor.exceptions import NotScalarConstantError from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.extra_ops import broadcast_arrays from pytensor.tensor.extra_ops import broadcast_arrays
...@@ -45,10 +45,7 @@ from pytensor.tensor.math import ( ...@@ -45,10 +45,7 @@ from pytensor.tensor.math import (
Sum, Sum,
_conj, _conj,
_dot, _dot,
_inner_prod, _matmul,
_matrix_matrix_matmul,
_matrix_vec_prod,
_vec_matrix_prod,
add, add,
digamma, digamma,
dot, dot,
...@@ -182,7 +179,7 @@ def local_lift_transpose_through_dot(fgraph, node): ...@@ -182,7 +179,7 @@ def local_lift_transpose_through_dot(fgraph, node):
if not ( if not (
is_matrix_transpose(node.outputs[0]) is_matrix_transpose(node.outputs[0])
and node.inputs[0].owner and node.inputs[0].owner
and ((dot_op := node.inputs[0].owner.op) in (_dot, _matrix_matrix_matmul)) and ((dot_op := node.inputs[0].owner.op) in (_dot, _matmul))
): ):
return False return False
...@@ -197,60 +194,157 @@ def local_lift_transpose_through_dot(fgraph, node): ...@@ -197,60 +194,157 @@ def local_lift_transpose_through_dot(fgraph, node):
return ret return ret
@register_stabilize def _batched_matmul_to_core_matmul(fgraph, node, allow_reshape: bool):
@register_specialize """Move batch dimensions of matmul operands to core matmul
@node_rewriter(tracks=[Blockwise])
def local_batched_matmul_to_core_matmul(fgraph, node):
"""Rewrite matmul where only one of the inputs has batch dimensions to a reshaped core matmul.
Example, if x has batch dimensions, but y not: Example, if x has batch dimensions that don't overlap with batch dimensions of y
x @ y -> (x.reshape(-1, x.shape[-1]) @ y).reshape(*x.shape[:-1], y.shape[-1]) x @ y -> (x.reshape(-1, x.shape[-1]) @ y).reshape(*x.shape[:-1], y.shape[-1])
It also works when y has batch dimensions, but x not. It also works for batch dimensions of y that don't overlap with batch dimensions of x
"""
# Check whether we have a matmul operation in this node The rewrite only uses reshape when mixing dimensions, and it can refuse to apply if `allow_reshape=False`
if not ( """
isinstance(node.op.core_op, Dot)
and len(node.op.inputs_sig[0]) == 2
and len(node.op.inputs_sig[1]) == 2
):
return None
x, y = node.inputs x, y = node.inputs
batch_ndim = node.op.batch_ndim(node) batch_ndim = node.op.batch_ndim(node)
# Check if x has batch dimensions, but y not (or only broadcastable dimensions) x_axis_to_merge = [
if any(not b_dim for b_dim in x.type.broadcastable[:-2]) and all( i
y.type.broadcastable[:-2] for i, (bcast_x, bcast_y) in enumerate(
): zip(x.type.broadcastable[:-2], y.type.broadcastable[:-2])
x_stacked = x.reshape((-1, x.shape[-1])) )
out_stacked = x_stacked @ y.squeeze(tuple(range(batch_ndim))) if bcast_y and not bcast_x
out = out_stacked.reshape((*x.shape[:-1], y.shape[-1])) ]
return [out]
# Otherwise, check if y has batch dimension, but x not y_axis_to_merge = [
elif any(not b_dim for b_dim in y.type.broadcastable[:-2]) and all( i
x.type.broadcastable[:-2] for i, (bcast_x, bcast_y) in enumerate(
): zip(x.type.broadcastable[:-2], y.type.broadcastable[:-2])
# For the y batch case we need to first move the batch axes and then reshape )
# y.shape == (*b, k, n) if bcast_x and not bcast_y
y_tr = moveaxis(y, -2, 0) # (k, *b, n) ]
y_stacked = y_tr.reshape((y.shape[-2], -1)) # (k, *b * n)
out_stacked = x.squeeze(tuple(range(batch_ndim))) @ y_stacked # (m, *b * n) if not (x_axis_to_merge or y_axis_to_merge):
out_stacked_tr = out_stacked.reshape( return None
(x.shape[-2], *y.shape[:-2], y.shape[-1])
) # (m, *b, n)
out = moveaxis(out_stacked_tr, 0, -2) # (*b, m, n)
return [out]
# Both x and y have batch dimensions, nothing to do here x_shape = tuple(x.shape)
y_shape = tuple(y.shape)
x_is_row = x.type.broadcastable[-2]
y_is_col = y.type.broadcastable[-1]
n_x_axis_to_merge = len(x_axis_to_merge)
n_y_axis_to_merge = len(y_axis_to_merge)
n_axis_to_merge = n_x_axis_to_merge + n_y_axis_to_merge
x_stacked, y_stacked = x, y
dims_were_merged = False
if n_x_axis_to_merge:
# ravel batch dimensions of x on the core (m) axis
x_axis_destination = tuple(range(-n_x_axis_to_merge - 2, -2))
x_stacked = moveaxis(x, x_axis_to_merge, x_axis_destination)
if x_is_row:
# x was a row matrix, squeeze it to clean up the graph
x_stacked = x_stacked.squeeze(-2)
if n_x_axis_to_merge > 1 or not x_is_row:
if not allow_reshape:
# TODO: We could allow the y rewrite to go on
# Or just move one axis (the largest) if x is row
return None return None
# Ravel moved batch dims together with (m) if needed
x_stacked_shape = tuple(x_stacked.shape)
x_stacked = x_stacked.reshape(
(*x_stacked_shape[: batch_ndim - n_x_axis_to_merge], -1, x_shape[-1])
)
dims_were_merged = True
if n_y_axis_to_merge:
# ravel batch dimensions of y on the core (n) axis
y_axis_destination = tuple(range(-n_y_axis_to_merge - 1, -1))
y_stacked = moveaxis(y, y_axis_to_merge, y_axis_destination)
if y_is_col:
# y was a column matrix, squeeze it to clean up the graph
y_stacked = y_stacked.squeeze(-1)
if n_y_axis_to_merge > 1 or not y_is_col:
if not allow_reshape:
# TODO: We could allow the x rewrite to go on
# Or just move one axis (the largest) if y is col
return None
# Ravel moved batch dims together with (n) if needed
y_stacked_shape = tuple(y_stacked.shape)
y_stacked = y_stacked.reshape(
(*y_stacked_shape[: batch_ndim - n_y_axis_to_merge], y_shape[-2], -1)
)
dims_were_merged = True
# Squeeze x_dims corresponding to merged dimensions of y
x_axis_to_squeeze = np.array(y_axis_to_merge)
for i in reversed(x_axis_to_merge):
# The corresponding dimensions of y may have shifted when we merged dimensions of x
x_axis_to_squeeze[x_axis_to_squeeze > i] -= 1
x_stacked = x_stacked.squeeze(tuple(x_axis_to_squeeze))
# Same for y
y_axis_to_squeeze = np.array(x_axis_to_merge)
for i in reversed(y_axis_to_merge):
y_axis_to_squeeze[y_axis_to_squeeze > i] -= 1
y_stacked = y_stacked.squeeze(tuple(y_axis_to_squeeze))
out_stacked = x_stacked @ y_stacked
# Split back any merged dimensions
if dims_were_merged:
x_merged_shapes = [x_shape[i] for i in x_axis_to_merge]
if not x_is_row:
# Otherwise we handle that later with expand_dims, which is cleaner
x_merged_shapes.append(x_shape[-2])
y_merged_shapes = [y_shape[i] for i in y_axis_to_merge]
if not y_is_col:
# Otherwise we handle that later with expand_dims, which is cleaner
y_merged_shapes.append(y_shape[-1])
out_stacked_shape = tuple(out_stacked.shape)
out_unstacked = out_stacked.reshape(
(
*out_stacked_shape[: batch_ndim - n_axis_to_merge],
*x_merged_shapes,
*y_merged_shapes,
)
)
else:
out_unstacked = out_stacked
# Add back dummy row, col axis
# We do this separately to avoid the reshape as much as we can
if y_is_col and (n_y_axis_to_merge or dims_were_merged):
out_unstacked = expand_dims(out_unstacked, -1)
if x_is_row and (n_x_axis_to_merge or dims_were_merged):
out_unstacked = expand_dims(out_unstacked, -n_y_axis_to_merge - 2)
# Move batch axis back to their original location
source = range(-n_axis_to_merge - 2, 0)
destination = (*x_axis_to_merge, -2, *y_axis_to_merge, -1)
out = moveaxis(out_unstacked, source, destination)
return [out]
@register_canonicalize
@node_rewriter(tracks=[_matmul])
def local_batched_matmul_to_core_matmul(fgraph, node):
# Allow passing batch dimensions of matmul to core vector / column matrices
return _batched_matmul_to_core_matmul(fgraph, node, allow_reshape=False)
@register_specialize
@node_rewriter(tracks=[_matmul])
def local_batched_matmul_to_core_matmul_with_reshape(fgraph, node):
# Allow stacking batch dimensions of matmul with core dimensions, with a reshape operation
# We only apply this in specialize, because grahs with reshape are hard to work with
return _batched_matmul_to_core_matmul(fgraph, node, allow_reshape=True)
@register_canonicalize @register_canonicalize
@register_specialize @register_specialize
@node_rewriter([_inner_prod, _matrix_vec_prod, _vec_matrix_prod, _matrix_matrix_matmul]) @node_rewriter([_matmul])
def local_blockwise_dot_to_mul(fgraph, node): def local_blockwise_dot_to_mul(fgraph, node):
"""Rewrite blockwise dots that correspond to multiplication without summation. """Rewrite blockwise dots that correspond to multiplication without summation.
......
import numpy as np import numpy as np
import pytest import pytest
from pytensor import function from pytensor import config, function
from pytensor import tensor as pt from pytensor import tensor as pt
from pytensor.compile import get_default_mode from pytensor.compile import get_default_mode
from pytensor.graph import FunctionGraph from pytensor.graph import FunctionGraph, ancestors
from pytensor.tensor import ( from pytensor.tensor import (
col, col,
dscalar, dscalar,
...@@ -21,7 +21,6 @@ from pytensor.tensor import ( ...@@ -21,7 +21,6 @@ from pytensor.tensor import (
vectorize, vectorize,
) )
from pytensor.tensor.blas import BatchedDot from pytensor.tensor.blas import BatchedDot
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.elemwise import DimShuffle from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.rewriting.blas import ( from pytensor.tensor.rewriting.blas import (
_as_scalar, _as_scalar,
...@@ -37,8 +36,11 @@ def XYZab(): ...@@ -37,8 +36,11 @@ def XYZab():
return matrix(), matrix(), matrix(), scalar(), scalar() return matrix(), matrix(), matrix(), scalar(), scalar()
@pytest.mark.parametrize("valid_case", (True, False)) @pytest.mark.skipif(
def test_specialize_matmul_to_batched_dot(valid_case): config.mode == "FAST_COMPILE", reason="Test requires specialization rewrites"
)
@pytest.mark.parametrize("aligned", (True, False))
def test_specialize_matmul_to_batched_dot(aligned):
signature = BatchedDot.gufunc_signature signature = BatchedDot.gufunc_signature
rewrite = specialize_matmul_to_batched_dot.__name__ rewrite = specialize_matmul_to_batched_dot.__name__
...@@ -49,23 +51,36 @@ def test_specialize_matmul_to_batched_dot(valid_case): ...@@ -49,23 +51,36 @@ def test_specialize_matmul_to_batched_dot(valid_case):
return np.matmul(x, y) return np.matmul(x, y)
x = tensor(shape=(7, 5, 3, 3)) x = tensor(shape=(7, 5, 3, 3))
if valid_case: if aligned:
y = tensor(shape=(7, 5, 3, 3)) y = tensor(shape=(7, 5, 3, 3))
else: else:
y = tensor(shape=(5, 3, 3)) y = tensor(shape=(5, 3, 3))
out = vectorize(core_pt, signature=signature)(x, y)
assert (
sum(
isinstance(var.owner.op, BatchedDot)
for var in ancestors([out])
if var.owner
)
== 0
)
vectorize_pt = function( vectorize_pt = function(
[x, y], [x, y],
vectorize(core_pt, signature=signature)(x, y), out,
mode=get_default_mode().including(rewrite), mode=get_default_mode().including(rewrite),
) )
blocwkise_node = any(
isinstance(node.op, Blockwise) for node in vectorize_pt.maker.fgraph.apply_nodes assert (
sum(
isinstance(var.owner.op, BatchedDot)
for var in ancestors(vectorize_pt.maker.fgraph.outputs)
if var.owner
)
== 1
) )
if valid_case:
assert not blocwkise_node
else:
assert blocwkise_node
x_test = np.random.normal(size=x.type.shape).astype(x.type.dtype) x_test = np.random.normal(size=x.type.shape).astype(x.type.dtype)
y_test = np.random.normal(size=y.type.shape).astype(y.type.dtype) y_test = np.random.normal(size=y.type.shape).astype(y.type.dtype)
......
...@@ -42,6 +42,7 @@ from pytensor.tensor.math import ( ...@@ -42,6 +42,7 @@ from pytensor.tensor.math import (
Prod, Prod,
Sum, Sum,
_conj, _conj,
_matmul,
add, add,
arccosh, arccosh,
arcsinh, arcsinh,
...@@ -4566,6 +4567,88 @@ def test_local_batched_matmul_to_core_matmul(): ...@@ -4566,6 +4567,88 @@ def test_local_batched_matmul_to_core_matmul():
np.testing.assert_allclose(fn(x_test, y_test), x_test @ y_test) np.testing.assert_allclose(fn(x_test, y_test), x_test @ y_test)
@pytest.mark.parametrize(
"mat_shape, vec_shape",
[
[(1, 2, 2), (5, 2)],
[(5, 2, 2), (1, 2)],
[(1, 1, 2, 2), (7, 5, 2)],
[(7, 5, 2, 2), (1, 1, 5, 2)],
[(1, 5, 1, 2, 2), (7, 5, 7, 2)],
[(7, 5, 7, 2, 2), (1, 5, 1, 2)],
[(5, 1, 3, 1, 2, 2), (1, 7, 3, 7, 2)],
[(1, 7, 3, 7, 2, 2), (5, 1, 3, 1, 2)],
],
ids=str,
)
@pytest.mark.parametrize("func", ("matvec", "vecmat", "vecdot"))
def test_batch_matvec_to_matmul(func, mat_shape, vec_shape):
def count_matvec_nodes(graph):
# Counts how many matmul nodes actually correspond to matvec or vecmat
return len(
[
var
for var in ancestors([graph])
if (
var.owner is not None
and var.owner.op == _matmul
and (
(var.owner.inputs[0].type.shape[-2] == 1)
or (var.owner.inputs[1].type.shape[-1] == 1)
)
)
]
)
mat = pt.tensor("mat", shape=mat_shape, dtype="float64")
vec = pt.tensor("vec", shape=vec_shape, dtype="float64")
if func == "matvec":
out = pt.matvec(mat, vec)
elif func == "vecmat":
out = pt.vecmat(vec, mat)
elif func == "vecdot":
out = pt.vecdot(mat[..., 0], vec)
else:
raise NotImplementedError(func)
assert count_matvec_nodes(out) == 1
rewritten_out = rewrite_graph(
out,
include=(
"canonicalize",
"specialize",
),
exclude=(
"local_eager_useless_unbatched_blockwise",
"specialize_matmul_to_batched_dot",
),
)
# No `matvec` in the rewritten out if one of the vector can be treated as a matrix
expected = not any(
mat_dim == 1 and vec_dim != 1
for vec_dim, mat_dim in zip(vec_shape[:-1], mat_shape[:-2])
)
if not expected and func == "vecdot":
# In this case there are two vectors, so we may still end up with a `matvec` unless the second vec can also be treated as matrix
expected = not any(
mat_dim != 1 and vec_dim == 1
for vec_dim, mat_dim in zip(vec_shape[:-1], mat_shape[:-2])
)
assert count_matvec_nodes(rewritten_out) == expected
rng = np.random.default_rng(mat_shape + vec_shape)
eval_dict = {mat: rng.random(mat.type.shape), vec: rng.random(vec.type.shape)}
# Evaluate results are correct without further rewrites
no_optimization = Mode(linker="py", optimizer=None)
np.testing.assert_allclose(
rewritten_out.eval(eval_dict, mode=no_optimization),
out.eval(eval_dict, mode=no_optimization),
)
def test_log_kv_stabilization(): def test_log_kv_stabilization():
x = pt.scalar("x") x = pt.scalar("x")
out = log(kv(4.5, x)) out = log(kv(4.5, x))
...@@ -4616,8 +4699,8 @@ def test_local_dot_to_mul(batched, a_shape, b_shape): ...@@ -4616,8 +4699,8 @@ def test_local_dot_to_mul(batched, a_shape, b_shape):
out = dot(a, b) out = dot(a, b)
if batched: if batched:
batch_a = tensor("batch_a", shape=(1, 5, *a_shape)) batch_a = tensor("batch_a", shape=(2, 1, 5, *a_shape))
batch_b = tensor("batch_b", shape=(7, 1, *b_shape)) batch_b = tensor("batch_b", shape=(2, 7, 1, *b_shape))
out = vectorize_graph(out, {a: batch_a, b: batch_b}) out = vectorize_graph(out, {a: batch_a, b: batch_b})
a = batch_a a = batch_a
b = batch_b b = batch_b
......
...@@ -2092,9 +2092,9 @@ class TestDot: ...@@ -2092,9 +2092,9 @@ class TestDot:
def test_matrix_vector_ops(): def test_matrix_vector_ops():
"""Test vecdot, matvec, and vecmat helper functions.""" """Test vecdot, matvec, and vecmat helper functions."""
rng = np.random.default_rng(seed=utt.fetch_seed()) rng = np.random.default_rng(2089)
# Create test data with batch dimension (2) atol = 1e-7 if config.floatX == "float32" else 1e-15
batch_size = 2 batch_size = 2
dim_k = 4 # Common dimension dim_k = 4 # Common dimension
dim_m = 3 # Matrix rows dim_m = 3 # Matrix rows
...@@ -2109,7 +2109,6 @@ def test_matrix_vector_ops(): ...@@ -2109,7 +2109,6 @@ def test_matrix_vector_ops():
mat_kn_val = random(batch_size, dim_k, dim_n, rng=rng).astype(config.floatX) mat_kn_val = random(batch_size, dim_k, dim_n, rng=rng).astype(config.floatX)
vec_k_val = random(batch_size, dim_k, rng=rng).astype(config.floatX) vec_k_val = random(batch_size, dim_k, rng=rng).astype(config.floatX)
# Create tensor variables with matching dtype
mat_mk = tensor( mat_mk = tensor(
name="mat_mk", shape=(batch_size, dim_m, dim_k), dtype=config.floatX name="mat_mk", shape=(batch_size, dim_m, dim_k), dtype=config.floatX
) )
...@@ -2130,7 +2129,7 @@ def test_matrix_vector_ops(): ...@@ -2130,7 +2129,7 @@ def test_matrix_vector_ops():
expected_vecdot = np.zeros((batch_size,), dtype=np.int32) expected_vecdot = np.zeros((batch_size,), dtype=np.int32)
for i in range(batch_size): for i in range(batch_size):
expected_vecdot[i] = np.sum(vec_k_val[i] * vec_k_val[i]) expected_vecdot[i] = np.sum(vec_k_val[i] * vec_k_val[i])
np.testing.assert_allclose(result, expected_vecdot) np.testing.assert_allclose(result, expected_vecdot, atol=atol)
# Test 2: matvec - matrix-vector product # Test 2: matvec - matrix-vector product
matvec_out = matvec(mat_mk, vec_k) matvec_out = matvec(mat_mk, vec_k)
...@@ -2141,7 +2140,7 @@ def test_matrix_vector_ops(): ...@@ -2141,7 +2140,7 @@ def test_matrix_vector_ops():
expected_matvec = np.zeros((batch_size, dim_m), dtype=config.floatX) expected_matvec = np.zeros((batch_size, dim_m), dtype=config.floatX)
for i in range(batch_size): for i in range(batch_size):
expected_matvec[i] = np.dot(mat_mk_val[i], vec_k_val[i]) expected_matvec[i] = np.dot(mat_mk_val[i], vec_k_val[i])
np.testing.assert_allclose(result_matvec, expected_matvec) np.testing.assert_allclose(result_matvec, expected_matvec, atol=atol)
# Test 3: vecmat - vector-matrix product # Test 3: vecmat - vector-matrix product
vecmat_out = vecmat(vec_k, mat_kn) vecmat_out = vecmat(vec_k, mat_kn)
...@@ -2152,7 +2151,7 @@ def test_matrix_vector_ops(): ...@@ -2152,7 +2151,7 @@ def test_matrix_vector_ops():
expected_vecmat = np.zeros((batch_size, dim_n), dtype=config.floatX) expected_vecmat = np.zeros((batch_size, dim_n), dtype=config.floatX)
for i in range(batch_size): for i in range(batch_size):
expected_vecmat[i] = np.dot(vec_k_val[i], mat_kn_val[i]) expected_vecmat[i] = np.dot(vec_k_val[i], mat_kn_val[i])
np.testing.assert_allclose(result_vecmat, expected_vecmat) np.testing.assert_allclose(result_vecmat, expected_vecmat, atol=atol)
class TestTensordot: class TestTensordot:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论