提交 2751bcc6 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Faster perform method for matmul

Also return matmul for respective vectorize of dot, to avoid creating redundant Blockwise Ops
上级 10c36d2a
......@@ -58,6 +58,7 @@ class Blockwise(Op):
core_op: Op,
signature: Optional[str] = None,
name: Optional[str] = None,
gufunc_spec: Optional[tuple[str, int, int]] = None,
**kwargs,
):
"""
......@@ -69,7 +70,12 @@ class Blockwise(Op):
signature
Generalized universal function signature,
e.g., (m,n),(n)->(m) for vectorized matrix-vector multiplication
gufunc: tuple, Optional
Tuple containing:
1. String import path for a numpy/scipy function (e.g., "numpy.matmul", "scipy.special.softmax")
that implements the blockwised operation of the scalar op.
2 Number of inputs of the function
3 Number of outputs of the function
"""
if isinstance(core_op, Blockwise):
raise TypeError("Core Op is already a Blockwise")
......@@ -85,6 +91,7 @@ class Blockwise(Op):
self.signature = signature
self.name = name
self.inputs_sig, self.outputs_sig = _parse_gufunc_signature(signature)
self.gufunc_spec = gufunc_spec
self._gufunc = None
super().__init__(**kwargs)
......@@ -297,10 +304,14 @@ class Blockwise(Op):
return rval
def _create_gufunc(self, node):
if hasattr(self.core_op, "gufunc_spec"):
self._gufunc = import_func_from_string(self.core_op.gufunc_spec[0])
gufunc_spec = self.gufunc_spec or getattr(self.core_op, "gufunc_spec", None)
if gufunc_spec is not None:
self._gufunc = import_func_from_string(gufunc_spec[0])
if self._gufunc:
return self._gufunc
else:
raise ValueError(f"Could not import gufunc {gufunc_spec[0]} for {self}")
n_outs = len(self.outputs_sig)
core_node = self._create_dummy_core_node(node.inputs)
......
......@@ -9,6 +9,7 @@ from pytensor import scalar as aes
from pytensor.gradient import DisconnectedType
from pytensor.graph.basic import Apply, Variable
from pytensor.graph.op import Op
from pytensor.graph.replace import _vectorize_node
from pytensor.link.c.op import COp
from pytensor.link.c.params_type import ParamsType
from pytensor.link.c.type import Generic
......@@ -25,7 +26,7 @@ from pytensor.tensor.basic import (
stack,
switch,
)
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.blockwise import Blockwise, vectorize_node_fallback
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise, scalar_elemwise
from pytensor.tensor.shape import shape, specify_broadcastable
from pytensor.tensor.type import (
......@@ -2873,7 +2874,11 @@ def logsumexp(x, axis=None, keepdims=False):
return log(sum(exp(x), axis=axis, keepdims=keepdims))
_matrix_matrix_matmul = Blockwise(_dot, signature="(n,k),(k,m)->(n,m)")
_matrix_matrix_matmul = Blockwise(
_dot,
signature="(m,k),(k,n)->(m,n)",
gufunc_spec=("numpy.matmul", 2, 1),
)
def matmul(x1: "ArrayLike", x2: "ArrayLike", dtype: Optional["DTypeLike"] = None):
......@@ -2937,6 +2942,15 @@ def matmul(x1: "ArrayLike", x2: "ArrayLike", dtype: Optional["DTypeLike"] = None
return out
@_vectorize_node.register(Dot)
def vectorize_node_to_matmul(op, node, batched_x, batched_y):
old_x, old_y = node.inputs
if old_x.type.ndim == 2 and old_y.type.ndim == 2:
return matmul(batched_x, batched_y).owner
else:
return vectorize_node_fallback(op, node, batched_x, batched_y)
__all__ = [
"max_and_argmax",
"max",
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论