提交 e9219159 authored 作者: Allen Downey's avatar Allen Downey 提交者: Ricardo Vieira

Implement broadcast for XTensorVariables

上级 e1ce1c35
...@@ -3,7 +3,7 @@ import warnings ...@@ -3,7 +3,7 @@ import warnings
import pytensor.xtensor.rewriting import pytensor.xtensor.rewriting
from pytensor.xtensor import linalg, random from pytensor.xtensor import linalg, random
from pytensor.xtensor.math import dot from pytensor.xtensor.math import dot
from pytensor.xtensor.shape import concat from pytensor.xtensor.shape import broadcast, concat
from pytensor.xtensor.type import ( from pytensor.xtensor.type import (
as_xtensor, as_xtensor,
xtensor, xtensor,
......
import pytensor.tensor as pt
from pytensor.graph import node_rewriter from pytensor.graph import node_rewriter
from pytensor.tensor import ( from pytensor.tensor import (
broadcast_to, broadcast_to,
...@@ -11,6 +12,7 @@ from pytensor.xtensor.basic import tensor_from_xtensor, xtensor_from_tensor ...@@ -11,6 +12,7 @@ from pytensor.xtensor.basic import tensor_from_xtensor, xtensor_from_tensor
from pytensor.xtensor.rewriting.basic import register_lower_xtensor from pytensor.xtensor.rewriting.basic import register_lower_xtensor
from pytensor.xtensor.rewriting.utils import lower_aligned from pytensor.xtensor.rewriting.utils import lower_aligned
from pytensor.xtensor.shape import ( from pytensor.xtensor.shape import (
Broadcast,
Concat, Concat,
ExpandDims, ExpandDims,
Squeeze, Squeeze,
...@@ -157,3 +159,61 @@ def lower_expand_dims(fgraph, node): ...@@ -157,3 +159,61 @@ def lower_expand_dims(fgraph, node):
# Convert result back to xtensor # Convert result back to xtensor
result = xtensor_from_tensor(result_tensor, dims=out.type.dims) result = xtensor_from_tensor(result_tensor, dims=out.type.dims)
return [result] return [result]
@register_lower_xtensor
@node_rewriter(tracks=[Broadcast])
def lower_broadcast(fgraph, node):
"""Rewrite XBroadcast using tensor operations."""
excluded_dims = node.op.exclude
tensor_inputs = [
lower_aligned(inp, out.type.dims)
for inp, out in zip(node.inputs, node.outputs, strict=True)
]
if not excluded_dims:
# Simple case: All dimensions are broadcasted
tensor_outputs = pt.broadcast_arrays(*tensor_inputs)
else:
# Complex case: Some dimensions are excluded from broadcasting
# Pick the first dimension_length for each dim
broadcast_dims = {
d: None for d in node.outputs[0].type.dims if d not in excluded_dims
}
for xtensor_inp in node.inputs:
for dim, dim_length in xtensor_inp.sizes.items():
if dim in broadcast_dims and broadcast_dims[dim] is None:
# If the dimension is not excluded, set its shape
broadcast_dims[dim] = dim_length
assert not any(
value is None for value in broadcast_dims.values()
), "All dimensions must have a length"
# Create zeros with the broadcast dimensions, to then broadcast each input against
# PyTensor will rewrite into using only the shapes of the zeros tensor
broadcast_dims = pt.zeros(
tuple(broadcast_dims.values()),
dtype=node.outputs[0].type.dtype,
)
n_broadcast_dims = broadcast_dims.ndim
tensor_outputs = []
for tensor_inp, xtensor_out in zip(tensor_inputs, node.outputs, strict=True):
n_excluded_dims = tensor_inp.type.ndim - n_broadcast_dims
# Excluded dimensions are on the right side of the output tensor so we padright the broadcast_dims
# second is equivalent to `np.broadcast_arrays(x, y)[1]` in PyTensor
tensor_outputs.append(
pt.second(
pt.shape_padright(broadcast_dims, n_excluded_dims),
tensor_inp,
)
)
new_outs = [
xtensor_from_tensor(out_tensor, dims=out.type.dims)
for out_tensor, out in zip(tensor_outputs, node.outputs)
]
return new_outs
...@@ -13,7 +13,8 @@ from pytensor.tensor.exceptions import NotScalarConstantError ...@@ -13,7 +13,8 @@ from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.type import integer_dtypes from pytensor.tensor.type import integer_dtypes
from pytensor.tensor.utils import get_static_shape_from_size_variables from pytensor.tensor.utils import get_static_shape_from_size_variables
from pytensor.xtensor.basic import XOp from pytensor.xtensor.basic import XOp
from pytensor.xtensor.type import as_xtensor, xtensor from pytensor.xtensor.type import XTensorVariable, as_xtensor, xtensor
from pytensor.xtensor.vectorization import combine_dims_and_shape
class Stack(XOp): class Stack(XOp):
...@@ -504,3 +505,63 @@ def expand_dims(x, dim=None, create_index_for_new_dim=None, axis=None, **dim_kwa ...@@ -504,3 +505,63 @@ def expand_dims(x, dim=None, create_index_for_new_dim=None, axis=None, **dim_kwa
x = Transpose(dims=tuple(target_dims))(x) x = Transpose(dims=tuple(target_dims))(x)
return x return x
class Broadcast(XOp):
"""Broadcast multiple XTensorVariables against each other."""
__props__ = ("exclude",)
def __init__(self, exclude: Sequence[str] = ()):
self.exclude = tuple(exclude)
def make_node(self, *inputs):
inputs = [as_xtensor(x) for x in inputs]
exclude = self.exclude
dims_and_shape = combine_dims_and_shape(inputs, exclude=exclude)
broadcast_dims = tuple(dims_and_shape.keys())
broadcast_shape = tuple(dims_and_shape.values())
dtype = upcast(*[x.type.dtype for x in inputs])
outputs = []
for x in inputs:
x_dims = x.type.dims
x_shape = x.type.shape
# The output has excluded dimensions in the order they appear in the op argument
excluded_dims = tuple(d for d in exclude if d in x_dims)
excluded_shape = tuple(x_shape[x_dims.index(d)] for d in excluded_dims)
output = xtensor(
dtype=dtype,
shape=broadcast_shape + excluded_shape,
dims=broadcast_dims + excluded_dims,
)
outputs.append(output)
return Apply(self, inputs, outputs)
def broadcast(
*args, exclude: str | Sequence[str] | None = None
) -> tuple[XTensorVariable, ...]:
"""Broadcast any number of XTensorVariables against each other.
Parameters
----------
*args : XTensorVariable
The tensors to broadcast against each other.
exclude : str or Sequence[str] or None, optional
"""
if not args:
return ()
if exclude is None:
exclude = ()
elif isinstance(exclude, str):
exclude = (exclude,)
elif not isinstance(exclude, Sequence):
raise TypeError(f"exclude must be None, str, or Sequence, got {type(exclude)}")
# xarray broadcast always returns a tuple, even if there's only one tensor
return tuple(Broadcast(exclude=exclude)(*args, return_list=True)) # type: ignore
...@@ -736,6 +736,15 @@ class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]): ...@@ -736,6 +736,15 @@ class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]):
"""Matrix multiplication with another XTensorVariable, contracting over matching or specified dims.""" """Matrix multiplication with another XTensorVariable, contracting over matching or specified dims."""
return px.math.dot(self, other, dim=dim) return px.math.dot(self, other, dim=dim)
def broadcast(self, *others, exclude=None):
"""Broadcast this tensor against other XTensorVariables."""
return px.shape.broadcast(self, *others, exclude=exclude)
def broadcast_like(self, other, exclude=None):
"""Broadcast this tensor against another XTensorVariable."""
_, self_bcast = px.shape.broadcast(other, self, exclude=exclude)
return self_bcast
class XTensorConstantSignature(TensorConstantSignature): class XTensorConstantSignature(TensorConstantSignature):
pass pass
......
from collections.abc import Sequence
from itertools import chain from itertools import chain
import numpy as np import numpy as np
...@@ -13,13 +14,22 @@ from pytensor.tensor.utils import ( ...@@ -13,13 +14,22 @@ from pytensor.tensor.utils import (
get_static_shape_from_size_variables, get_static_shape_from_size_variables,
) )
from pytensor.xtensor.basic import XOp from pytensor.xtensor.basic import XOp
from pytensor.xtensor.type import as_xtensor, xtensor from pytensor.xtensor.type import XTensorVariable, as_xtensor, xtensor
def combine_dims_and_shape(inputs): def combine_dims_and_shape(
inputs: Sequence[XTensorVariable], exclude: Sequence[str] | None = None
) -> dict[str, int | None]:
"""Combine information of static dimensions and shapes from multiple xtensor inputs.
Exclude
"""
exclude_set: set[str] = set() if exclude is None else set(exclude)
dims_and_shape: dict[str, int | None] = {} dims_and_shape: dict[str, int | None] = {}
for inp in inputs: for inp in inputs:
for dim, dim_length in zip(inp.type.dims, inp.type.shape): for dim, dim_length in zip(inp.type.dims, inp.type.shape):
if dim in exclude_set:
continue
if dim not in dims_and_shape: if dim not in dims_and_shape:
dims_and_shape[dim] = dim_length dims_and_shape[dim] = dim_length
elif dim_length is not None: elif dim_length is not None:
......
...@@ -9,10 +9,12 @@ from itertools import chain, combinations ...@@ -9,10 +9,12 @@ from itertools import chain, combinations
import numpy as np import numpy as np
from xarray import DataArray from xarray import DataArray
from xarray import broadcast as xr_broadcast
from xarray import concat as xr_concat from xarray import concat as xr_concat
from pytensor.tensor import scalar from pytensor.tensor import scalar
from pytensor.xtensor.shape import ( from pytensor.xtensor.shape import (
broadcast,
concat, concat,
stack, stack,
unstack, unstack,
...@@ -466,3 +468,168 @@ def test_expand_dims_errors(): ...@@ -466,3 +468,168 @@ def test_expand_dims_errors():
# Test with a numpy array as dim (not supported) # Test with a numpy array as dim (not supported)
with pytest.raises(TypeError, match="unhashable type"): with pytest.raises(TypeError, match="unhashable type"):
y.expand_dims(np.array([1, 2])) y.expand_dims(np.array([1, 2]))
class TestBroadcast:
@pytest.mark.parametrize(
"exclude",
[
None,
[],
["b"],
["b", "d"],
["a", "d"],
["b", "c", "d"],
["a", "b", "c", "d"],
],
)
def test_compatible_excluded_shapes(self, exclude):
# Create test data
x = xtensor("x", dims=("a", "b"), shape=(3, 4))
y = xtensor("y", dims=("c", "d"), shape=(5, 6))
z = xtensor("z", dims=("b", "d"), shape=(4, 6))
x_test = xr_arange_like(x)
y_test = xr_arange_like(y)
z_test = xr_arange_like(z)
# Test with excluded dims
x2_expected, y2_expected, z2_expected = xr_broadcast(
x_test, y_test, z_test, exclude=exclude
)
x2, y2, z2 = broadcast(x, y, z, exclude=exclude)
fn = xr_function([x, y, z], [x2, y2, z2])
x2_result, y2_result, z2_result = fn(x_test, y_test, z_test)
xr_assert_allclose(x2_result, x2_expected)
xr_assert_allclose(y2_result, y2_expected)
xr_assert_allclose(z2_result, z2_expected)
def test_incompatible_excluded_shapes(self):
# Test that excluded dims are allowed to be different sizes
x = xtensor("x", dims=("a", "b"), shape=(3, 4))
y = xtensor("y", dims=("c", "d"), shape=(5, 6))
z = xtensor("z", dims=("b", "d"), shape=(4, 7))
out = broadcast(x, y, z, exclude=["d"])
x_test = xr_arange_like(x)
y_test = xr_arange_like(y)
z_test = xr_arange_like(z)
fn = xr_function([x, y, z], out)
results = fn(x_test, y_test, z_test)
expected_results = xr_broadcast(x_test, y_test, z_test, exclude=["d"])
for res, expected_res in zip(results, expected_results, strict=True):
xr_assert_allclose(res, expected_res)
@pytest.mark.parametrize("exclude", [[], ["b"], ["b", "c"], ["a", "b", "d"]])
def test_runtime_shapes(self, exclude):
x = xtensor("x", dims=("a", "b"), shape=(None, 4))
y = xtensor("y", dims=("c", "d"), shape=(5, None))
z = xtensor("z", dims=("b", "d"), shape=(None, None))
out = broadcast(x, y, z, exclude=exclude)
x_test = xr_arange_like(xtensor(dims=x.dims, shape=(3, 4)))
y_test = xr_arange_like(xtensor(dims=y.dims, shape=(5, 6)))
z_test = xr_arange_like(xtensor(dims=z.dims, shape=(4, 6)))
fn = xr_function([x, y, z], out)
results = fn(x_test, y_test, z_test)
expected_results = xr_broadcast(x_test, y_test, z_test, exclude=exclude)
for res, expected_res in zip(results, expected_results, strict=True):
xr_assert_allclose(res, expected_res)
# Test invalid shape raises an error
# Note: We might decide not to raise an error in the lowered graphs for performance reasons
if "d" not in exclude:
z_test_bad = xr_arange_like(xtensor(dims=z.dims, shape=(4, 7)))
with pytest.raises(Exception):
fn(x_test, y_test, z_test_bad)
def test_broadcast_excluded_dims_in_different_order(self):
"""Test broadcasting excluded dims are aligned with user input."""
x = xtensor("x", dims=("a", "c", "b"), shape=(3, 4, 5))
y = xtensor("y", dims=("a", "b", "c"), shape=(3, 5, 4))
out = (out_x, out_y) = broadcast(x, y, exclude=["c", "b"])
assert out_x.type.dims == ("a", "c", "b")
assert out_y.type.dims == ("a", "c", "b")
x_test = xr_arange_like(x)
y_test = xr_arange_like(y)
fn = xr_function([x, y], out)
results = fn(x_test, y_test)
expected_results = xr_broadcast(x_test, y_test, exclude=["c", "b"])
for res, expected_res in zip(results, expected_results, strict=True):
xr_assert_allclose(res, expected_res)
def test_broadcast_errors(self):
"""Test error handling in broadcast."""
x = xtensor("x", dims=("a", "b"), shape=(3, 4))
y = xtensor("y", dims=("c", "d"), shape=(5, 6))
z = xtensor("z", dims=("b", "d"), shape=(4, 6))
with pytest.raises(TypeError, match="exclude must be None, str, or Sequence"):
broadcast(x, y, z, exclude=1)
# Test with conflicting shapes
x = xtensor("x", dims=("a", "b"), shape=(3, 4))
y = xtensor("y", dims=("c", "d"), shape=(5, 6))
z = xtensor("z", dims=("b", "d"), shape=(4, 7))
with pytest.raises(ValueError, match="Dimension .* has conflicting shapes"):
broadcast(x, y, z)
def test_broadcast_no_input(self):
assert broadcast() == xr_broadcast()
assert broadcast(exclude=("a",)) == xr_broadcast(exclude=("a",))
def test_broadcast_single_input(self):
"""Test broadcasting a single input."""
x = xtensor("x", dims=("a", "b"), shape=(3, 4))
# Broadcast with a single input can still imply a transpose via the exclude parameter
outs = [
*broadcast(x),
*broadcast(x, exclude=("a", "b")),
*broadcast(x, exclude=("b", "a")),
*broadcast(x, exclude=("b",)),
]
fn = xr_function([x], outs)
x_test = xr_arange_like(x)
results = fn(x_test)
expected_results = [
*xr_broadcast(x_test),
*xr_broadcast(x_test, exclude=("a", "b")),
*xr_broadcast(x_test, exclude=("b", "a")),
*xr_broadcast(x_test, exclude=("b",)),
]
for res, expected_res in zip(results, expected_results, strict=True):
xr_assert_allclose(res, expected_res)
@pytest.mark.parametrize("exclude", [None, ["b"], ["b", "c"]])
def test_broadcast_like(self, exclude):
"""Test broadcast_like method"""
# Create test data
x = xtensor("x", dims=("a", "b"), shape=(3, 4))
y = xtensor("y", dims=("c", "d"), shape=(5, 6))
z = xtensor("z", dims=("b", "d"), shape=(4, 6))
# Order matters so we test both orders
outs = [
x.broadcast_like(y, exclude=exclude),
y.broadcast_like(x, exclude=exclude),
y.broadcast_like(z, exclude=exclude),
z.broadcast_like(y, exclude=exclude),
]
x_test = xr_arange_like(x)
y_test = xr_arange_like(y)
z_test = xr_arange_like(z)
fn = xr_function([x, y, z], outs)
results = fn(x_test, y_test, z_test)
expected_results = [
x_test.broadcast_like(y_test, exclude=exclude),
y_test.broadcast_like(x_test, exclude=exclude),
y_test.broadcast_like(z_test, exclude=exclude),
z_test.broadcast_like(y_test, exclude=exclude),
]
for res, expected_res in zip(results, expected_results, strict=True):
xr_assert_allclose(res, expected_res)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论