提交 89c7544a authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Specialize matmul to batched dot

上级 75485c13
...@@ -59,6 +59,8 @@ import time ...@@ -59,6 +59,8 @@ import time
import numpy as np import numpy as np
from pytensor.tensor.rewriting.basic import register_specialize
try: try:
import numpy.__config__ # noqa import numpy.__config__ # noqa
...@@ -79,12 +81,12 @@ from pytensor.graph.rewriting.basic import ( ...@@ -79,12 +81,12 @@ from pytensor.graph.rewriting.basic import (
) )
from pytensor.graph.rewriting.db import SequenceDB from pytensor.graph.rewriting.db import SequenceDB
from pytensor.graph.utils import InconsistencyError from pytensor.graph.utils import InconsistencyError
from pytensor.printing import debugprint
from pytensor.tensor import basic as at from pytensor.tensor import basic as at
from pytensor.tensor.blas import ( from pytensor.tensor.blas import (
Dot22, Dot22,
_dot22, _dot22,
_dot22scalar, _dot22scalar,
batched_dot,
gemm_inplace, gemm_inplace,
gemm_no_inplace, gemm_no_inplace,
gemv_inplace, gemv_inplace,
...@@ -94,7 +96,7 @@ from pytensor.tensor.blas import ( ...@@ -94,7 +96,7 @@ from pytensor.tensor.blas import (
) )
from pytensor.tensor.elemwise import DimShuffle, Elemwise from pytensor.tensor.elemwise import DimShuffle, Elemwise
from pytensor.tensor.exceptions import NotScalarConstantError from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.math import Dot, add, mul, neg, sub from pytensor.tensor.math import Dot, _matrix_matrix_matmul, add, mul, neg, sub
from pytensor.tensor.rewriting.elemwise import local_dimshuffle_lift from pytensor.tensor.rewriting.elemwise import local_dimshuffle_lift
from pytensor.tensor.type import ( from pytensor.tensor.type import (
DenseTensorType, DenseTensorType,
...@@ -899,9 +901,32 @@ blas_optdb.register( ...@@ -899,9 +901,32 @@ blas_optdb.register(
) )
# from opt import register_specialize, register_canonicalize @register_specialize
# @register_specialize @node_rewriter([_matrix_matrix_matmul])
@node_rewriter([sub, add]) def specialize_matmul_to_batched_dot(fgraph, node):
def local_print_as_we_go_along(fgraph, node): """Rewrite Matmul (Blockwise matrix-matrix) without implicit broadcasted batched dimension as BatchedDot.
if node.op in (sub, add):
debugprint(node) TODO: Do the same for Blockwise BatchedDot
"""
x, y = node.inputs
# BatchedDot does not allow implicit broadcasting of the batch dimensions
# We do not want to explicitly broadcast as it may result in huge arrays
if x.type.broadcastable[:-2] != y.type.broadcastable[:-2]:
return None
x_shape = tuple(x.shape)
y_shape = tuple(y.shape)
if len(x_shape) > 3:
# If we have more than one batch dim, ravel it
x = x.reshape((-1, x_shape[-2], x_shape[-1]))
y = y.reshape((-1, y_shape[-2], y_shape[-1]))
new_out = batched_dot(x, y)
if len(x_shape) > 3:
# And then unravel it
new_out = new_out.reshape((*x_shape[:-2], x_shape[-2], y_shape[-1]))
copy_stack_trace(node.outputs, [new_out])
return [new_out]
import numpy as np
import pytest
from pytensor import function
from pytensor.compile import get_default_mode
from pytensor.tensor import matmul, tensor, vectorize
from pytensor.tensor.blas import BatchedDot
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.rewriting.blas import specialize_matmul_to_batched_dot
@pytest.mark.parametrize("valid_case", (True, False))
def test_specialize_matmul_to_batched_dot(valid_case):
signature = BatchedDot.gufunc_signature
rewrite = specialize_matmul_to_batched_dot.__name__
def core_pt(x, y):
return matmul(x, y)
def core_np(x, y):
return np.matmul(x, y)
x = tensor(shape=(7, 5, 3, 3))
if valid_case:
y = tensor(shape=(7, 5, 3, 3))
else:
y = tensor(shape=(5, 3, 3))
vectorize_pt = function(
[x, y],
vectorize(core_pt, signature=signature)(x, y),
mode=get_default_mode().including(rewrite),
)
blocwkise_node = any(
isinstance(node.op, Blockwise) for node in vectorize_pt.maker.fgraph.apply_nodes
)
if valid_case:
assert not blocwkise_node
else:
assert blocwkise_node
x_test = np.random.normal(size=x.type.shape).astype(x.type.dtype)
y_test = np.random.normal(size=y.type.shape).astype(y.type.dtype)
vectorize_np = np.vectorize(core_np, signature=signature)
np.testing.assert_allclose(
vectorize_pt(x_test, y_test),
vectorize_np(x_test, y_test),
)
...@@ -6,6 +6,7 @@ import pytest ...@@ -6,6 +6,7 @@ import pytest
import pytensor import pytensor
from pytensor import config, function from pytensor import config, function
from pytensor.compile import get_mode
from pytensor.gradient import grad from pytensor.gradient import grad
from pytensor.graph import Apply, Op from pytensor.graph import Apply, Op
from pytensor.graph.replace import vectorize_node from pytensor.graph.replace import vectorize_node
...@@ -13,6 +14,7 @@ from pytensor.raise_op import assert_op ...@@ -13,6 +14,7 @@ from pytensor.raise_op import assert_op
from pytensor.tensor import diagonal, log, tensor from pytensor.tensor import diagonal, log, tensor
from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.nlinalg import MatrixInverse from pytensor.tensor.nlinalg import MatrixInverse
from pytensor.tensor.rewriting.blas import specialize_matmul_to_batched_dot
from pytensor.tensor.slinalg import Cholesky, Solve, cholesky, solve_triangular from pytensor.tensor.slinalg import Cholesky, Solve, cholesky, solve_triangular
from pytensor.tensor.utils import _parse_gufunc_signature from pytensor.tensor.utils import _parse_gufunc_signature
...@@ -45,7 +47,11 @@ def check_blockwise_runtime_broadcasting(mode): ...@@ -45,7 +47,11 @@ def check_blockwise_runtime_broadcasting(mode):
b = tensor("b", shape=(None, 5, 3)) b = tensor("b", shape=(None, 5, 3))
out = a @ b out = a @ b
fn = function([a, b], out, mode=mode) fn = function(
[a, b],
out,
mode=get_mode(mode).excluding(specialize_matmul_to_batched_dot.__name__),
)
assert isinstance(fn.maker.fgraph.outputs[0].owner.op, Blockwise) assert isinstance(fn.maker.fgraph.outputs[0].owner.op, Blockwise)
for valid_test_values in [ for valid_test_values in [
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论