Unverified 提交 6f8bb555 authored 作者: Jesse Grabowski's avatar Jesse Grabowski 提交者: GitHub

Optimize matmuls involving block diagonal matrices (#1493)

* Add `concat_with_broadcast` helper function Use new helper in xt.concat Co-authored-by: 's avatarRicardo <ricardo.vieira1994@gmail.com> * block_diag dot rewrite Co-authored-by: 's avatarRicardo <ricardo.vieira1994@gmail.com> --------- Co-authored-by: 's avatarRicardo <ricardo.vieira1994@gmail.com>
上级 d4e8f736
...@@ -27,7 +27,7 @@ from pytensor.scalar import int64 as int_t ...@@ -27,7 +27,7 @@ from pytensor.scalar import int64 as int_t
from pytensor.scalar import upcast from pytensor.scalar import upcast
from pytensor.tensor import TensorLike, as_tensor_variable from pytensor.tensor import TensorLike, as_tensor_variable
from pytensor.tensor import basic as ptb from pytensor.tensor import basic as ptb
from pytensor.tensor.basic import alloc, second from pytensor.tensor.basic import alloc, join, second
from pytensor.tensor.exceptions import NotScalarConstantError from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.math import abs as pt_abs from pytensor.tensor.math import abs as pt_abs
from pytensor.tensor.math import all as pt_all from pytensor.tensor.math import all as pt_all
...@@ -2018,6 +2018,42 @@ def broadcast_arrays(*args: TensorVariable) -> tuple[TensorVariable, ...]: ...@@ -2018,6 +2018,42 @@ def broadcast_arrays(*args: TensorVariable) -> tuple[TensorVariable, ...]:
return brodacasted_vars return brodacasted_vars
def concat_with_broadcast(tensor_list, axis=0):
"""
Concatenate a list of tensors, broadcasting the non-concatenated dimensions to align.
"""
if not tensor_list:
raise ValueError("Cannot concatenate an empty list of tensors.")
ndim = tensor_list[0].ndim
if not all(t.ndim == ndim for t in tensor_list):
raise TypeError(
"Only tensors with the same number of dimensions can be concatenated. "
f"Input ndims were: {[x.ndim for x in tensor_list]}"
)
axis = normalize_axis_index(axis=axis, ndim=ndim)
non_concat_shape = [1 if i != axis else None for i in range(ndim)]
for tensor_inp in tensor_list:
for i, (bcast, sh) in enumerate(
zip(tensor_inp.type.broadcastable, tensor_inp.shape)
):
if bcast or i == axis:
continue
non_concat_shape[i] = sh
assert non_concat_shape.count(None) == 1
bcast_tensor_inputs = []
for tensor_inp in tensor_list:
# We modify the concat_axis in place, as we don't need the list anywhere else
non_concat_shape[axis] = tensor_inp.shape[axis]
bcast_tensor_inputs.append(broadcast_to(tensor_inp, non_concat_shape))
return join(axis, *bcast_tensor_inputs)
__all__ = [ __all__ = [
"searchsorted", "searchsorted",
"cumsum", "cumsum",
...@@ -2035,6 +2071,7 @@ __all__ = [ ...@@ -2035,6 +2071,7 @@ __all__ = [
"ravel_multi_index", "ravel_multi_index",
"broadcast_shape", "broadcast_shape",
"broadcast_to", "broadcast_to",
"concat_with_broadcast",
"geomspace", "geomspace",
"logspace", "logspace",
"linspace", "linspace",
......
...@@ -32,18 +32,21 @@ from pytensor.tensor.basic import ( ...@@ -32,18 +32,21 @@ from pytensor.tensor.basic import (
moveaxis, moveaxis,
ones_like, ones_like,
register_infer_shape, register_infer_shape,
split,
switch, switch,
zeros, zeros,
zeros_like, zeros_like,
) )
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
from pytensor.tensor.exceptions import NotScalarConstantError from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.extra_ops import broadcast_arrays from pytensor.tensor.extra_ops import broadcast_arrays, concat_with_broadcast
from pytensor.tensor.math import ( from pytensor.tensor.math import (
Dot, Dot,
Prod, Prod,
Sum, Sum,
_conj, _conj,
_dot,
_matmul, _matmul,
add, add,
digamma, digamma,
...@@ -96,6 +99,7 @@ from pytensor.tensor.rewriting.basic import ( ...@@ -96,6 +99,7 @@ from pytensor.tensor.rewriting.basic import (
from pytensor.tensor.rewriting.elemwise import apply_local_dimshuffle_lift from pytensor.tensor.rewriting.elemwise import apply_local_dimshuffle_lift
from pytensor.tensor.rewriting.linalg import is_matrix_transpose from pytensor.tensor.rewriting.linalg import is_matrix_transpose
from pytensor.tensor.shape import Shape, Shape_i from pytensor.tensor.shape import Shape, Shape_i
from pytensor.tensor.slinalg import BlockDiagonal
from pytensor.tensor.subtensor import Subtensor from pytensor.tensor.subtensor import Subtensor
from pytensor.tensor.type import ( from pytensor.tensor.type import (
complex_dtypes, complex_dtypes,
...@@ -146,6 +150,68 @@ def local_0_dot_x(fgraph, node): ...@@ -146,6 +150,68 @@ def local_0_dot_x(fgraph, node):
return [zeros((x.shape[0], y.shape[1]), dtype=node.outputs[0].type.dtype)] return [zeros((x.shape[0], y.shape[1]), dtype=node.outputs[0].type.dtype)]
@register_stabilize
@node_rewriter([Blockwise])
def local_block_diag_dot_to_dot_block_diag(fgraph, node):
r"""
Perform the rewrite ``dot(block_diag(A, B), C) -> concat(dot(A, C), dot(B, C))``
BlockDiag results in the creation of a matrix of shape ``(n1 * n2, m1 * m2)``. Because dot has complexity
of approximately O(n^3), it's always better to perform two dot products on the smaller matrices, rather than
a single dot on the larger matrix.
"""
if not isinstance(node.op.core_op, BlockDiagonal):
return
# Check that the BlockDiagonal is an input to a Dot node:
for client in itertools.chain.from_iterable(
get_clients_at_depth(fgraph, node, depth=i) for i in [1, 2]
):
if client.op not in (_dot, _matmul):
continue
[blockdiag_result] = node.outputs
blockdiag_inputs = node.inputs
dot_op = client.op
try:
client_idx = client.inputs.index(blockdiag_result)
except ValueError:
# If the blockdiag result is not an input to the dot, there is at least one Op between them (usually a
# DimShuffle). In this case, we need to figure out which of the inputs of the dot eventually leads to the
# blockdiag result.
for ancestor in client.inputs:
if ancestor.owner and blockdiag_result in ancestor.owner.inputs:
client_idx = client.inputs.index(ancestor)
break
other_input = client.inputs[1 - client_idx]
split_axis = -2 if client_idx == 0 else -1
split_size_axis = -1 if client_idx == 0 else -2
other_dot_input_split = split(
other_input,
splits_size=[
component.shape[split_size_axis] for component in blockdiag_inputs
],
n_splits=len(blockdiag_inputs),
axis=split_axis,
)
split_dot_results = [
dot_op(component, other_split)
if client_idx == 0
else dot_op(other_split, component)
for component, other_split in zip(blockdiag_inputs, other_dot_input_split)
]
new_output = concat_with_broadcast(split_dot_results, axis=split_axis)
copy_stack_trace(node.outputs[0], new_output)
return {client.outputs[0]: new_output}
@register_canonicalize @register_canonicalize
@node_rewriter([Dot, _matmul]) @node_rewriter([Dot, _matmul])
def local_lift_transpose_through_dot(fgraph, node): def local_lift_transpose_through_dot(fgraph, node):
...@@ -2582,7 +2648,6 @@ add_canonizer = in2out( ...@@ -2582,7 +2648,6 @@ add_canonizer = in2out(
name="add_canonizer_group", name="add_canonizer_group",
) )
register_canonicalize(local_add_canonizer, "shape_unsafe", name="local_add_canonizer") register_canonicalize(local_add_canonizer, "shape_unsafe", name="local_add_canonizer")
...@@ -3720,7 +3785,6 @@ logdiffexp_to_log1mexpdiff = PatternNodeRewriter( ...@@ -3720,7 +3785,6 @@ logdiffexp_to_log1mexpdiff = PatternNodeRewriter(
) )
register_stabilize(logdiffexp_to_log1mexpdiff, name="logdiffexp_to_log1mexpdiff") register_stabilize(logdiffexp_to_log1mexpdiff, name="logdiffexp_to_log1mexpdiff")
# log(sigmoid(x) / (1 - sigmoid(x))) -> x # log(sigmoid(x) / (1 - sigmoid(x))) -> x
# i.e logit(sigmoid(x)) -> x # i.e logit(sigmoid(x)) -> x
local_logit_sigmoid = PatternNodeRewriter( local_logit_sigmoid = PatternNodeRewriter(
...@@ -3734,7 +3798,6 @@ local_logit_sigmoid = PatternNodeRewriter( ...@@ -3734,7 +3798,6 @@ local_logit_sigmoid = PatternNodeRewriter(
register_canonicalize(local_logit_sigmoid) register_canonicalize(local_logit_sigmoid)
register_specialize(local_logit_sigmoid) register_specialize(local_logit_sigmoid)
# sigmoid(log(x / (1-x)) -> x # sigmoid(log(x / (1-x)) -> x
# i.e., sigmoid(logit(x)) -> x # i.e., sigmoid(logit(x)) -> x
local_sigmoid_logit = PatternNodeRewriter( local_sigmoid_logit = PatternNodeRewriter(
...@@ -3775,7 +3838,6 @@ local_polygamma_to_tri_gamma = PatternNodeRewriter( ...@@ -3775,7 +3838,6 @@ local_polygamma_to_tri_gamma = PatternNodeRewriter(
register_specialize(local_polygamma_to_tri_gamma) register_specialize(local_polygamma_to_tri_gamma)
local_log_kv = PatternNodeRewriter( local_log_kv = PatternNodeRewriter(
# Rewrite log(kv(v, x)) = log(kve(v, x) * exp(-x)) -> log(kve(v, x)) - x # Rewrite log(kv(v, x)) = log(kve(v, x) * exp(-x)) -> log(kve(v, x)) - x
# During stabilize -x is converted to -1.0 * x # During stabilize -x is converted to -1.0 * x
......
...@@ -2,8 +2,8 @@ import pytensor.tensor as pt ...@@ -2,8 +2,8 @@ import pytensor.tensor as pt
from pytensor.graph import node_rewriter from pytensor.graph import node_rewriter
from pytensor.tensor import ( from pytensor.tensor import (
broadcast_to, broadcast_to,
concat_with_broadcast,
expand_dims, expand_dims,
join,
moveaxis, moveaxis,
specify_shape, specify_shape,
squeeze, squeeze,
...@@ -74,28 +74,7 @@ def lower_concat(fgraph, node): ...@@ -74,28 +74,7 @@ def lower_concat(fgraph, node):
# Convert input XTensors to Tensors and align batch dimensions # Convert input XTensors to Tensors and align batch dimensions
tensor_inputs = [lower_aligned(inp, out_dims) for inp in node.inputs] tensor_inputs = [lower_aligned(inp, out_dims) for inp in node.inputs]
joined_tensor = concat_with_broadcast(tensor_inputs, axis=concat_axis)
# Broadcast non-concatenated dimensions of each input
non_concat_shape = [None] * len(out_dims)
for tensor_inp in tensor_inputs:
# TODO: This is assuming the graph is correct and every non-concat dimension matches in shape at runtime
# I'm running this as "shape_unsafe" to simplify the logic / returned graph
for i, (bcast, sh) in enumerate(
zip(tensor_inp.type.broadcastable, tensor_inp.shape)
):
if bcast or i == concat_axis or non_concat_shape[i] is not None:
continue
non_concat_shape[i] = sh
assert non_concat_shape.count(None) == 1
bcast_tensor_inputs = []
for tensor_inp in tensor_inputs:
# We modify the concat_axis in place, as we don't need the list anywhere else
non_concat_shape[concat_axis] = tensor_inp.shape[concat_axis]
bcast_tensor_inputs.append(broadcast_to(tensor_inp, non_concat_shape))
joined_tensor = join(concat_axis, *bcast_tensor_inputs)
new_out = xtensor_from_tensor(joined_tensor, dims=out_dims) new_out = xtensor_from_tensor(joined_tensor, dims=out_dims)
return [new_out] return [new_out]
......
...@@ -115,6 +115,7 @@ from pytensor.tensor.rewriting.math import ( ...@@ -115,6 +115,7 @@ from pytensor.tensor.rewriting.math import (
simplify_mul, simplify_mul,
) )
from pytensor.tensor.shape import Reshape, Shape_i, SpecifyShape, specify_shape from pytensor.tensor.shape import Reshape, Shape_i, SpecifyShape, specify_shape
from pytensor.tensor.slinalg import BlockDiagonal
from pytensor.tensor.type import ( from pytensor.tensor.type import (
TensorType, TensorType,
cmatrix, cmatrix,
...@@ -4745,3 +4746,121 @@ def test_local_dot_to_mul(batched, a_shape, b_shape): ...@@ -4745,3 +4746,121 @@ def test_local_dot_to_mul(batched, a_shape, b_shape):
out.eval({a: a_test, b: b_test}, mode=test_mode), out.eval({a: a_test, b: b_test}, mode=test_mode),
rewritten_out.eval({a: a_test, b: b_test}, mode=test_mode), rewritten_out.eval({a: a_test, b: b_test}, mode=test_mode),
) )
@pytest.mark.parametrize("left_multiply", [True, False], ids=["left", "right"])
@pytest.mark.parametrize(
"batch_blockdiag", [True, False], ids=["batch_blockdiag", "unbatched_blockdiag"]
)
@pytest.mark.parametrize(
"batch_other", [True, False], ids=["batched_other", "unbatched_other"]
)
def test_local_block_diag_dot_to_dot_block_diag(
left_multiply, batch_blockdiag, batch_other
):
"""
Test that dot(block_diag(x, y,), z) is rewritten to concat(dot(x, z[:n]), dot(y, z[n:]))
"""
def has_blockdiag(graph):
return any(
(
var.owner
and (
isinstance(var.owner.op, BlockDiagonal)
or (
isinstance(var.owner.op, Blockwise)
and isinstance(var.owner.op.core_op, BlockDiagonal)
)
)
)
for var in ancestors([graph])
)
a = tensor("a", shape=(4, 2))
b = tensor("b", shape=(2, 4) if not batch_blockdiag else (3, 2, 4))
c = tensor("c", shape=(4, 4))
x = pt.linalg.block_diag(a, b, c)
d = tensor("d", shape=(10, 10) if not batch_other else (3, 1, 10, 10))
# Test multiple clients are all rewritten
if left_multiply:
out = x @ d
else:
out = d @ x
assert has_blockdiag(out)
fn = pytensor.function([a, b, c, d], out, mode=rewrite_mode)
assert not has_blockdiag(fn.maker.fgraph.outputs[0])
n_dots_rewrite = sum(
isinstance(node.op, Dot | Dot22)
or (isinstance(node.op, Blockwise) and isinstance(node.op.core_op, Dot | Dot22))
for node in fn.maker.fgraph.apply_nodes
)
assert n_dots_rewrite == 3
fn_expected = pytensor.function(
[a, b, c, d],
out,
mode=Mode(linker="py", optimizer=None),
)
assert has_blockdiag(fn_expected.maker.fgraph.outputs[0])
n_dots_no_rewrite = sum(
isinstance(node.op, Dot | Dot22)
or (isinstance(node.op, Blockwise) and isinstance(node.op.core_op, Dot | Dot22))
for node in fn_expected.maker.fgraph.apply_nodes
)
assert n_dots_no_rewrite == 1
rng = np.random.default_rng()
a_val = rng.normal(size=a.type.shape).astype(a.type.dtype)
b_val = rng.normal(size=b.type.shape).astype(b.type.dtype)
c_val = rng.normal(size=c.type.shape).astype(c.type.dtype)
d_val = rng.normal(size=d.type.shape).astype(d.type.dtype)
rewrite_out = fn(a_val, b_val, c_val, d_val)
expected_out = fn_expected(a_val, b_val, c_val, d_val)
np.testing.assert_allclose(
rewrite_out,
expected_out,
atol=1e-6 if config.floatX == "float32" else 1e-12,
rtol=1e-6 if config.floatX == "float32" else 1e-12,
)
@pytest.mark.parametrize("rewrite", [True, False], ids=["rewrite", "no_rewrite"])
@pytest.mark.parametrize("size", [10, 100, 1000], ids=["small", "medium", "large"])
def test_block_diag_dot_to_dot_concat_benchmark(benchmark, size, rewrite):
rng = np.random.default_rng()
a_size = int(rng.uniform(1, int(0.8 * size)))
b_size = int(rng.uniform(1, int(0.8 * (size - a_size))))
c_size = size - a_size - b_size
a = tensor("a", shape=(a_size, a_size))
b = tensor("b", shape=(b_size, b_size))
c = tensor("c", shape=(c_size, c_size))
d = tensor("d", shape=(size,))
x = pt.linalg.block_diag(a, b, c)
out = x @ d
mode = get_default_mode()
if not rewrite:
mode = mode.excluding("local_block_diag_dot_to_dot_block_diag")
fn = pytensor.function([a, b, c, d], out, mode=mode)
a_val = rng.normal(size=a.type.shape).astype(a.type.dtype)
b_val = rng.normal(size=b.type.shape).astype(b.type.dtype)
c_val = rng.normal(size=c.type.shape).astype(c.type.dtype)
d_val = rng.normal(size=d.type.shape).astype(d.type.dtype)
benchmark(
fn,
a_val,
b_val,
c_val,
d_val,
)
...@@ -1333,3 +1333,48 @@ def test_space_ops(op, dtype, start, stop, num_samples, endpoint, axis): ...@@ -1333,3 +1333,48 @@ def test_space_ops(op, dtype, start, stop, num_samples, endpoint, axis):
atol=1e-6 if config.floatX.endswith("64") else 1e-4, atol=1e-6 if config.floatX.endswith("64") else 1e-4,
rtol=1e-6 if config.floatX.endswith("64") else 1e-4, rtol=1e-6 if config.floatX.endswith("64") else 1e-4,
) )
def test_concat_with_broadcast():
rng = np.random.default_rng()
a = pt.tensor("a", shape=(1, 3, 5))
b = pt.tensor("b", shape=(5, 3, 10))
c = pt.concat_with_broadcast([a, b], axis=2)
fn = function([a, b], c, mode="FAST_COMPILE")
assert c.type.shape == (5, 3, 15)
a_val = rng.normal(size=(1, 3, 5)).astype(config.floatX)
b_val = rng.normal(size=(5, 3, 10)).astype(config.floatX)
c_val = fn(a_val, b_val)
# The result should be a tile + concat
np.testing.assert_allclose(c_val[:, :, :5], np.tile(a_val, (5, 1, 1)))
np.testing.assert_allclose(c_val[:, :, 5:], b_val)
# If a and b already conform, the result should be the same as a concatenation
a = pt.tensor("a", shape=(1, 1, 3, 5, 10))
b = pt.tensor("b", shape=(1, 1, 3, 2, 10))
c = pt.concat_with_broadcast([a, b], axis=-2)
assert c.type.shape == (1, 1, 3, 7, 10)
fn = function([a, b], c, mode="FAST_COMPILE")
a_val = rng.normal(size=(1, 1, 3, 5, 10)).astype(config.floatX)
b_val = rng.normal(size=(1, 1, 3, 2, 10)).astype(config.floatX)
c_val = fn(a_val, b_val)
np.testing.assert_allclose(c_val, np.concatenate([a_val, b_val], axis=-2))
c = pt.concat_with_broadcast([a], axis=0)
fn = function([a], c, mode="FAST_COMPILE")
np.testing.assert_allclose(fn(a_val), a_val)
with pytest.raises(ValueError, match="Cannot concatenate an empty list of tensors"):
pt.concat_with_broadcast([], axis=0)
with pytest.raises(
TypeError,
match="Only tensors with the same number of dimensions can be concatenated.",
):
a = pt.tensor("a", shape=(1, 3, 5))
b = pt.tensor("b", shape=(3, 5))
pt.concat_with_broadcast([a, b], axis=1)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论