提交 e9219159 authored 作者: Allen Downey's avatar Allen Downey 提交者: Ricardo Vieira

Implement broadcast for XTensorVariables

上级 e1ce1c35
......@@ -3,7 +3,7 @@ import warnings
import pytensor.xtensor.rewriting
from pytensor.xtensor import linalg, random
from pytensor.xtensor.math import dot
from pytensor.xtensor.shape import concat
from pytensor.xtensor.shape import broadcast, concat
from pytensor.xtensor.type import (
as_xtensor,
xtensor,
......
import pytensor.tensor as pt
from pytensor.graph import node_rewriter
from pytensor.tensor import (
broadcast_to,
......@@ -11,6 +12,7 @@ from pytensor.xtensor.basic import tensor_from_xtensor, xtensor_from_tensor
from pytensor.xtensor.rewriting.basic import register_lower_xtensor
from pytensor.xtensor.rewriting.utils import lower_aligned
from pytensor.xtensor.shape import (
Broadcast,
Concat,
ExpandDims,
Squeeze,
......@@ -157,3 +159,61 @@ def lower_expand_dims(fgraph, node):
# Convert result back to xtensor
result = xtensor_from_tensor(result_tensor, dims=out.type.dims)
return [result]
@register_lower_xtensor
@node_rewriter(tracks=[Broadcast])
def lower_broadcast(fgraph, node):
"""Rewrite XBroadcast using tensor operations."""
excluded_dims = node.op.exclude
tensor_inputs = [
lower_aligned(inp, out.type.dims)
for inp, out in zip(node.inputs, node.outputs, strict=True)
]
if not excluded_dims:
# Simple case: All dimensions are broadcasted
tensor_outputs = pt.broadcast_arrays(*tensor_inputs)
else:
# Complex case: Some dimensions are excluded from broadcasting
# Pick the first dimension_length for each dim
broadcast_dims = {
d: None for d in node.outputs[0].type.dims if d not in excluded_dims
}
for xtensor_inp in node.inputs:
for dim, dim_length in xtensor_inp.sizes.items():
if dim in broadcast_dims and broadcast_dims[dim] is None:
# If the dimension is not excluded, set its shape
broadcast_dims[dim] = dim_length
assert not any(
value is None for value in broadcast_dims.values()
), "All dimensions must have a length"
# Create zeros with the broadcast dimensions, to then broadcast each input against
# PyTensor will rewrite into using only the shapes of the zeros tensor
broadcast_dims = pt.zeros(
tuple(broadcast_dims.values()),
dtype=node.outputs[0].type.dtype,
)
n_broadcast_dims = broadcast_dims.ndim
tensor_outputs = []
for tensor_inp, xtensor_out in zip(tensor_inputs, node.outputs, strict=True):
n_excluded_dims = tensor_inp.type.ndim - n_broadcast_dims
# Excluded dimensions are on the right side of the output tensor so we padright the broadcast_dims
# second is equivalent to `np.broadcast_arrays(x, y)[1]` in PyTensor
tensor_outputs.append(
pt.second(
pt.shape_padright(broadcast_dims, n_excluded_dims),
tensor_inp,
)
)
new_outs = [
xtensor_from_tensor(out_tensor, dims=out.type.dims)
for out_tensor, out in zip(tensor_outputs, node.outputs)
]
return new_outs
......@@ -13,7 +13,8 @@ from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.type import integer_dtypes
from pytensor.tensor.utils import get_static_shape_from_size_variables
from pytensor.xtensor.basic import XOp
from pytensor.xtensor.type import as_xtensor, xtensor
from pytensor.xtensor.type import XTensorVariable, as_xtensor, xtensor
from pytensor.xtensor.vectorization import combine_dims_and_shape
class Stack(XOp):
......@@ -504,3 +505,63 @@ def expand_dims(x, dim=None, create_index_for_new_dim=None, axis=None, **dim_kwa
x = Transpose(dims=tuple(target_dims))(x)
return x
class Broadcast(XOp):
"""Broadcast multiple XTensorVariables against each other."""
__props__ = ("exclude",)
def __init__(self, exclude: Sequence[str] = ()):
self.exclude = tuple(exclude)
def make_node(self, *inputs):
inputs = [as_xtensor(x) for x in inputs]
exclude = self.exclude
dims_and_shape = combine_dims_and_shape(inputs, exclude=exclude)
broadcast_dims = tuple(dims_and_shape.keys())
broadcast_shape = tuple(dims_and_shape.values())
dtype = upcast(*[x.type.dtype for x in inputs])
outputs = []
for x in inputs:
x_dims = x.type.dims
x_shape = x.type.shape
# The output has excluded dimensions in the order they appear in the op argument
excluded_dims = tuple(d for d in exclude if d in x_dims)
excluded_shape = tuple(x_shape[x_dims.index(d)] for d in excluded_dims)
output = xtensor(
dtype=dtype,
shape=broadcast_shape + excluded_shape,
dims=broadcast_dims + excluded_dims,
)
outputs.append(output)
return Apply(self, inputs, outputs)
def broadcast(
*args, exclude: str | Sequence[str] | None = None
) -> tuple[XTensorVariable, ...]:
"""Broadcast any number of XTensorVariables against each other.
Parameters
----------
*args : XTensorVariable
The tensors to broadcast against each other.
exclude : str or Sequence[str] or None, optional
"""
if not args:
return ()
if exclude is None:
exclude = ()
elif isinstance(exclude, str):
exclude = (exclude,)
elif not isinstance(exclude, Sequence):
raise TypeError(f"exclude must be None, str, or Sequence, got {type(exclude)}")
# xarray broadcast always returns a tuple, even if there's only one tensor
return tuple(Broadcast(exclude=exclude)(*args, return_list=True)) # type: ignore
......@@ -736,6 +736,15 @@ class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]):
"""Matrix multiplication with another XTensorVariable, contracting over matching or specified dims."""
return px.math.dot(self, other, dim=dim)
def broadcast(self, *others, exclude=None):
"""Broadcast this tensor against other XTensorVariables."""
return px.shape.broadcast(self, *others, exclude=exclude)
def broadcast_like(self, other, exclude=None):
"""Broadcast this tensor against another XTensorVariable."""
_, self_bcast = px.shape.broadcast(other, self, exclude=exclude)
return self_bcast
class XTensorConstantSignature(TensorConstantSignature):
pass
......
from collections.abc import Sequence
from itertools import chain
import numpy as np
......@@ -13,13 +14,22 @@ from pytensor.tensor.utils import (
get_static_shape_from_size_variables,
)
from pytensor.xtensor.basic import XOp
from pytensor.xtensor.type import as_xtensor, xtensor
from pytensor.xtensor.type import XTensorVariable, as_xtensor, xtensor
def combine_dims_and_shape(inputs):
def combine_dims_and_shape(
inputs: Sequence[XTensorVariable], exclude: Sequence[str] | None = None
) -> dict[str, int | None]:
"""Combine information of static dimensions and shapes from multiple xtensor inputs.
Exclude
"""
exclude_set: set[str] = set() if exclude is None else set(exclude)
dims_and_shape: dict[str, int | None] = {}
for inp in inputs:
for dim, dim_length in zip(inp.type.dims, inp.type.shape):
if dim in exclude_set:
continue
if dim not in dims_and_shape:
dims_and_shape[dim] = dim_length
elif dim_length is not None:
......
......@@ -9,10 +9,12 @@ from itertools import chain, combinations
import numpy as np
from xarray import DataArray
from xarray import broadcast as xr_broadcast
from xarray import concat as xr_concat
from pytensor.tensor import scalar
from pytensor.xtensor.shape import (
broadcast,
concat,
stack,
unstack,
......@@ -466,3 +468,168 @@ def test_expand_dims_errors():
# Test with a numpy array as dim (not supported)
with pytest.raises(TypeError, match="unhashable type"):
y.expand_dims(np.array([1, 2]))
class TestBroadcast:
@pytest.mark.parametrize(
"exclude",
[
None,
[],
["b"],
["b", "d"],
["a", "d"],
["b", "c", "d"],
["a", "b", "c", "d"],
],
)
def test_compatible_excluded_shapes(self, exclude):
# Create test data
x = xtensor("x", dims=("a", "b"), shape=(3, 4))
y = xtensor("y", dims=("c", "d"), shape=(5, 6))
z = xtensor("z", dims=("b", "d"), shape=(4, 6))
x_test = xr_arange_like(x)
y_test = xr_arange_like(y)
z_test = xr_arange_like(z)
# Test with excluded dims
x2_expected, y2_expected, z2_expected = xr_broadcast(
x_test, y_test, z_test, exclude=exclude
)
x2, y2, z2 = broadcast(x, y, z, exclude=exclude)
fn = xr_function([x, y, z], [x2, y2, z2])
x2_result, y2_result, z2_result = fn(x_test, y_test, z_test)
xr_assert_allclose(x2_result, x2_expected)
xr_assert_allclose(y2_result, y2_expected)
xr_assert_allclose(z2_result, z2_expected)
def test_incompatible_excluded_shapes(self):
# Test that excluded dims are allowed to be different sizes
x = xtensor("x", dims=("a", "b"), shape=(3, 4))
y = xtensor("y", dims=("c", "d"), shape=(5, 6))
z = xtensor("z", dims=("b", "d"), shape=(4, 7))
out = broadcast(x, y, z, exclude=["d"])
x_test = xr_arange_like(x)
y_test = xr_arange_like(y)
z_test = xr_arange_like(z)
fn = xr_function([x, y, z], out)
results = fn(x_test, y_test, z_test)
expected_results = xr_broadcast(x_test, y_test, z_test, exclude=["d"])
for res, expected_res in zip(results, expected_results, strict=True):
xr_assert_allclose(res, expected_res)
@pytest.mark.parametrize("exclude", [[], ["b"], ["b", "c"], ["a", "b", "d"]])
def test_runtime_shapes(self, exclude):
x = xtensor("x", dims=("a", "b"), shape=(None, 4))
y = xtensor("y", dims=("c", "d"), shape=(5, None))
z = xtensor("z", dims=("b", "d"), shape=(None, None))
out = broadcast(x, y, z, exclude=exclude)
x_test = xr_arange_like(xtensor(dims=x.dims, shape=(3, 4)))
y_test = xr_arange_like(xtensor(dims=y.dims, shape=(5, 6)))
z_test = xr_arange_like(xtensor(dims=z.dims, shape=(4, 6)))
fn = xr_function([x, y, z], out)
results = fn(x_test, y_test, z_test)
expected_results = xr_broadcast(x_test, y_test, z_test, exclude=exclude)
for res, expected_res in zip(results, expected_results, strict=True):
xr_assert_allclose(res, expected_res)
# Test invalid shape raises an error
# Note: We might decide not to raise an error in the lowered graphs for performance reasons
if "d" not in exclude:
z_test_bad = xr_arange_like(xtensor(dims=z.dims, shape=(4, 7)))
with pytest.raises(Exception):
fn(x_test, y_test, z_test_bad)
def test_broadcast_excluded_dims_in_different_order(self):
"""Test broadcasting excluded dims are aligned with user input."""
x = xtensor("x", dims=("a", "c", "b"), shape=(3, 4, 5))
y = xtensor("y", dims=("a", "b", "c"), shape=(3, 5, 4))
out = (out_x, out_y) = broadcast(x, y, exclude=["c", "b"])
assert out_x.type.dims == ("a", "c", "b")
assert out_y.type.dims == ("a", "c", "b")
x_test = xr_arange_like(x)
y_test = xr_arange_like(y)
fn = xr_function([x, y], out)
results = fn(x_test, y_test)
expected_results = xr_broadcast(x_test, y_test, exclude=["c", "b"])
for res, expected_res in zip(results, expected_results, strict=True):
xr_assert_allclose(res, expected_res)
def test_broadcast_errors(self):
"""Test error handling in broadcast."""
x = xtensor("x", dims=("a", "b"), shape=(3, 4))
y = xtensor("y", dims=("c", "d"), shape=(5, 6))
z = xtensor("z", dims=("b", "d"), shape=(4, 6))
with pytest.raises(TypeError, match="exclude must be None, str, or Sequence"):
broadcast(x, y, z, exclude=1)
# Test with conflicting shapes
x = xtensor("x", dims=("a", "b"), shape=(3, 4))
y = xtensor("y", dims=("c", "d"), shape=(5, 6))
z = xtensor("z", dims=("b", "d"), shape=(4, 7))
with pytest.raises(ValueError, match="Dimension .* has conflicting shapes"):
broadcast(x, y, z)
def test_broadcast_no_input(self):
assert broadcast() == xr_broadcast()
assert broadcast(exclude=("a",)) == xr_broadcast(exclude=("a",))
def test_broadcast_single_input(self):
"""Test broadcasting a single input."""
x = xtensor("x", dims=("a", "b"), shape=(3, 4))
# Broadcast with a single input can still imply a transpose via the exclude parameter
outs = [
*broadcast(x),
*broadcast(x, exclude=("a", "b")),
*broadcast(x, exclude=("b", "a")),
*broadcast(x, exclude=("b",)),
]
fn = xr_function([x], outs)
x_test = xr_arange_like(x)
results = fn(x_test)
expected_results = [
*xr_broadcast(x_test),
*xr_broadcast(x_test, exclude=("a", "b")),
*xr_broadcast(x_test, exclude=("b", "a")),
*xr_broadcast(x_test, exclude=("b",)),
]
for res, expected_res in zip(results, expected_results, strict=True):
xr_assert_allclose(res, expected_res)
@pytest.mark.parametrize("exclude", [None, ["b"], ["b", "c"]])
def test_broadcast_like(self, exclude):
"""Test broadcast_like method"""
# Create test data
x = xtensor("x", dims=("a", "b"), shape=(3, 4))
y = xtensor("y", dims=("c", "d"), shape=(5, 6))
z = xtensor("z", dims=("b", "d"), shape=(4, 6))
# Order matters so we test both orders
outs = [
x.broadcast_like(y, exclude=exclude),
y.broadcast_like(x, exclude=exclude),
y.broadcast_like(z, exclude=exclude),
z.broadcast_like(y, exclude=exclude),
]
x_test = xr_arange_like(x)
y_test = xr_arange_like(y)
z_test = xr_arange_like(z)
fn = xr_function([x, y, z], outs)
results = fn(x_test, y_test, z_test)
expected_results = [
x_test.broadcast_like(y_test, exclude=exclude),
y_test.broadcast_like(x_test, exclude=exclude),
y_test.broadcast_like(z_test, exclude=exclude),
z_test.broadcast_like(y_test, exclude=exclude),
]
for res, expected_res in zip(results, expected_results, strict=True):
xr_assert_allclose(res, expected_res)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论