提交 2751bcc6 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Faster perform method for matmul

Also return matmul for respective vectorize of dot, to avoid creating redundant Blockwise Ops
上级 10c36d2a
...@@ -58,6 +58,7 @@ class Blockwise(Op): ...@@ -58,6 +58,7 @@ class Blockwise(Op):
core_op: Op, core_op: Op,
signature: Optional[str] = None, signature: Optional[str] = None,
name: Optional[str] = None, name: Optional[str] = None,
gufunc_spec: Optional[tuple[str, int, int]] = None,
**kwargs, **kwargs,
): ):
""" """
...@@ -69,7 +70,12 @@ class Blockwise(Op): ...@@ -69,7 +70,12 @@ class Blockwise(Op):
signature signature
Generalized universal function signature, Generalized universal function signature,
e.g., (m,n),(n)->(m) for vectorized matrix-vector multiplication e.g., (m,n),(n)->(m) for vectorized matrix-vector multiplication
gufunc: tuple, Optional
Tuple containing:
1. String import path for a numpy/scipy function (e.g., "numpy.matmul", "scipy.special.softmax")
that implements the blockwised operation of the scalar op.
2 Number of inputs of the function
3 Number of outputs of the function
""" """
if isinstance(core_op, Blockwise): if isinstance(core_op, Blockwise):
raise TypeError("Core Op is already a Blockwise") raise TypeError("Core Op is already a Blockwise")
...@@ -85,6 +91,7 @@ class Blockwise(Op): ...@@ -85,6 +91,7 @@ class Blockwise(Op):
self.signature = signature self.signature = signature
self.name = name self.name = name
self.inputs_sig, self.outputs_sig = _parse_gufunc_signature(signature) self.inputs_sig, self.outputs_sig = _parse_gufunc_signature(signature)
self.gufunc_spec = gufunc_spec
self._gufunc = None self._gufunc = None
super().__init__(**kwargs) super().__init__(**kwargs)
...@@ -297,10 +304,14 @@ class Blockwise(Op): ...@@ -297,10 +304,14 @@ class Blockwise(Op):
return rval return rval
def _create_gufunc(self, node): def _create_gufunc(self, node):
if hasattr(self.core_op, "gufunc_spec"): gufunc_spec = self.gufunc_spec or getattr(self.core_op, "gufunc_spec", None)
self._gufunc = import_func_from_string(self.core_op.gufunc_spec[0])
if gufunc_spec is not None:
self._gufunc = import_func_from_string(gufunc_spec[0])
if self._gufunc: if self._gufunc:
return self._gufunc return self._gufunc
else:
raise ValueError(f"Could not import gufunc {gufunc_spec[0]} for {self}")
n_outs = len(self.outputs_sig) n_outs = len(self.outputs_sig)
core_node = self._create_dummy_core_node(node.inputs) core_node = self._create_dummy_core_node(node.inputs)
......
...@@ -9,6 +9,7 @@ from pytensor import scalar as aes ...@@ -9,6 +9,7 @@ from pytensor import scalar as aes
from pytensor.gradient import DisconnectedType from pytensor.gradient import DisconnectedType
from pytensor.graph.basic import Apply, Variable from pytensor.graph.basic import Apply, Variable
from pytensor.graph.op import Op from pytensor.graph.op import Op
from pytensor.graph.replace import _vectorize_node
from pytensor.link.c.op import COp from pytensor.link.c.op import COp
from pytensor.link.c.params_type import ParamsType from pytensor.link.c.params_type import ParamsType
from pytensor.link.c.type import Generic from pytensor.link.c.type import Generic
...@@ -25,7 +26,7 @@ from pytensor.tensor.basic import ( ...@@ -25,7 +26,7 @@ from pytensor.tensor.basic import (
stack, stack,
switch, switch,
) )
from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.blockwise import Blockwise, vectorize_node_fallback
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise, scalar_elemwise from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise, scalar_elemwise
from pytensor.tensor.shape import shape, specify_broadcastable from pytensor.tensor.shape import shape, specify_broadcastable
from pytensor.tensor.type import ( from pytensor.tensor.type import (
...@@ -2873,7 +2874,11 @@ def logsumexp(x, axis=None, keepdims=False): ...@@ -2873,7 +2874,11 @@ def logsumexp(x, axis=None, keepdims=False):
return log(sum(exp(x), axis=axis, keepdims=keepdims)) return log(sum(exp(x), axis=axis, keepdims=keepdims))
_matrix_matrix_matmul = Blockwise(_dot, signature="(n,k),(k,m)->(n,m)") _matrix_matrix_matmul = Blockwise(
_dot,
signature="(m,k),(k,n)->(m,n)",
gufunc_spec=("numpy.matmul", 2, 1),
)
def matmul(x1: "ArrayLike", x2: "ArrayLike", dtype: Optional["DTypeLike"] = None): def matmul(x1: "ArrayLike", x2: "ArrayLike", dtype: Optional["DTypeLike"] = None):
...@@ -2937,6 +2942,15 @@ def matmul(x1: "ArrayLike", x2: "ArrayLike", dtype: Optional["DTypeLike"] = None ...@@ -2937,6 +2942,15 @@ def matmul(x1: "ArrayLike", x2: "ArrayLike", dtype: Optional["DTypeLike"] = None
return out return out
@_vectorize_node.register(Dot)
def vectorize_node_to_matmul(op, node, batched_x, batched_y):
old_x, old_y = node.inputs
if old_x.type.ndim == 2 and old_y.type.ndim == 2:
return matmul(batched_x, batched_y).owner
else:
return vectorize_node_fallback(op, node, batched_x, batched_y)
__all__ = [ __all__ = [
"max_and_argmax", "max_and_argmax",
"max", "max",
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论