提交 911c6a33 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Rewrite batched dots that do not reduce as multiplication

上级 f86a0dc1
......@@ -29,7 +29,7 @@ from pytensor.tensor.basic import (
stack,
switch,
)
from pytensor.tensor.blockwise import Blockwise, vectorize_node_fallback
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.elemwise import (
CAReduce,
Elemwise,
......@@ -2726,6 +2726,22 @@ def logsumexp(x, axis=None, keepdims=False):
return log(sum(exp(x), axis=axis, keepdims=keepdims))
# Predefine all batched variations of Dot
_inner_prod = Blockwise(
_dot,
signature="(n),(n)->()",
)
_matrix_vec_prod = Blockwise(
_dot,
signature="(m,k),(k)->(m)",
)
_vec_matrix_prod = Blockwise(
_dot,
signature="(k),(k,n)->(n)",
)
_matrix_matrix_matmul = Blockwise(
_dot,
signature="(m,k),(k,n)->(m,n)",
......@@ -2795,14 +2811,24 @@ def matmul(x1: "ArrayLike", x2: "ArrayLike", dtype: Optional["DTypeLike"] = None
@_vectorize_node.register(Dot)
def vectorize_node_dot_to_matmul(op, node, batched_x, batched_y):
def vectorize_node_dot(op, node, batched_x, batched_y):
old_x, old_y = node.inputs
if old_x.type.ndim == 2 and old_y.type.ndim == 2:
# If original input is equivalent to a matrix-matrix product,
# return specialized Matmul Op to avoid unnecessary new Ops.
return matmul(batched_x, batched_y).owner
else:
return vectorize_node_fallback(op, node, batched_x, batched_y)
old_x_ndim = old_x.type.ndim
old_y_ndim = old_y.type.ndim
match (old_x_ndim, old_y_ndim):
case (1, 1):
batch_op = _inner_prod
case (2, 1):
batch_op = _matrix_vec_prod
case (1, 2):
batch_op = _vec_matrix_prod
case (2, 2):
batch_op = _matrix_matrix_matmul
case _:
raise ValueError(
f"Core dot Op should have 1D or 2D inputs, got {old_x_ndim}D and {old_y_ndim}D."
)
return batch_op(batched_x, batched_y).owner
def nan_to_num(x, nan=0.0, posinf=None, neginf=None):
......
......@@ -44,6 +44,10 @@ from pytensor.tensor.math import (
Prod,
Sum,
_conj,
_inner_prod,
_matrix_matrix_matmul,
_matrix_vec_prod,
_vec_matrix_prod,
add,
digamma,
dot,
......@@ -242,6 +246,62 @@ def local_batched_matmul_to_core_matmul(fgraph, node):
return None
@register_canonicalize
@register_specialize
@node_rewriter([_inner_prod, _matrix_vec_prod, _vec_matrix_prod, _matrix_matrix_matmul])
def local_blockwise_dot_to_mul(fgraph, node):
"""Rewrite blockwise dots that correspond to multiplication without summation.
We don't touch the regular dot, to not interfere with the BLAS optimizations.
"""
a, b = node.inputs
a_static_shape = a.type.shape
b_static_shape = b.type.shape
core_a_ndim = len(node.op.inputs_sig[0])
core_b_ndim = len(node.op.inputs_sig[1])
if core_a_ndim > 2 or core_b_ndim > 2:
# Shouldn't happen, but here just in case
return None
if core_b_ndim == 1:
if a_static_shape[-1] == 1 or b_static_shape[-1] == 1:
if core_a_ndim == 1:
# inner product: (..., 1) * (..., 1) -> (...)
# just squeeze the last dimensions of a and b
new_a = a.squeeze(-1)
new_b = b.squeeze(-1)
else:
# matrix vector product: (..., m, 1) * (..., 1) -> (..., m)
# the last dimension of b is already aligned for the elemwise multiplication
# after we squeeze the last dimension of a
new_a = a.squeeze(-1)
new_b = b
else:
return None
else:
if a_static_shape[-1] == 1 or b_static_shape[-2] == 1:
if core_a_ndim == 1:
# vector_matrix product: (..., 1) * (..., 1, n) -> (..., n)
# the last dimension of a is already aligned for the elemwise multiplication
# after we squeeze the one to last dimension of b
new_a = a
new_b = b.squeeze(-2)
else:
# matrix matrix product: (..., m, 1) * (..., 1, n) -> (..., m, n)
# the dimensions of a and b are already aligned for the elemwise multiplication
new_a = a
new_b = b
else:
return None
new_a = copy_stack_trace(a, new_a)
new_b = copy_stack_trace(b, new_b)
new_out = copy_stack_trace(node.out, mul(new_a, new_b))
return [new_out]
def is_inverse_pair(node_op, prev_op, inv_pair):
"""
Given two consecutive operations, check if they are the
......
......@@ -16,7 +16,8 @@ from pytensor.compile.function import function
from pytensor.compile.mode import Mode, get_default_mode, get_mode
from pytensor.compile.ops import DeepCopyOp, deep_copy_op
from pytensor.configdefaults import config
from pytensor.graph.basic import Apply, equal_computations
from pytensor.graph import vectorize_graph
from pytensor.graph.basic import Apply, ancestors, equal_computations
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.rewriting.basic import (
SequentialNodeRewriter,
......@@ -4590,3 +4591,53 @@ def test_pow_1_rewrite(shape):
x_val = np.random.random(shape).astype(config.floatX)
np.testing.assert_allclose(z.eval({x: x_val}), f(x_val))
@pytest.mark.parametrize(
"a_shape,b_shape",
[
((1,), (1,)),
((3, 1), (1,)),
((1,), (1, 3)),
((3, 1), (1, 3)),
],
ids=str,
)
@pytest.mark.parametrize("batched", (False, True))
def test_local_dot_to_mul(batched, a_shape, b_shape):
a = tensor("a", shape=a_shape)
b = tensor("b", shape=b_shape)
out = dot(a, b)
if batched:
batch_a = tensor("batch_a", shape=(1, 5, *a_shape))
batch_b = tensor("batch_b", shape=(7, 1, *b_shape))
out = vectorize_graph(out, {a: batch_a, b: batch_b})
a = batch_a
b = batch_b
assert (
sum(
isinstance(var.owner.op, (Blockwise | Dot))
for var in ancestors([out])
if var.owner
)
== 1
)
# For now rewrite only applies to Batched Dots
rewritten_out = rewrite_graph(out)
assert rewritten_out.type.shape == out.type.shape
assert sum(
isinstance(var.owner.op, (Blockwise | Dot))
for var in ancestors([rewritten_out])
if var.owner
) == (0 if batched else 1)
a_test = np.random.normal(size=a.type.shape).astype(a.type.dtype)
b_test = np.random.normal(size=b.type.shape).astype(b.type.dtype)
test_mode = Mode(linker="py", optimizer=None)
np.testing.assert_allclose(
out.eval({a: a_test, b: b_test}, mode=test_mode),
rewritten_out.eval({a: a_test, b: b_test}, mode=test_mode),
)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论