提交 c52154d3 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Add rewrite for matmul when only one of the inputs has batched dimensions

This rewrites replaces a batched matmul by a core matmul by raveling the batched dimensions of the batched input, doing a 2D matmul and then unravelling the batched dimensions of the output. This sidesteps the Blockwise Dot operation and allows specialization into BLAS routines. The idea was taken from these two discussions: https://github.com/numpy/numpy/issues/7569 https://github.com/numpy/numpy/issues/8957
上级 0fc2cd8e
......@@ -31,11 +31,13 @@ from pytensor.tensor.basic import (
constant,
extract_constant,
get_underlying_scalar_constant_value,
moveaxis,
ones_like,
register_infer_shape,
switch,
zeros_like,
)
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.extra_ops import broadcast_arrays
......@@ -217,6 +219,57 @@ def local_lift_transpose_through_dot(fgraph, node):
return ret
@register_stabilize
@register_specialize
@node_rewriter(tracks=[Blockwise])
def local_batched_matmul_to_core_matmul(fgraph, node):
"""Rewrite matmul where only one of the inputs has batch dimensions to a reshaped core matmul.
Example, if x has batch dimensions, but y not:
x @ y -> (x.reshape(-1, x.shape[-1]) @ y).reshape(*x.shape[:-1], y.shape[-1])
It also works when y has batch dimensions, but x not.
"""
# Check whether we have a matmul operation in this node
if not (
isinstance(node.op.core_op, Dot)
and len(node.op.inputs_sig[0]) == 2
and len(node.op.inputs_sig[1]) == 2
):
return None
x, y = node.inputs
batch_ndim = node.op.batch_ndim(node)
# Check if x has batch dimensions, but y not (or only broadcastable dimensions)
if any(not b_dim for b_dim in x.type.broadcastable[:-2]) and all(
y.type.broadcastable[:-2]
):
x_stacked = x.reshape((-1, x.shape[-1]))
out_stacked = x_stacked @ y.squeeze(tuple(range(batch_ndim)))
out = out_stacked.reshape((*x.shape[:-1], y.shape[-1]))
return [out]
# Otherwise, check if y has batch dimension, but x not
elif any(not b_dim for b_dim in y.type.broadcastable[:-2]) and all(
x.type.broadcastable[:-2]
):
# For the y batch case we need to first move the batch axes and then reshape
# y.shape == (*b, k, n)
y_tr = moveaxis(y, -2, 0) # (k, *b, n)
y_stacked = y_tr.reshape((y.shape[-2], -1)) # (k, *b * n)
out_stacked = x.squeeze(tuple(range(batch_ndim))) @ y_stacked # (m, *b * n)
out_stacked_tr = out_stacked.reshape(
(x.shape[-2], *y.shape[:-2], y.shape[-1])
) # (m, *b, n)
out = moveaxis(out_stacked_tr, 0, -2) # (*b, m, n)
return [out]
# Both x and y have batch dimensions, nothing to do here
return None
def is_inverse_pair(node_op, prev_op, inv_pair):
"""
Given two consecutive operations, check if they are the
......
......@@ -34,6 +34,7 @@ from pytensor.tensor import inplace
from pytensor.tensor.basic import Alloc, constant, join, second, switch
from pytensor.tensor.blas import Dot22, Gemv
from pytensor.tensor.blas_c import CGemv
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
from pytensor.tensor.math import Dot, MaxAndArgmax, Prod, Sum, _conj
from pytensor.tensor.math import abs as pt_abs
......@@ -4427,3 +4428,51 @@ def test_polygamma_specialization():
assert isinstance(fn_outs[0].owner.op.scalar_op, Psi)
assert isinstance(fn_outs[1].owner.op.scalar_op, TriGamma)
assert isinstance(fn_outs[2].owner.op.scalar_op, PolyGamma)
@pytest.mark.skipif(
config.mode == "FAST_COMPILE",
reason="Rewrite is only relevant in FAST_RUN",
)
def test_local_batched_matmul_to_core_matmul():
rng = np.random.default_rng(seed=4433)
# x is batched but not y
x = pt.tensor("x", shape=(None, 3, 2), dtype="float64")
y = pt.tensor("y", shape=(2, 2), dtype="float64")
out = x @ y
assert isinstance(out.owner.op, Blockwise)
fn = pytensor.function([x, y], out)
assert not any(
isinstance(node.op, Blockwise) for node in fn.maker.fgraph.apply_nodes
)
x_test = rng.normal(size=(5, 3, 2))
y_test = rng.normal(size=(2, 2))
np.testing.assert_allclose(fn(x_test, y_test), x_test @ y_test)
# y is batched but not x
x = pt.tensor("x", shape=(1, 3, 2), dtype="float64")
y = pt.tensor("y", shape=(5, 2, 2), dtype="float64")
out = x @ y
assert isinstance(out.owner.op, Blockwise)
fn = pytensor.function([x, y], out)
assert not any(
isinstance(node.op, Blockwise) for node in fn.maker.fgraph.apply_nodes
)
x_test = rng.normal(size=(1, 3, 2))
y_test = rng.normal(size=(5, 2, 2))
np.testing.assert_allclose(fn(x_test, y_test), x_test @ y_test)
# Both x and y are batched, rewrite does not apply
x = pt.tensor("x", shape=(None, 3, 2), dtype="float64")
y = pt.tensor("y", shape=(5, 2, 2), dtype="float64")
out = x @ y
fn = pytensor.function([x, y], out)
x_test = rng.normal(size=(5, 3, 2))
y_test = rng.normal(size=(5, 2, 2))
np.testing.assert_allclose(fn(x_test, y_test), x_test @ y_test)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论