Unverified 提交 6f8bb555 authored 作者: Jesse Grabowski's avatar Jesse Grabowski 提交者: GitHub

Optimize matmuls involving block diagonal matrices (#1493)

* Add `concat_with_broadcast` helper function Use new helper in xt.concat Co-authored-by: 's avatarRicardo <ricardo.vieira1994@gmail.com> * block_diag dot rewrite Co-authored-by: 's avatarRicardo <ricardo.vieira1994@gmail.com> --------- Co-authored-by: 's avatarRicardo <ricardo.vieira1994@gmail.com>
上级 d4e8f736
......@@ -27,7 +27,7 @@ from pytensor.scalar import int64 as int_t
from pytensor.scalar import upcast
from pytensor.tensor import TensorLike, as_tensor_variable
from pytensor.tensor import basic as ptb
from pytensor.tensor.basic import alloc, second
from pytensor.tensor.basic import alloc, join, second
from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.math import abs as pt_abs
from pytensor.tensor.math import all as pt_all
......@@ -2018,6 +2018,42 @@ def broadcast_arrays(*args: TensorVariable) -> tuple[TensorVariable, ...]:
return brodacasted_vars
def concat_with_broadcast(tensor_list, axis=0):
"""
Concatenate a list of tensors, broadcasting the non-concatenated dimensions to align.
"""
if not tensor_list:
raise ValueError("Cannot concatenate an empty list of tensors.")
ndim = tensor_list[0].ndim
if not all(t.ndim == ndim for t in tensor_list):
raise TypeError(
"Only tensors with the same number of dimensions can be concatenated. "
f"Input ndims were: {[x.ndim for x in tensor_list]}"
)
axis = normalize_axis_index(axis=axis, ndim=ndim)
non_concat_shape = [1 if i != axis else None for i in range(ndim)]
for tensor_inp in tensor_list:
for i, (bcast, sh) in enumerate(
zip(tensor_inp.type.broadcastable, tensor_inp.shape)
):
if bcast or i == axis:
continue
non_concat_shape[i] = sh
assert non_concat_shape.count(None) == 1
bcast_tensor_inputs = []
for tensor_inp in tensor_list:
# We modify the concat_axis in place, as we don't need the list anywhere else
non_concat_shape[axis] = tensor_inp.shape[axis]
bcast_tensor_inputs.append(broadcast_to(tensor_inp, non_concat_shape))
return join(axis, *bcast_tensor_inputs)
__all__ = [
"searchsorted",
"cumsum",
......@@ -2035,6 +2071,7 @@ __all__ = [
"ravel_multi_index",
"broadcast_shape",
"broadcast_to",
"concat_with_broadcast",
"geomspace",
"logspace",
"linspace",
......
......@@ -32,18 +32,21 @@ from pytensor.tensor.basic import (
moveaxis,
ones_like,
register_infer_shape,
split,
switch,
zeros,
zeros_like,
)
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.extra_ops import broadcast_arrays
from pytensor.tensor.extra_ops import broadcast_arrays, concat_with_broadcast
from pytensor.tensor.math import (
Dot,
Prod,
Sum,
_conj,
_dot,
_matmul,
add,
digamma,
......@@ -96,6 +99,7 @@ from pytensor.tensor.rewriting.basic import (
from pytensor.tensor.rewriting.elemwise import apply_local_dimshuffle_lift
from pytensor.tensor.rewriting.linalg import is_matrix_transpose
from pytensor.tensor.shape import Shape, Shape_i
from pytensor.tensor.slinalg import BlockDiagonal
from pytensor.tensor.subtensor import Subtensor
from pytensor.tensor.type import (
complex_dtypes,
......@@ -146,6 +150,68 @@ def local_0_dot_x(fgraph, node):
return [zeros((x.shape[0], y.shape[1]), dtype=node.outputs[0].type.dtype)]
@register_stabilize
@node_rewriter([Blockwise])
def local_block_diag_dot_to_dot_block_diag(fgraph, node):
r"""
Perform the rewrite ``dot(block_diag(A, B), C) -> concat(dot(A, C), dot(B, C))``
BlockDiag results in the creation of a matrix of shape ``(n1 * n2, m1 * m2)``. Because dot has complexity
of approximately O(n^3), it's always better to perform two dot products on the smaller matrices, rather than
a single dot on the larger matrix.
"""
if not isinstance(node.op.core_op, BlockDiagonal):
return
# Check that the BlockDiagonal is an input to a Dot node:
for client in itertools.chain.from_iterable(
get_clients_at_depth(fgraph, node, depth=i) for i in [1, 2]
):
if client.op not in (_dot, _matmul):
continue
[blockdiag_result] = node.outputs
blockdiag_inputs = node.inputs
dot_op = client.op
try:
client_idx = client.inputs.index(blockdiag_result)
except ValueError:
# If the blockdiag result is not an input to the dot, there is at least one Op between them (usually a
# DimShuffle). In this case, we need to figure out which of the inputs of the dot eventually leads to the
# blockdiag result.
for ancestor in client.inputs:
if ancestor.owner and blockdiag_result in ancestor.owner.inputs:
client_idx = client.inputs.index(ancestor)
break
other_input = client.inputs[1 - client_idx]
split_axis = -2 if client_idx == 0 else -1
split_size_axis = -1 if client_idx == 0 else -2
other_dot_input_split = split(
other_input,
splits_size=[
component.shape[split_size_axis] for component in blockdiag_inputs
],
n_splits=len(blockdiag_inputs),
axis=split_axis,
)
split_dot_results = [
dot_op(component, other_split)
if client_idx == 0
else dot_op(other_split, component)
for component, other_split in zip(blockdiag_inputs, other_dot_input_split)
]
new_output = concat_with_broadcast(split_dot_results, axis=split_axis)
copy_stack_trace(node.outputs[0], new_output)
return {client.outputs[0]: new_output}
@register_canonicalize
@node_rewriter([Dot, _matmul])
def local_lift_transpose_through_dot(fgraph, node):
......@@ -2582,7 +2648,6 @@ add_canonizer = in2out(
name="add_canonizer_group",
)
register_canonicalize(local_add_canonizer, "shape_unsafe", name="local_add_canonizer")
......@@ -3720,7 +3785,6 @@ logdiffexp_to_log1mexpdiff = PatternNodeRewriter(
)
register_stabilize(logdiffexp_to_log1mexpdiff, name="logdiffexp_to_log1mexpdiff")
# log(sigmoid(x) / (1 - sigmoid(x))) -> x
# i.e logit(sigmoid(x)) -> x
local_logit_sigmoid = PatternNodeRewriter(
......@@ -3734,7 +3798,6 @@ local_logit_sigmoid = PatternNodeRewriter(
register_canonicalize(local_logit_sigmoid)
register_specialize(local_logit_sigmoid)
# sigmoid(log(x / (1-x)) -> x
# i.e., sigmoid(logit(x)) -> x
local_sigmoid_logit = PatternNodeRewriter(
......@@ -3775,7 +3838,6 @@ local_polygamma_to_tri_gamma = PatternNodeRewriter(
register_specialize(local_polygamma_to_tri_gamma)
local_log_kv = PatternNodeRewriter(
# Rewrite log(kv(v, x)) = log(kve(v, x) * exp(-x)) -> log(kve(v, x)) - x
# During stabilize -x is converted to -1.0 * x
......
......@@ -2,8 +2,8 @@ import pytensor.tensor as pt
from pytensor.graph import node_rewriter
from pytensor.tensor import (
broadcast_to,
concat_with_broadcast,
expand_dims,
join,
moveaxis,
specify_shape,
squeeze,
......@@ -74,28 +74,7 @@ def lower_concat(fgraph, node):
# Convert input XTensors to Tensors and align batch dimensions
tensor_inputs = [lower_aligned(inp, out_dims) for inp in node.inputs]
# Broadcast non-concatenated dimensions of each input
non_concat_shape = [None] * len(out_dims)
for tensor_inp in tensor_inputs:
# TODO: This is assuming the graph is correct and every non-concat dimension matches in shape at runtime
# I'm running this as "shape_unsafe" to simplify the logic / returned graph
for i, (bcast, sh) in enumerate(
zip(tensor_inp.type.broadcastable, tensor_inp.shape)
):
if bcast or i == concat_axis or non_concat_shape[i] is not None:
continue
non_concat_shape[i] = sh
assert non_concat_shape.count(None) == 1
bcast_tensor_inputs = []
for tensor_inp in tensor_inputs:
# We modify the concat_axis in place, as we don't need the list anywhere else
non_concat_shape[concat_axis] = tensor_inp.shape[concat_axis]
bcast_tensor_inputs.append(broadcast_to(tensor_inp, non_concat_shape))
joined_tensor = join(concat_axis, *bcast_tensor_inputs)
joined_tensor = concat_with_broadcast(tensor_inputs, axis=concat_axis)
new_out = xtensor_from_tensor(joined_tensor, dims=out_dims)
return [new_out]
......
......@@ -115,6 +115,7 @@ from pytensor.tensor.rewriting.math import (
simplify_mul,
)
from pytensor.tensor.shape import Reshape, Shape_i, SpecifyShape, specify_shape
from pytensor.tensor.slinalg import BlockDiagonal
from pytensor.tensor.type import (
TensorType,
cmatrix,
......@@ -4745,3 +4746,121 @@ def test_local_dot_to_mul(batched, a_shape, b_shape):
out.eval({a: a_test, b: b_test}, mode=test_mode),
rewritten_out.eval({a: a_test, b: b_test}, mode=test_mode),
)
@pytest.mark.parametrize("left_multiply", [True, False], ids=["left", "right"])
@pytest.mark.parametrize(
"batch_blockdiag", [True, False], ids=["batch_blockdiag", "unbatched_blockdiag"]
)
@pytest.mark.parametrize(
"batch_other", [True, False], ids=["batched_other", "unbatched_other"]
)
def test_local_block_diag_dot_to_dot_block_diag(
left_multiply, batch_blockdiag, batch_other
):
"""
Test that dot(block_diag(x, y,), z) is rewritten to concat(dot(x, z[:n]), dot(y, z[n:]))
"""
def has_blockdiag(graph):
return any(
(
var.owner
and (
isinstance(var.owner.op, BlockDiagonal)
or (
isinstance(var.owner.op, Blockwise)
and isinstance(var.owner.op.core_op, BlockDiagonal)
)
)
)
for var in ancestors([graph])
)
a = tensor("a", shape=(4, 2))
b = tensor("b", shape=(2, 4) if not batch_blockdiag else (3, 2, 4))
c = tensor("c", shape=(4, 4))
x = pt.linalg.block_diag(a, b, c)
d = tensor("d", shape=(10, 10) if not batch_other else (3, 1, 10, 10))
# Test multiple clients are all rewritten
if left_multiply:
out = x @ d
else:
out = d @ x
assert has_blockdiag(out)
fn = pytensor.function([a, b, c, d], out, mode=rewrite_mode)
assert not has_blockdiag(fn.maker.fgraph.outputs[0])
n_dots_rewrite = sum(
isinstance(node.op, Dot | Dot22)
or (isinstance(node.op, Blockwise) and isinstance(node.op.core_op, Dot | Dot22))
for node in fn.maker.fgraph.apply_nodes
)
assert n_dots_rewrite == 3
fn_expected = pytensor.function(
[a, b, c, d],
out,
mode=Mode(linker="py", optimizer=None),
)
assert has_blockdiag(fn_expected.maker.fgraph.outputs[0])
n_dots_no_rewrite = sum(
isinstance(node.op, Dot | Dot22)
or (isinstance(node.op, Blockwise) and isinstance(node.op.core_op, Dot | Dot22))
for node in fn_expected.maker.fgraph.apply_nodes
)
assert n_dots_no_rewrite == 1
rng = np.random.default_rng()
a_val = rng.normal(size=a.type.shape).astype(a.type.dtype)
b_val = rng.normal(size=b.type.shape).astype(b.type.dtype)
c_val = rng.normal(size=c.type.shape).astype(c.type.dtype)
d_val = rng.normal(size=d.type.shape).astype(d.type.dtype)
rewrite_out = fn(a_val, b_val, c_val, d_val)
expected_out = fn_expected(a_val, b_val, c_val, d_val)
np.testing.assert_allclose(
rewrite_out,
expected_out,
atol=1e-6 if config.floatX == "float32" else 1e-12,
rtol=1e-6 if config.floatX == "float32" else 1e-12,
)
@pytest.mark.parametrize("rewrite", [True, False], ids=["rewrite", "no_rewrite"])
@pytest.mark.parametrize("size", [10, 100, 1000], ids=["small", "medium", "large"])
def test_block_diag_dot_to_dot_concat_benchmark(benchmark, size, rewrite):
rng = np.random.default_rng()
a_size = int(rng.uniform(1, int(0.8 * size)))
b_size = int(rng.uniform(1, int(0.8 * (size - a_size))))
c_size = size - a_size - b_size
a = tensor("a", shape=(a_size, a_size))
b = tensor("b", shape=(b_size, b_size))
c = tensor("c", shape=(c_size, c_size))
d = tensor("d", shape=(size,))
x = pt.linalg.block_diag(a, b, c)
out = x @ d
mode = get_default_mode()
if not rewrite:
mode = mode.excluding("local_block_diag_dot_to_dot_block_diag")
fn = pytensor.function([a, b, c, d], out, mode=mode)
a_val = rng.normal(size=a.type.shape).astype(a.type.dtype)
b_val = rng.normal(size=b.type.shape).astype(b.type.dtype)
c_val = rng.normal(size=c.type.shape).astype(c.type.dtype)
d_val = rng.normal(size=d.type.shape).astype(d.type.dtype)
benchmark(
fn,
a_val,
b_val,
c_val,
d_val,
)
......@@ -1333,3 +1333,48 @@ def test_space_ops(op, dtype, start, stop, num_samples, endpoint, axis):
atol=1e-6 if config.floatX.endswith("64") else 1e-4,
rtol=1e-6 if config.floatX.endswith("64") else 1e-4,
)
def test_concat_with_broadcast():
rng = np.random.default_rng()
a = pt.tensor("a", shape=(1, 3, 5))
b = pt.tensor("b", shape=(5, 3, 10))
c = pt.concat_with_broadcast([a, b], axis=2)
fn = function([a, b], c, mode="FAST_COMPILE")
assert c.type.shape == (5, 3, 15)
a_val = rng.normal(size=(1, 3, 5)).astype(config.floatX)
b_val = rng.normal(size=(5, 3, 10)).astype(config.floatX)
c_val = fn(a_val, b_val)
# The result should be a tile + concat
np.testing.assert_allclose(c_val[:, :, :5], np.tile(a_val, (5, 1, 1)))
np.testing.assert_allclose(c_val[:, :, 5:], b_val)
# If a and b already conform, the result should be the same as a concatenation
a = pt.tensor("a", shape=(1, 1, 3, 5, 10))
b = pt.tensor("b", shape=(1, 1, 3, 2, 10))
c = pt.concat_with_broadcast([a, b], axis=-2)
assert c.type.shape == (1, 1, 3, 7, 10)
fn = function([a, b], c, mode="FAST_COMPILE")
a_val = rng.normal(size=(1, 1, 3, 5, 10)).astype(config.floatX)
b_val = rng.normal(size=(1, 1, 3, 2, 10)).astype(config.floatX)
c_val = fn(a_val, b_val)
np.testing.assert_allclose(c_val, np.concatenate([a_val, b_val], axis=-2))
c = pt.concat_with_broadcast([a], axis=0)
fn = function([a], c, mode="FAST_COMPILE")
np.testing.assert_allclose(fn(a_val), a_val)
with pytest.raises(ValueError, match="Cannot concatenate an empty list of tensors"):
pt.concat_with_broadcast([], axis=0)
with pytest.raises(
TypeError,
match="Only tensors with the same number of dimensions can be concatenated.",
):
a = pt.tensor("a", shape=(1, 3, 5))
b = pt.tensor("b", shape=(3, 5))
pt.concat_with_broadcast([a, b], axis=1)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论