提交 071eadd8 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Remove Matmul Operator in favor of Blockwise Dot

上级 7c58661b
...@@ -25,11 +25,11 @@ from pytensor.tensor.basic import ( ...@@ -25,11 +25,11 @@ from pytensor.tensor.basic import (
stack, stack,
switch, switch,
) )
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise, scalar_elemwise from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise, scalar_elemwise
from pytensor.tensor.shape import shape, specify_broadcastable from pytensor.tensor.shape import shape, specify_broadcastable
from pytensor.tensor.type import ( from pytensor.tensor.type import (
DenseTensorType, DenseTensorType,
TensorType,
complex_dtypes, complex_dtypes,
continuous_dtypes, continuous_dtypes,
discrete_dtypes, discrete_dtypes,
...@@ -2868,93 +2868,7 @@ def logsumexp(x, axis=None, keepdims=False): ...@@ -2868,93 +2868,7 @@ def logsumexp(x, axis=None, keepdims=False):
return log(sum(exp(x), axis=axis, keepdims=keepdims)) return log(sum(exp(x), axis=axis, keepdims=keepdims))
class MatMul(Op): _matrix_matrix_matmul = Blockwise(_dot, signature="(n,k),(k,m)->(n,m)")
__props__ = ("dtype",)
def __init__(self, dtype=None):
self.dtype = dtype
@classmethod
def _get_output_shape(cls, x1, x2, shapes, validate=False):
x1_shape, x2_shape = shapes
if x1.ndim == 1 and x2.ndim == 1:
if validate and x1_shape[0] != x2_shape[0]:
raise ValueError("1d inputs must have the same length.")
return ()
elif x1.ndim == 1 and x2.ndim > 1:
if validate and x1_shape[0] != x2_shape[-2]:
raise ValueError(
"length of input 1 must be equal the length "
"of the 2nd-last dimension of input 2"
)
return x2_shape[:-2] + x2_shape[-1:]
elif x1.ndim > 1 and x2.ndim == 1:
if validate and x1_shape[-1] != x2_shape[0]:
raise ValueError(
"length of input 2 must be equal the length "
"of the last dimension of input 1"
)
return x1_shape[:-1]
elif x1.ndim == 2 and x2.ndim == 2:
if validate and x1_shape[-1] != x2_shape[0]:
raise ValueError(
"number of columns of input 1 must be equal to "
"the number of rows of input 2"
)
return x1_shape[:-1] + x2_shape[-1:]
elif x1.ndim > 2 and x2.ndim == 2:
if validate and x1_shape[-1] != x2_shape[0]:
raise ValueError(
"number of rows of input 2 must be equal to "
"the length of the last dimension of input 1"
)
return x1_shape[:-2] + x1_shape[-2:-1] + x2_shape[-1:]
elif x1.ndim == 2 and x2.ndim > 2:
if validate and x1_shape[-1] != x2_shape[-2]:
raise ValueError(
"number of columns of input 1 must be equal "
"the length of the 2nd-last dimension of input 2"
)
return x2_shape[:-2] + x1_shape[-2:-1] + x2_shape[-1:]
else:
if validate:
from pytensor.tensor.random.basic import broadcast_shapes
bshape = broadcast_shapes(x1_shape[:-2], x2_shape[:-2])
if x1_shape[-1] != x2_shape[-2]:
raise ValueError(
"length of the last dimension of input 1 must be equal "
"to the length of the 2nd-last dimension of input 2"
)
else:
from pytensor.tensor.extra_ops import broadcast_shape
bshape = broadcast_shape(
x1_shape[:-2], x2_shape[:-2], arrays_are_shapes=True
)
return bshape + x1_shape[-2:-1] + x2_shape[-1:]
def make_node(self, a, b):
a = as_tensor_variable(a)
b = as_tensor_variable(b)
if 0 in {a.ndim, b.ndim}:
raise ValueError("inputs to `matmul` cannot be scalar.")
out_shape = self._get_output_shape(
a, b, (a.type.shape, b.type.shape), validate=True
)
out = TensorType(dtype=self.dtype, shape=out_shape)()
return Apply(self, [a, b], [out])
def perform(self, node, inputs, outputs):
x1, x2 = inputs
outputs[0][0] = np.matmul(x1, x2, dtype=self.dtype)
def infer_shape(self, fgraph, node, shapes):
x1, x2 = node.inputs
return [self._get_output_shape(x1, x2, shapes)]
def matmul(x1: "ArrayLike", x2: "ArrayLike", dtype: Optional["DTypeLike"] = None): def matmul(x1: "ArrayLike", x2: "ArrayLike", dtype: Optional["DTypeLike"] = None):
...@@ -2999,7 +2913,23 @@ def matmul(x1: "ArrayLike", x2: "ArrayLike", dtype: Optional["DTypeLike"] = None ...@@ -2999,7 +2913,23 @@ def matmul(x1: "ArrayLike", x2: "ArrayLike", dtype: Optional["DTypeLike"] = None
- Stacks of matrices are broadcast together as if the matrices were elements, - Stacks of matrices are broadcast together as if the matrices were elements,
respecting the signature ``(n, k), (k, m) -> (n, m)``: respecting the signature ``(n, k), (k, m) -> (n, m)``:
""" """
return MatMul(dtype=dtype)(x1, x2) x1 = as_tensor_variable(x1)
x2 = as_tensor_variable(x2)
if x1.type.ndim == 0 or x2.type.ndim == 0:
raise ValueError("matmul operand cannot be scalar")
if x1.type.ndim == 1 and x2.type.ndim == 1:
out = _dot(x1, x2)
elif x1.type.ndim == 1:
out = _matrix_matrix_matmul(x1[None], x2).squeeze(-2)
elif x2.type.ndim == 1:
out = _matrix_matrix_matmul(x1, x2[:, None]).squeeze(-1)
else:
out = _matrix_matrix_matmul(x1, x2)
if dtype is not None:
out = out.astype(dtype)
return out
__all__ = [ __all__ = [
......
...@@ -3,6 +3,8 @@ from pytensor.graph import node_rewriter ...@@ -3,6 +3,8 @@ from pytensor.graph import node_rewriter
from pytensor.graph.replace import vectorize_node from pytensor.graph.replace import vectorize_node
from pytensor.graph.rewriting.basic import copy_stack_trace, out2in from pytensor.graph.rewriting.basic import copy_stack_trace, out2in
from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.math import _matrix_matrix_matmul
from pytensor.tensor.rewriting.basic import register_canonicalize
@node_rewriter([Blockwise]) @node_rewriter([Blockwise])
...@@ -40,3 +42,10 @@ optdb.register( ...@@ -40,3 +42,10 @@ optdb.register(
"blockwise", "blockwise",
position=49, position=49,
) )
# Avoid redundant cases early on for Ops whose default form is not Blockwised
@register_canonicalize
@node_rewriter(tracks=[_matrix_matrix_matmul])
def local_eager_useless_unbatched_blockwise(fgraph, node):
return local_useless_unbatched_blockwise.fn(fgraph, node)
...@@ -647,8 +647,12 @@ class _tensor_py_operators: ...@@ -647,8 +647,12 @@ class _tensor_py_operators:
return at.math.dense_dot(left, right) return at.math.dense_dot(left, right)
dot = __dot__ dot = __dot__
__matmul__ = __dot__
__rmatmul__ = __rdot__ def __matmul__(left, right):
return at.math.matmul(left, right)
def __rmatmul__(right, left):
return at.math.matmul(right, left)
def sum(self, axis=None, dtype=None, keepdims=False, acc_dtype=None): def sum(self, axis=None, dtype=None, keepdims=False, acc_dtype=None):
"""See :func:`pytensor.tensor.math.sum`.""" """See :func:`pytensor.tensor.math.sum`."""
...@@ -797,7 +801,7 @@ class _tensor_py_operators: ...@@ -797,7 +801,7 @@ class _tensor_py_operators:
""" """
return at.basic.choose(self, choices, mode="raise") return at.basic.choose(self, choices, mode="raise")
def squeeze(self): def squeeze(self, axis=None):
""" """
Remove broadcastable dimensions from the shape of an array. Remove broadcastable dimensions from the shape of an array.
...@@ -805,7 +809,7 @@ class _tensor_py_operators: ...@@ -805,7 +809,7 @@ class _tensor_py_operators:
removed. This is always `x` itself or a view into `x`. removed. This is always `x` itself or a view into `x`.
""" """
return at.extra_ops.squeeze(self) return at.extra_ops.squeeze(self, axis=axis)
def compress(self, a, axis=None): def compress(self, a, axis=None):
"""Return selected slices only.""" """Return selected slices only."""
......
...@@ -30,11 +30,11 @@ from pytensor.tensor.basic import ( ...@@ -30,11 +30,11 @@ from pytensor.tensor.basic import (
get_underlying_scalar_constant_value, get_underlying_scalar_constant_value,
switch, switch,
) )
from pytensor.tensor.blas import Dot22
from pytensor.tensor.elemwise import CAReduce, Elemwise from pytensor.tensor.elemwise import CAReduce, Elemwise
from pytensor.tensor.math import ( from pytensor.tensor.math import (
Argmax, Argmax,
Dot, Dot,
MatMul,
MaxAndArgmax, MaxAndArgmax,
Mean, Mean,
Prod, Prod,
...@@ -3412,12 +3412,10 @@ def test_log1mexp_grad_lim(): ...@@ -3412,12 +3412,10 @@ def test_log1mexp_grad_lim():
assert grad_x_fn(-1e-308) != -np.inf assert grad_x_fn(-1e-308) != -np.inf
class TestMatMul(utt.InferShapeTester): class TestMatMul:
def setup_method(self): def setup_method(self):
super().setup_method()
self.rng = np.random.default_rng(utt.fetch_seed()) self.rng = np.random.default_rng(utt.fetch_seed())
self.op = matmul self.op = matmul
self.op_class = MatMul
def _validate_output(self, a, b): def _validate_output(self, a, b):
pytensor_sol = self.op(a, b).eval() pytensor_sol = self.op(a, b).eval()
...@@ -3467,85 +3465,8 @@ class TestMatMul(utt.InferShapeTester): ...@@ -3467,85 +3465,8 @@ class TestMatMul(utt.InferShapeTester):
sol = self.op([1, 2, 3], [3, 2, 1], dtype=dtype) sol = self.op([1, 2, 3], [3, 2, 1], dtype=dtype)
assert sol.eval().dtype == dtype assert sol.eval().dtype == dtype
@pytest.mark.parametrize( def test_dot22_opt(self):
"x1_shape,x2_shape,exp_res,error_regex", x, y = matrices("xy")
[ fn = function([x, y], x @ y, mode="FAST_RUN")
((1,), (3,), None, "inputs must have the same length"), [node] = fn.maker.fgraph.apply_nodes
((2,), (3, 1), None, "length of input 1.*2nd-last dimension of input 2"), assert isinstance(node.op, Dot22)
((2, 5), (3,), None, "length of input 2.*of the last dimension of input 1"),
(
(2, 5),
(3, 4),
None,
"number of columns of input 1 .* number of rows of input 2",
),
(
(2, 1, 3),
(5, 4),
None,
"number of rows of input 2 .* last dimension of input 1",
),
(
(2, 5),
(2, 4, 3),
None,
"number of columns of input 1 .* 2nd-last dimension of input 2",
),
(
(3, 2, 4, 5),
(1, 6, 7),
None,
"length of the last dimension of input 1 .* 2nd-last dimension of input 2",
),
(
(4, 5, 4),
(3, 2, 2),
None,
"cannot be broadcast to a single shape",
),
(
(4, None, 2),
(4, 2, None),
(4, None, None),
None,
),
],
)
def test_get_output_shape(self, x1_shape, x2_shape, exp_res, error_regex):
x1 = tensor(dtype=np.float64, shape=x1_shape)
x2 = tensor(dtype=np.float64, shape=x2_shape)
if error_regex is not None:
with pytest.raises(ValueError, match=error_regex):
self.op_class._get_output_shape(
x1, x2, (x1_shape, x2_shape), validate=True
)
else:
assert (
self.op_class._get_output_shape(
x1, x2, (x1_shape, x2_shape), validate=True
)
== exp_res
)
def test_infer_shape(self):
for shape_x1, shape_x2 in [
((5,), (5,)),
((5,), (2, 5, 3)),
((2, 5, 3), (3,)),
((2, 5), (5, 4)),
((2, 5), (2, 5, 3)),
((2, 1, 3), (3, 4)),
((3, 2, 4, 5), (1, 5, 7)),
]:
a = tensor(dtype=config.floatX, shape=shape_x1)
b = tensor(dtype=config.floatX, shape=shape_x2)
x1 = self.rng.random(shape_x1).astype(config.floatX)
x2 = self.rng.random(shape_x2).astype(config.floatX)
self._compile_and_check(
[a, b],
[self.op(a, b)],
[x1, x2],
self.op_class,
)
...@@ -10,9 +10,9 @@ from pytensor.compile import DeepCopyOp ...@@ -10,9 +10,9 @@ from pytensor.compile import DeepCopyOp
from pytensor.compile.mode import get_default_mode from pytensor.compile.mode import get_default_mode
from pytensor.graph.basic import Constant, equal_computations from pytensor.graph.basic import Constant, equal_computations
from pytensor.tensor import get_vector_length from pytensor.tensor import get_vector_length
from pytensor.tensor.basic import constant from pytensor.tensor.basic import as_tensor, constant
from pytensor.tensor.elemwise import DimShuffle from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.math import dot, eq from pytensor.tensor.math import dot, eq, matmul
from pytensor.tensor.shape import Shape from pytensor.tensor.shape import Shape
from pytensor.tensor.subtensor import AdvancedSubtensor, Subtensor from pytensor.tensor.subtensor import AdvancedSubtensor, Subtensor
from pytensor.tensor.type import ( from pytensor.tensor.type import (
...@@ -79,16 +79,30 @@ def test_infix_dot_method(): ...@@ -79,16 +79,30 @@ def test_infix_dot_method():
X = dmatrix("X") X = dmatrix("X")
y = dvector("y") y = dvector("y")
res = X @ y res = X.dot(y)
exp_res = X.dot(y) exp_res = dot(X, y)
assert equal_computations([res], [exp_res]) assert equal_computations([res], [exp_res])
X_val = np.arange(2 * 3).reshape((2, 3)) X_val = np.arange(2 * 3).reshape((2, 3))
res = X_val @ y res = as_tensor(X_val).dot(y)
exp_res = dot(X_val, y) exp_res = dot(X_val, y)
assert equal_computations([res], [exp_res]) assert equal_computations([res], [exp_res])
def test_infix_matmul_method():
X = dmatrix("X")
y = dvector("y")
res = X @ y
exp_res = matmul(X, y)
assert equal_computations([res], [exp_res])
X_val = np.arange(2 * 3).reshape((2, 3))
res = as_tensor(X_val) @ y
exp_res = matmul(X_val, y)
assert equal_computations([res], [exp_res])
def test_empty_list_indexing(): def test_empty_list_indexing():
ynp = np.zeros((2, 2))[:, []] ynp = np.zeros((2, 2))[:, []]
znp = np.zeros((2, 2))[:, ()] znp = np.zeros((2, 2))[:, ()]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论