提交 c52154d3 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Add rewrite for matmul when only one of the inputs has batched dimensions

This rewrites replaces a batched matmul by a core matmul by raveling the batched dimensions of the batched input, doing a 2D matmul and then unravelling the batched dimensions of the output. This sidesteps the Blockwise Dot operation and allows specialization into BLAS routines. The idea was taken from these two discussions: https://github.com/numpy/numpy/issues/7569 https://github.com/numpy/numpy/issues/8957
上级 0fc2cd8e
...@@ -31,11 +31,13 @@ from pytensor.tensor.basic import ( ...@@ -31,11 +31,13 @@ from pytensor.tensor.basic import (
constant, constant,
extract_constant, extract_constant,
get_underlying_scalar_constant_value, get_underlying_scalar_constant_value,
moveaxis,
ones_like, ones_like,
register_infer_shape, register_infer_shape,
switch, switch,
zeros_like, zeros_like,
) )
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
from pytensor.tensor.exceptions import NotScalarConstantError from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.extra_ops import broadcast_arrays from pytensor.tensor.extra_ops import broadcast_arrays
...@@ -217,6 +219,57 @@ def local_lift_transpose_through_dot(fgraph, node): ...@@ -217,6 +219,57 @@ def local_lift_transpose_through_dot(fgraph, node):
return ret return ret
@register_stabilize
@register_specialize
@node_rewriter(tracks=[Blockwise])
def local_batched_matmul_to_core_matmul(fgraph, node):
"""Rewrite matmul where only one of the inputs has batch dimensions to a reshaped core matmul.
Example, if x has batch dimensions, but y not:
x @ y -> (x.reshape(-1, x.shape[-1]) @ y).reshape(*x.shape[:-1], y.shape[-1])
It also works when y has batch dimensions, but x not.
"""
# Check whether we have a matmul operation in this node
if not (
isinstance(node.op.core_op, Dot)
and len(node.op.inputs_sig[0]) == 2
and len(node.op.inputs_sig[1]) == 2
):
return None
x, y = node.inputs
batch_ndim = node.op.batch_ndim(node)
# Check if x has batch dimensions, but y not (or only broadcastable dimensions)
if any(not b_dim for b_dim in x.type.broadcastable[:-2]) and all(
y.type.broadcastable[:-2]
):
x_stacked = x.reshape((-1, x.shape[-1]))
out_stacked = x_stacked @ y.squeeze(tuple(range(batch_ndim)))
out = out_stacked.reshape((*x.shape[:-1], y.shape[-1]))
return [out]
# Otherwise, check if y has batch dimension, but x not
elif any(not b_dim for b_dim in y.type.broadcastable[:-2]) and all(
x.type.broadcastable[:-2]
):
# For the y batch case we need to first move the batch axes and then reshape
# y.shape == (*b, k, n)
y_tr = moveaxis(y, -2, 0) # (k, *b, n)
y_stacked = y_tr.reshape((y.shape[-2], -1)) # (k, *b * n)
out_stacked = x.squeeze(tuple(range(batch_ndim))) @ y_stacked # (m, *b * n)
out_stacked_tr = out_stacked.reshape(
(x.shape[-2], *y.shape[:-2], y.shape[-1])
) # (m, *b, n)
out = moveaxis(out_stacked_tr, 0, -2) # (*b, m, n)
return [out]
# Both x and y have batch dimensions, nothing to do here
return None
def is_inverse_pair(node_op, prev_op, inv_pair): def is_inverse_pair(node_op, prev_op, inv_pair):
""" """
Given two consecutive operations, check if they are the Given two consecutive operations, check if they are the
......
...@@ -34,6 +34,7 @@ from pytensor.tensor import inplace ...@@ -34,6 +34,7 @@ from pytensor.tensor import inplace
from pytensor.tensor.basic import Alloc, constant, join, second, switch from pytensor.tensor.basic import Alloc, constant, join, second, switch
from pytensor.tensor.blas import Dot22, Gemv from pytensor.tensor.blas import Dot22, Gemv
from pytensor.tensor.blas_c import CGemv from pytensor.tensor.blas_c import CGemv
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
from pytensor.tensor.math import Dot, MaxAndArgmax, Prod, Sum, _conj from pytensor.tensor.math import Dot, MaxAndArgmax, Prod, Sum, _conj
from pytensor.tensor.math import abs as pt_abs from pytensor.tensor.math import abs as pt_abs
...@@ -4427,3 +4428,51 @@ def test_polygamma_specialization(): ...@@ -4427,3 +4428,51 @@ def test_polygamma_specialization():
assert isinstance(fn_outs[0].owner.op.scalar_op, Psi) assert isinstance(fn_outs[0].owner.op.scalar_op, Psi)
assert isinstance(fn_outs[1].owner.op.scalar_op, TriGamma) assert isinstance(fn_outs[1].owner.op.scalar_op, TriGamma)
assert isinstance(fn_outs[2].owner.op.scalar_op, PolyGamma) assert isinstance(fn_outs[2].owner.op.scalar_op, PolyGamma)
@pytest.mark.skipif(
config.mode == "FAST_COMPILE",
reason="Rewrite is only relevant in FAST_RUN",
)
def test_local_batched_matmul_to_core_matmul():
rng = np.random.default_rng(seed=4433)
# x is batched but not y
x = pt.tensor("x", shape=(None, 3, 2), dtype="float64")
y = pt.tensor("y", shape=(2, 2), dtype="float64")
out = x @ y
assert isinstance(out.owner.op, Blockwise)
fn = pytensor.function([x, y], out)
assert not any(
isinstance(node.op, Blockwise) for node in fn.maker.fgraph.apply_nodes
)
x_test = rng.normal(size=(5, 3, 2))
y_test = rng.normal(size=(2, 2))
np.testing.assert_allclose(fn(x_test, y_test), x_test @ y_test)
# y is batched but not x
x = pt.tensor("x", shape=(1, 3, 2), dtype="float64")
y = pt.tensor("y", shape=(5, 2, 2), dtype="float64")
out = x @ y
assert isinstance(out.owner.op, Blockwise)
fn = pytensor.function([x, y], out)
assert not any(
isinstance(node.op, Blockwise) for node in fn.maker.fgraph.apply_nodes
)
x_test = rng.normal(size=(1, 3, 2))
y_test = rng.normal(size=(5, 2, 2))
np.testing.assert_allclose(fn(x_test, y_test), x_test @ y_test)
# Both x and y are batched, rewrite does not apply
x = pt.tensor("x", shape=(None, 3, 2), dtype="float64")
y = pt.tensor("y", shape=(5, 2, 2), dtype="float64")
out = x @ y
fn = pytensor.function([x, y], out)
x_test = rng.normal(size=(5, 3, 2))
y_test = rng.normal(size=(5, 2, 2))
np.testing.assert_allclose(fn(x_test, y_test), x_test @ y_test)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论