提交 89c7544a authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Specialize matmul to batched dot

上级 75485c13
......@@ -59,6 +59,8 @@ import time
import numpy as np
from pytensor.tensor.rewriting.basic import register_specialize
try:
import numpy.__config__ # noqa
......@@ -79,12 +81,12 @@ from pytensor.graph.rewriting.basic import (
)
from pytensor.graph.rewriting.db import SequenceDB
from pytensor.graph.utils import InconsistencyError
from pytensor.printing import debugprint
from pytensor.tensor import basic as at
from pytensor.tensor.blas import (
Dot22,
_dot22,
_dot22scalar,
batched_dot,
gemm_inplace,
gemm_no_inplace,
gemv_inplace,
......@@ -94,7 +96,7 @@ from pytensor.tensor.blas import (
)
from pytensor.tensor.elemwise import DimShuffle, Elemwise
from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.math import Dot, add, mul, neg, sub
from pytensor.tensor.math import Dot, _matrix_matrix_matmul, add, mul, neg, sub
from pytensor.tensor.rewriting.elemwise import local_dimshuffle_lift
from pytensor.tensor.type import (
DenseTensorType,
......@@ -899,9 +901,32 @@ blas_optdb.register(
)
# from opt import register_specialize, register_canonicalize
# @register_specialize
@node_rewriter([sub, add])
def local_print_as_we_go_along(fgraph, node):
if node.op in (sub, add):
debugprint(node)
@register_specialize
@node_rewriter([_matrix_matrix_matmul])
def specialize_matmul_to_batched_dot(fgraph, node):
"""Rewrite Matmul (Blockwise matrix-matrix) without implicit broadcasted batched dimension as BatchedDot.
TODO: Do the same for Blockwise BatchedDot
"""
x, y = node.inputs
# BatchedDot does not allow implicit broadcasting of the batch dimensions
# We do not want to explicitly broadcast as it may result in huge arrays
if x.type.broadcastable[:-2] != y.type.broadcastable[:-2]:
return None
x_shape = tuple(x.shape)
y_shape = tuple(y.shape)
if len(x_shape) > 3:
# If we have more than one batch dim, ravel it
x = x.reshape((-1, x_shape[-2], x_shape[-1]))
y = y.reshape((-1, y_shape[-2], y_shape[-1]))
new_out = batched_dot(x, y)
if len(x_shape) > 3:
# And then unravel it
new_out = new_out.reshape((*x_shape[:-2], x_shape[-2], y_shape[-1]))
copy_stack_trace(node.outputs, [new_out])
return [new_out]
import numpy as np
import pytest
from pytensor import function
from pytensor.compile import get_default_mode
from pytensor.tensor import matmul, tensor, vectorize
from pytensor.tensor.blas import BatchedDot
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.rewriting.blas import specialize_matmul_to_batched_dot
@pytest.mark.parametrize("valid_case", (True, False))
def test_specialize_matmul_to_batched_dot(valid_case):
signature = BatchedDot.gufunc_signature
rewrite = specialize_matmul_to_batched_dot.__name__
def core_pt(x, y):
return matmul(x, y)
def core_np(x, y):
return np.matmul(x, y)
x = tensor(shape=(7, 5, 3, 3))
if valid_case:
y = tensor(shape=(7, 5, 3, 3))
else:
y = tensor(shape=(5, 3, 3))
vectorize_pt = function(
[x, y],
vectorize(core_pt, signature=signature)(x, y),
mode=get_default_mode().including(rewrite),
)
blocwkise_node = any(
isinstance(node.op, Blockwise) for node in vectorize_pt.maker.fgraph.apply_nodes
)
if valid_case:
assert not blocwkise_node
else:
assert blocwkise_node
x_test = np.random.normal(size=x.type.shape).astype(x.type.dtype)
y_test = np.random.normal(size=y.type.shape).astype(y.type.dtype)
vectorize_np = np.vectorize(core_np, signature=signature)
np.testing.assert_allclose(
vectorize_pt(x_test, y_test),
vectorize_np(x_test, y_test),
)
......@@ -6,6 +6,7 @@ import pytest
import pytensor
from pytensor import config, function
from pytensor.compile import get_mode
from pytensor.gradient import grad
from pytensor.graph import Apply, Op
from pytensor.graph.replace import vectorize_node
......@@ -13,6 +14,7 @@ from pytensor.raise_op import assert_op
from pytensor.tensor import diagonal, log, tensor
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.nlinalg import MatrixInverse
from pytensor.tensor.rewriting.blas import specialize_matmul_to_batched_dot
from pytensor.tensor.slinalg import Cholesky, Solve, cholesky, solve_triangular
from pytensor.tensor.utils import _parse_gufunc_signature
......@@ -45,7 +47,11 @@ def check_blockwise_runtime_broadcasting(mode):
b = tensor("b", shape=(None, 5, 3))
out = a @ b
fn = function([a, b], out, mode=mode)
fn = function(
[a, b],
out,
mode=get_mode(mode).excluding(specialize_matmul_to_batched_dot.__name__),
)
assert isinstance(fn.maker.fgraph.outputs[0].owner.op, Blockwise)
for valid_test_values in [
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论