提交 20ff202e authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Define all batched dot operations as matmul

New rewrite is added to convert unpaired batched row/column matvec or vec products as equivalent matmul products.
上级 e265debc
......@@ -3921,23 +3921,7 @@ def logsumexp(x, axis=None, keepdims=False):
return log(sum(exp(x), axis=axis, keepdims=keepdims))
# Predefine all batched variations of Dot
_inner_prod = Blockwise(
_dot,
signature="(n),(n)->()",
)
_matrix_vec_prod = Blockwise(
_dot,
signature="(m,k),(k)->(m)",
)
_vec_matrix_prod = Blockwise(
_dot,
signature="(k),(k,n)->(n)",
)
_matrix_matrix_matmul = Blockwise(
_matmul = Blockwise(
_dot,
signature="(m,k),(k,n)->(m,n)",
gufunc_spec=("numpy.matmul", 2, 1),
......@@ -3993,11 +3977,11 @@ def matmul(x1: "ArrayLike", x2: "ArrayLike", dtype: Optional["DTypeLike"] = None
if x1.type.ndim == 1 and x2.type.ndim == 1:
out = _dot(x1, x2)
elif x1.type.ndim == 1:
out = _matrix_matrix_matmul(x1[None], x2).squeeze(-2)
out = vecmat(x1, x2)
elif x2.type.ndim == 1:
out = _matrix_matrix_matmul(x1, x2[:, None]).squeeze(-1)
out = matvec(x1, x2)
else:
out = _matrix_matrix_matmul(x1, x2)
out = _matmul(x1, x2)
if dtype is not None:
out = out.astype(dtype)
......@@ -4047,7 +4031,7 @@ def vecdot(
>>> z_batch = pt.vecdot(x_batch, y_batch) # shape (3,)
>>> # Equivalent to numpy.vecdot(x_batch, y_batch)
"""
out = _inner_prod(x1, x2)
out = matmul(x1[..., None, :], x2[..., :, None]).squeeze((-2, -1))
if dtype is not None:
out = out.astype(dtype)
......@@ -4096,7 +4080,7 @@ def matvec(
>>> result = pt.matvec(batched_A, batched_v) # shape (2, 3)
>>> # Equivalent to numpy.matvec(batched_A, batched_v)
"""
out = _matrix_vec_prod(x1, x2)
out = matmul(x1, x2[..., None]).squeeze(-1)
if dtype is not None:
out = out.astype(dtype)
......@@ -4134,18 +4118,18 @@ def vecmat(
--------
>>> import pytensor.tensor as pt
>>> # Vector-matrix product
>>> v = pt.vector("v", shape=(3,)) # shape (3,)
>>> A = pt.matrix("A", shape=(3, 4)) # shape (3, 4)
>>> v = pt.vector("v", shape=(3,))
>>> A = pt.matrix("A", shape=(3, 4))
>>> result = pt.vecmat(v, A) # shape (4,)
>>> # Equivalent to numpy.vecmat(v, A)
>>>
>>> # Batched vector-matrix product
>>> batched_v = pt.matrix("v", shape=(2, 3)) # shape (2, 3)
>>> batched_A = pt.tensor3("A", shape=(2, 3, 4)) # shape (2, 3, 4)
>>> batched_v = pt.matrix("v", shape=(2, 3))
>>> batched_A = pt.tensor3("A", shape=(2, 3, 4))
>>> result = pt.vecmat(batched_v, batched_A) # shape (2, 4)
>>> # Equivalent to numpy.vecmat(batched_v, batched_A)
"""
out = _vec_matrix_prod(x1, x2)
out = matmul(x2.mT, x1[..., None]).squeeze(-1)
if dtype is not None:
out = out.astype(dtype)
......@@ -4160,18 +4144,18 @@ def vectorize_node_dot(op, node, batched_x, batched_y):
old_y_ndim = old_y.type.ndim
match (old_x_ndim, old_y_ndim):
case (1, 1):
batch_op = _inner_prod
batch_fn = vecdot
case (2, 1):
batch_op = _matrix_vec_prod
batch_fn = matvec
case (1, 2):
batch_op = _vec_matrix_prod
batch_fn = vecmat
case (2, 2):
batch_op = _matrix_matrix_matmul
batch_fn = matmul
case _:
raise ValueError(
f"Core dot Op should have 1D or 2D inputs, got {old_x_ndim}D and {old_y_ndim}D."
)
return batch_op(batched_x, batched_y).owner
return batch_fn(batched_x, batched_y).owner
def nan_to_num(x, nan=0.0, posinf=None, neginf=None):
......
......@@ -98,7 +98,7 @@ from pytensor.tensor.elemwise import DimShuffle, Elemwise
from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.math import (
Dot,
_matrix_matrix_matmul,
_matmul,
add,
mul,
neg,
......@@ -908,7 +908,7 @@ blas_optdb.register(
@register_specialize
@node_rewriter([_matrix_matrix_matmul])
@node_rewriter([_matmul])
def specialize_matmul_to_batched_dot(fgraph, node):
"""Rewrite Matmul (Blockwise matrix-matrix) without implicit broadcasted batched dimension as BatchedDot.
......
......@@ -39,6 +39,7 @@ from pytensor.tensor.rewriting.basic import (
broadcasted_by,
register_canonicalize,
register_specialize,
register_stabilize,
)
from pytensor.tensor.variable import TensorConstant, TensorVariable
......@@ -341,6 +342,7 @@ def is_dimshuffle_useless(new_order, input):
@register_canonicalize
@register_stabilize
@register_specialize
@node_rewriter([DimShuffle])
def local_dimshuffle_lift(fgraph, node):
......
......@@ -26,7 +26,7 @@ from pytensor.tensor.basic import (
from pytensor.tensor.blas import Dot22
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.elemwise import DimShuffle, Elemwise
from pytensor.tensor.math import Dot, Prod, _matrix_matrix_matmul, log, outer, prod
from pytensor.tensor.math import Dot, Prod, _matmul, log, outer, prod
from pytensor.tensor.nlinalg import (
SVD,
KroneckerProduct,
......@@ -284,7 +284,7 @@ def cholesky_ldotlt(fgraph, node):
# This rewrite only applies to matrix Dot
and A.owner.inputs[0].type.ndim == 2
)
or (A.owner.op == _matrix_matrix_matmul)
or (A.owner.op == _matmul)
)
):
return
......
......@@ -28,6 +28,7 @@ from pytensor.tensor.basic import (
as_tensor_variable,
cast,
constant,
expand_dims,
get_underlying_scalar_constant_value,
moveaxis,
ones_like,
......@@ -35,7 +36,6 @@ from pytensor.tensor.basic import (
switch,
zeros_like,
)
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.extra_ops import broadcast_arrays
......@@ -45,10 +45,7 @@ from pytensor.tensor.math import (
Sum,
_conj,
_dot,
_inner_prod,
_matrix_matrix_matmul,
_matrix_vec_prod,
_vec_matrix_prod,
_matmul,
add,
digamma,
dot,
......@@ -182,7 +179,7 @@ def local_lift_transpose_through_dot(fgraph, node):
if not (
is_matrix_transpose(node.outputs[0])
and node.inputs[0].owner
and ((dot_op := node.inputs[0].owner.op) in (_dot, _matrix_matrix_matmul))
and ((dot_op := node.inputs[0].owner.op) in (_dot, _matmul))
):
return False
......@@ -197,60 +194,157 @@ def local_lift_transpose_through_dot(fgraph, node):
return ret
@register_stabilize
@register_specialize
@node_rewriter(tracks=[Blockwise])
def local_batched_matmul_to_core_matmul(fgraph, node):
"""Rewrite matmul where only one of the inputs has batch dimensions to a reshaped core matmul.
def _batched_matmul_to_core_matmul(fgraph, node, allow_reshape: bool):
"""Move batch dimensions of matmul operands to core matmul
Example, if x has batch dimensions, but y not:
Example, if x has batch dimensions that don't overlap with batch dimensions of y
x @ y -> (x.reshape(-1, x.shape[-1]) @ y).reshape(*x.shape[:-1], y.shape[-1])
It also works when y has batch dimensions, but x not.
"""
It also works for batch dimensions of y that don't overlap with batch dimensions of x
# Check whether we have a matmul operation in this node
if not (
isinstance(node.op.core_op, Dot)
and len(node.op.inputs_sig[0]) == 2
and len(node.op.inputs_sig[1]) == 2
):
return None
The rewrite only uses reshape when mixing dimensions, and it can refuse to apply if `allow_reshape=False`
"""
x, y = node.inputs
batch_ndim = node.op.batch_ndim(node)
# Check if x has batch dimensions, but y not (or only broadcastable dimensions)
if any(not b_dim for b_dim in x.type.broadcastable[:-2]) and all(
y.type.broadcastable[:-2]
):
x_stacked = x.reshape((-1, x.shape[-1]))
out_stacked = x_stacked @ y.squeeze(tuple(range(batch_ndim)))
out = out_stacked.reshape((*x.shape[:-1], y.shape[-1]))
return [out]
# Otherwise, check if y has batch dimension, but x not
elif any(not b_dim for b_dim in y.type.broadcastable[:-2]) and all(
x.type.broadcastable[:-2]
):
# For the y batch case we need to first move the batch axes and then reshape
# y.shape == (*b, k, n)
y_tr = moveaxis(y, -2, 0) # (k, *b, n)
y_stacked = y_tr.reshape((y.shape[-2], -1)) # (k, *b * n)
out_stacked = x.squeeze(tuple(range(batch_ndim))) @ y_stacked # (m, *b * n)
out_stacked_tr = out_stacked.reshape(
(x.shape[-2], *y.shape[:-2], y.shape[-1])
) # (m, *b, n)
out = moveaxis(out_stacked_tr, 0, -2) # (*b, m, n)
return [out]
# Both x and y have batch dimensions, nothing to do here
return None
x_axis_to_merge = [
i
for i, (bcast_x, bcast_y) in enumerate(
zip(x.type.broadcastable[:-2], y.type.broadcastable[:-2])
)
if bcast_y and not bcast_x
]
y_axis_to_merge = [
i
for i, (bcast_x, bcast_y) in enumerate(
zip(x.type.broadcastable[:-2], y.type.broadcastable[:-2])
)
if bcast_x and not bcast_y
]
if not (x_axis_to_merge or y_axis_to_merge):
return None
x_shape = tuple(x.shape)
y_shape = tuple(y.shape)
x_is_row = x.type.broadcastable[-2]
y_is_col = y.type.broadcastable[-1]
n_x_axis_to_merge = len(x_axis_to_merge)
n_y_axis_to_merge = len(y_axis_to_merge)
n_axis_to_merge = n_x_axis_to_merge + n_y_axis_to_merge
x_stacked, y_stacked = x, y
dims_were_merged = False
if n_x_axis_to_merge:
# ravel batch dimensions of x on the core (m) axis
x_axis_destination = tuple(range(-n_x_axis_to_merge - 2, -2))
x_stacked = moveaxis(x, x_axis_to_merge, x_axis_destination)
if x_is_row:
# x was a row matrix, squeeze it to clean up the graph
x_stacked = x_stacked.squeeze(-2)
if n_x_axis_to_merge > 1 or not x_is_row:
if not allow_reshape:
# TODO: We could allow the y rewrite to go on
# Or just move one axis (the largest) if x is row
return None
# Ravel moved batch dims together with (m) if needed
x_stacked_shape = tuple(x_stacked.shape)
x_stacked = x_stacked.reshape(
(*x_stacked_shape[: batch_ndim - n_x_axis_to_merge], -1, x_shape[-1])
)
dims_were_merged = True
if n_y_axis_to_merge:
# ravel batch dimensions of y on the core (n) axis
y_axis_destination = tuple(range(-n_y_axis_to_merge - 1, -1))
y_stacked = moveaxis(y, y_axis_to_merge, y_axis_destination)
if y_is_col:
# y was a column matrix, squeeze it to clean up the graph
y_stacked = y_stacked.squeeze(-1)
if n_y_axis_to_merge > 1 or not y_is_col:
if not allow_reshape:
# TODO: We could allow the x rewrite to go on
# Or just move one axis (the largest) if y is col
return None
# Ravel moved batch dims together with (n) if needed
y_stacked_shape = tuple(y_stacked.shape)
y_stacked = y_stacked.reshape(
(*y_stacked_shape[: batch_ndim - n_y_axis_to_merge], y_shape[-2], -1)
)
dims_were_merged = True
# Squeeze x_dims corresponding to merged dimensions of y
x_axis_to_squeeze = np.array(y_axis_to_merge)
for i in reversed(x_axis_to_merge):
# The corresponding dimensions of y may have shifted when we merged dimensions of x
x_axis_to_squeeze[x_axis_to_squeeze > i] -= 1
x_stacked = x_stacked.squeeze(tuple(x_axis_to_squeeze))
# Same for y
y_axis_to_squeeze = np.array(x_axis_to_merge)
for i in reversed(y_axis_to_merge):
y_axis_to_squeeze[y_axis_to_squeeze > i] -= 1
y_stacked = y_stacked.squeeze(tuple(y_axis_to_squeeze))
out_stacked = x_stacked @ y_stacked
# Split back any merged dimensions
if dims_were_merged:
x_merged_shapes = [x_shape[i] for i in x_axis_to_merge]
if not x_is_row:
# Otherwise we handle that later with expand_dims, which is cleaner
x_merged_shapes.append(x_shape[-2])
y_merged_shapes = [y_shape[i] for i in y_axis_to_merge]
if not y_is_col:
# Otherwise we handle that later with expand_dims, which is cleaner
y_merged_shapes.append(y_shape[-1])
out_stacked_shape = tuple(out_stacked.shape)
out_unstacked = out_stacked.reshape(
(
*out_stacked_shape[: batch_ndim - n_axis_to_merge],
*x_merged_shapes,
*y_merged_shapes,
)
)
else:
out_unstacked = out_stacked
# Add back dummy row, col axis
# We do this separately to avoid the reshape as much as we can
if y_is_col and (n_y_axis_to_merge or dims_were_merged):
out_unstacked = expand_dims(out_unstacked, -1)
if x_is_row and (n_x_axis_to_merge or dims_were_merged):
out_unstacked = expand_dims(out_unstacked, -n_y_axis_to_merge - 2)
# Move batch axis back to their original location
source = range(-n_axis_to_merge - 2, 0)
destination = (*x_axis_to_merge, -2, *y_axis_to_merge, -1)
out = moveaxis(out_unstacked, source, destination)
return [out]
@register_canonicalize
@node_rewriter(tracks=[_matmul])
def local_batched_matmul_to_core_matmul(fgraph, node):
# Allow passing batch dimensions of matmul to core vector / column matrices
return _batched_matmul_to_core_matmul(fgraph, node, allow_reshape=False)
@register_specialize
@node_rewriter(tracks=[_matmul])
def local_batched_matmul_to_core_matmul_with_reshape(fgraph, node):
# Allow stacking batch dimensions of matmul with core dimensions, with a reshape operation
# We only apply this in specialize, because grahs with reshape are hard to work with
return _batched_matmul_to_core_matmul(fgraph, node, allow_reshape=True)
@register_canonicalize
@register_specialize
@node_rewriter([_inner_prod, _matrix_vec_prod, _vec_matrix_prod, _matrix_matrix_matmul])
@node_rewriter([_matmul])
def local_blockwise_dot_to_mul(fgraph, node):
"""Rewrite blockwise dots that correspond to multiplication without summation.
......
import numpy as np
import pytest
from pytensor import function
from pytensor import config, function
from pytensor import tensor as pt
from pytensor.compile import get_default_mode
from pytensor.graph import FunctionGraph
from pytensor.graph import FunctionGraph, ancestors
from pytensor.tensor import (
col,
dscalar,
......@@ -21,7 +21,6 @@ from pytensor.tensor import (
vectorize,
)
from pytensor.tensor.blas import BatchedDot
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.rewriting.blas import (
_as_scalar,
......@@ -37,8 +36,11 @@ def XYZab():
return matrix(), matrix(), matrix(), scalar(), scalar()
@pytest.mark.parametrize("valid_case", (True, False))
def test_specialize_matmul_to_batched_dot(valid_case):
@pytest.mark.skipif(
config.mode == "FAST_COMPILE", reason="Test requires specialization rewrites"
)
@pytest.mark.parametrize("aligned", (True, False))
def test_specialize_matmul_to_batched_dot(aligned):
signature = BatchedDot.gufunc_signature
rewrite = specialize_matmul_to_batched_dot.__name__
......@@ -49,23 +51,36 @@ def test_specialize_matmul_to_batched_dot(valid_case):
return np.matmul(x, y)
x = tensor(shape=(7, 5, 3, 3))
if valid_case:
if aligned:
y = tensor(shape=(7, 5, 3, 3))
else:
y = tensor(shape=(5, 3, 3))
out = vectorize(core_pt, signature=signature)(x, y)
assert (
sum(
isinstance(var.owner.op, BatchedDot)
for var in ancestors([out])
if var.owner
)
== 0
)
vectorize_pt = function(
[x, y],
vectorize(core_pt, signature=signature)(x, y),
out,
mode=get_default_mode().including(rewrite),
)
blocwkise_node = any(
isinstance(node.op, Blockwise) for node in vectorize_pt.maker.fgraph.apply_nodes
assert (
sum(
isinstance(var.owner.op, BatchedDot)
for var in ancestors(vectorize_pt.maker.fgraph.outputs)
if var.owner
)
== 1
)
if valid_case:
assert not blocwkise_node
else:
assert blocwkise_node
x_test = np.random.normal(size=x.type.shape).astype(x.type.dtype)
y_test = np.random.normal(size=y.type.shape).astype(y.type.dtype)
......
......@@ -42,6 +42,7 @@ from pytensor.tensor.math import (
Prod,
Sum,
_conj,
_matmul,
add,
arccosh,
arcsinh,
......@@ -4566,6 +4567,88 @@ def test_local_batched_matmul_to_core_matmul():
np.testing.assert_allclose(fn(x_test, y_test), x_test @ y_test)
@pytest.mark.parametrize(
"mat_shape, vec_shape",
[
[(1, 2, 2), (5, 2)],
[(5, 2, 2), (1, 2)],
[(1, 1, 2, 2), (7, 5, 2)],
[(7, 5, 2, 2), (1, 1, 5, 2)],
[(1, 5, 1, 2, 2), (7, 5, 7, 2)],
[(7, 5, 7, 2, 2), (1, 5, 1, 2)],
[(5, 1, 3, 1, 2, 2), (1, 7, 3, 7, 2)],
[(1, 7, 3, 7, 2, 2), (5, 1, 3, 1, 2)],
],
ids=str,
)
@pytest.mark.parametrize("func", ("matvec", "vecmat", "vecdot"))
def test_batch_matvec_to_matmul(func, mat_shape, vec_shape):
def count_matvec_nodes(graph):
# Counts how many matmul nodes actually correspond to matvec or vecmat
return len(
[
var
for var in ancestors([graph])
if (
var.owner is not None
and var.owner.op == _matmul
and (
(var.owner.inputs[0].type.shape[-2] == 1)
or (var.owner.inputs[1].type.shape[-1] == 1)
)
)
]
)
mat = pt.tensor("mat", shape=mat_shape, dtype="float64")
vec = pt.tensor("vec", shape=vec_shape, dtype="float64")
if func == "matvec":
out = pt.matvec(mat, vec)
elif func == "vecmat":
out = pt.vecmat(vec, mat)
elif func == "vecdot":
out = pt.vecdot(mat[..., 0], vec)
else:
raise NotImplementedError(func)
assert count_matvec_nodes(out) == 1
rewritten_out = rewrite_graph(
out,
include=(
"canonicalize",
"specialize",
),
exclude=(
"local_eager_useless_unbatched_blockwise",
"specialize_matmul_to_batched_dot",
),
)
# No `matvec` in the rewritten out if one of the vector can be treated as a matrix
expected = not any(
mat_dim == 1 and vec_dim != 1
for vec_dim, mat_dim in zip(vec_shape[:-1], mat_shape[:-2])
)
if not expected and func == "vecdot":
# In this case there are two vectors, so we may still end up with a `matvec` unless the second vec can also be treated as matrix
expected = not any(
mat_dim != 1 and vec_dim == 1
for vec_dim, mat_dim in zip(vec_shape[:-1], mat_shape[:-2])
)
assert count_matvec_nodes(rewritten_out) == expected
rng = np.random.default_rng(mat_shape + vec_shape)
eval_dict = {mat: rng.random(mat.type.shape), vec: rng.random(vec.type.shape)}
# Evaluate results are correct without further rewrites
no_optimization = Mode(linker="py", optimizer=None)
np.testing.assert_allclose(
rewritten_out.eval(eval_dict, mode=no_optimization),
out.eval(eval_dict, mode=no_optimization),
)
def test_log_kv_stabilization():
x = pt.scalar("x")
out = log(kv(4.5, x))
......@@ -4616,8 +4699,8 @@ def test_local_dot_to_mul(batched, a_shape, b_shape):
out = dot(a, b)
if batched:
batch_a = tensor("batch_a", shape=(1, 5, *a_shape))
batch_b = tensor("batch_b", shape=(7, 1, *b_shape))
batch_a = tensor("batch_a", shape=(2, 1, 5, *a_shape))
batch_b = tensor("batch_b", shape=(2, 7, 1, *b_shape))
out = vectorize_graph(out, {a: batch_a, b: batch_b})
a = batch_a
b = batch_b
......
......@@ -2092,9 +2092,9 @@ class TestDot:
def test_matrix_vector_ops():
"""Test vecdot, matvec, and vecmat helper functions."""
rng = np.random.default_rng(seed=utt.fetch_seed())
rng = np.random.default_rng(2089)
# Create test data with batch dimension (2)
atol = 1e-7 if config.floatX == "float32" else 1e-15
batch_size = 2
dim_k = 4 # Common dimension
dim_m = 3 # Matrix rows
......@@ -2109,7 +2109,6 @@ def test_matrix_vector_ops():
mat_kn_val = random(batch_size, dim_k, dim_n, rng=rng).astype(config.floatX)
vec_k_val = random(batch_size, dim_k, rng=rng).astype(config.floatX)
# Create tensor variables with matching dtype
mat_mk = tensor(
name="mat_mk", shape=(batch_size, dim_m, dim_k), dtype=config.floatX
)
......@@ -2130,7 +2129,7 @@ def test_matrix_vector_ops():
expected_vecdot = np.zeros((batch_size,), dtype=np.int32)
for i in range(batch_size):
expected_vecdot[i] = np.sum(vec_k_val[i] * vec_k_val[i])
np.testing.assert_allclose(result, expected_vecdot)
np.testing.assert_allclose(result, expected_vecdot, atol=atol)
# Test 2: matvec - matrix-vector product
matvec_out = matvec(mat_mk, vec_k)
......@@ -2141,7 +2140,7 @@ def test_matrix_vector_ops():
expected_matvec = np.zeros((batch_size, dim_m), dtype=config.floatX)
for i in range(batch_size):
expected_matvec[i] = np.dot(mat_mk_val[i], vec_k_val[i])
np.testing.assert_allclose(result_matvec, expected_matvec)
np.testing.assert_allclose(result_matvec, expected_matvec, atol=atol)
# Test 3: vecmat - vector-matrix product
vecmat_out = vecmat(vec_k, mat_kn)
......@@ -2152,7 +2151,7 @@ def test_matrix_vector_ops():
expected_vecmat = np.zeros((batch_size, dim_n), dtype=config.floatX)
for i in range(batch_size):
expected_vecmat[i] = np.dot(vec_k_val[i], mat_kn_val[i])
np.testing.assert_allclose(result_vecmat, expected_vecmat)
np.testing.assert_allclose(result_vecmat, expected_vecmat, atol=atol)
class TestTensordot:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论