提交 30b50fda authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Implement concat for XTensorVariables

上级 010e0f97
...@@ -2,6 +2,7 @@ import warnings ...@@ -2,6 +2,7 @@ import warnings
import pytensor.xtensor.rewriting import pytensor.xtensor.rewriting
from pytensor.xtensor import linalg from pytensor.xtensor import linalg
from pytensor.xtensor.shape import concat
from pytensor.xtensor.type import ( from pytensor.xtensor.type import (
XTensorType, XTensorType,
as_xtensor, as_xtensor,
......
from pytensor.graph import node_rewriter from pytensor.graph import node_rewriter
from pytensor.tensor import moveaxis from pytensor.tensor import broadcast_to, join, moveaxis
from pytensor.xtensor.basic import tensor_from_xtensor, xtensor_from_tensor from pytensor.xtensor.basic import tensor_from_xtensor, xtensor_from_tensor
from pytensor.xtensor.rewriting.basic import register_lower_xtensor from pytensor.xtensor.rewriting.basic import register_lower_xtensor
from pytensor.xtensor.shape import Stack from pytensor.xtensor.shape import Concat, Stack
@register_lower_xtensor @register_lower_xtensor
...@@ -27,3 +27,46 @@ def lower_stack(fgraph, node): ...@@ -27,3 +27,46 @@ def lower_stack(fgraph, node):
new_out = xtensor_from_tensor(final_tensor, dims=node.outputs[0].type.dims) new_out = xtensor_from_tensor(final_tensor, dims=node.outputs[0].type.dims)
return [new_out] return [new_out]
@register_lower_xtensor
@node_rewriter(tracks=[Concat])
def lower_concat(fgraph, node):
out_dims = node.outputs[0].type.dims
concat_dim = node.op.dim
concat_axis = out_dims.index(concat_dim)
# Convert input XTensors to Tensors and align batch dimensions
tensor_inputs = []
for inp in node.inputs:
inp_dims = inp.type.dims
order = [
inp_dims.index(out_dim) if out_dim in inp_dims else "x"
for out_dim in out_dims
]
tensor_inp = tensor_from_xtensor(inp).dimshuffle(order)
tensor_inputs.append(tensor_inp)
# Broadcast non-concatenated dimensions of each input
non_concat_shape = [None] * len(out_dims)
for tensor_inp in tensor_inputs:
# TODO: This is assuming the graph is correct and every non-concat dimension matches in shape at runtime
# I'm running this as "shape_unsafe" to simplify the logic / returned graph
for i, (bcast, sh) in enumerate(
zip(tensor_inp.type.broadcastable, tensor_inp.shape)
):
if bcast or i == concat_axis or non_concat_shape[i] is not None:
continue
non_concat_shape[i] = sh
assert non_concat_shape.count(None) == 1
bcast_tensor_inputs = []
for tensor_inp in tensor_inputs:
# We modify the concat_axis in place, as we don't need the list anywhere else
non_concat_shape[concat_axis] = tensor_inp.shape[concat_axis]
bcast_tensor_inputs.append(broadcast_to(tensor_inp, non_concat_shape))
joined_tensor = join(concat_axis, *bcast_tensor_inputs)
new_out = xtensor_from_tensor(joined_tensor, dims=out_dims)
return [new_out]
from collections.abc import Sequence from collections.abc import Sequence
from pytensor.graph import Apply from pytensor.graph import Apply
from pytensor.scalar import upcast
from pytensor.xtensor.basic import XOp from pytensor.xtensor.basic import XOp
from pytensor.xtensor.type import as_xtensor, xtensor from pytensor.xtensor.type import as_xtensor, xtensor
...@@ -69,3 +70,55 @@ def stack(x, dim: dict[str, Sequence[str]] | None = None, **dims: Sequence[str]) ...@@ -69,3 +70,55 @@ def stack(x, dim: dict[str, Sequence[str]] | None = None, **dims: Sequence[str])
) )
y = Stack(new_dim_name, tuple(stacked_dims))(y) y = Stack(new_dim_name, tuple(stacked_dims))(y)
return y return y
class Concat(XOp):
__props__ = ("dim",)
def __init__(self, dim: str):
self.dim = dim
super().__init__()
def make_node(self, *inputs):
inputs = [as_xtensor(inp) for inp in inputs]
concat_dim = self.dim
dims_and_shape: dict[str, int | None] = {}
for inp in inputs:
for dim, dim_length in zip(inp.type.dims, inp.type.shape):
if dim not in dims_and_shape:
dims_and_shape[dim] = dim_length
else:
if dim == concat_dim:
if dim_length is None:
dims_and_shape[dim] = None
elif dims_and_shape[dim] is not None:
dims_and_shape[dim] += dim_length
elif dim_length is not None:
# Check for conflicting in non-concatenated shapes
if (dims_and_shape[dim] is not None) and (
dims_and_shape[dim] != dim_length
):
raise ValueError(
f"Non-concatenated dimension {dim} has conflicting shapes"
)
# Keep the non-None shape
dims_and_shape[dim] = dim_length
if concat_dim not in dims_and_shape:
# It's a new dim, that should be located at the start
dims_and_shape = {concat_dim: len(inputs)} | dims_and_shape
elif dims_and_shape[concat_dim] is not None:
# We need to add +1 for every input that doesn't have this dimension
for inp in inputs:
if concat_dim not in inp.type.dims:
dims_and_shape[concat_dim] += 1
dims, shape = zip(*dims_and_shape.items())
dtype = upcast(*[x.type.dtype for x in inputs])
output = xtensor(dtype=dtype, dims=dims, shape=shape)
return Apply(self, inputs, [output])
def concat(xtensors, dim: str):
return Concat(dim=dim)(*xtensors)
...@@ -6,12 +6,16 @@ pytest.importorskip("xarray") ...@@ -6,12 +6,16 @@ pytest.importorskip("xarray")
from itertools import chain, combinations from itertools import chain, combinations
from pytensor.xtensor.shape import stack import numpy as np
from xarray import concat as xr_concat
from pytensor.xtensor.shape import concat, stack
from pytensor.xtensor.type import xtensor from pytensor.xtensor.type import xtensor
from tests.xtensor.util import ( from tests.xtensor.util import (
xr_arange_like, xr_arange_like,
xr_assert_allclose, xr_assert_allclose,
xr_function, xr_function,
xr_random_like,
) )
...@@ -65,3 +69,63 @@ def test_multiple_stacks(): ...@@ -65,3 +69,63 @@ def test_multiple_stacks():
res = fn(x_test) res = fn(x_test)
expected_res = x_test.stack(new_dim1=("a", "b"), new_dim2=("c", "d")) expected_res = x_test.stack(new_dim1=("a", "b"), new_dim2=("c", "d"))
xr_assert_allclose(res[0], expected_res) xr_assert_allclose(res[0], expected_res)
@pytest.mark.parametrize("dim", ("a", "b", "new"))
def test_concat(dim):
rng = np.random.default_rng(sum(map(ord, dim)))
x1 = xtensor("x1", dims=("a", "b"), shape=(2, 3))
x2 = xtensor("x2", dims=("b", "a"), shape=(3, 2))
x3_shape0 = 4 if dim == "a" else 2
x3_shape1 = 5 if dim == "b" else 3
x3 = xtensor("x3", dims=("a", "b"), shape=(x3_shape0, x3_shape1))
out = concat([x1, x2, x3], dim=dim)
fn = xr_function([x1, x2, x3], out)
x1_test = xr_random_like(x1, rng)
x2_test = xr_random_like(x2, rng)
x3_test = xr_random_like(x3, rng)
res = fn(x1_test, x2_test, x3_test)
expected_res = xr_concat([x1_test, x2_test, x3_test], dim=dim)
xr_assert_allclose(res, expected_res)
@pytest.mark.parametrize("dim", ("a", "b", "c", "d", "new"))
def test_concat_with_broadcast(dim):
rng = np.random.default_rng(sum(map(ord, dim)) + 1)
x1 = xtensor("x1", dims=("a", "b"), shape=(2, 3))
x2 = xtensor("x2", dims=("b", "c"), shape=(3, 5))
x3 = xtensor("x3", dims=("c", "d"), shape=(5, 7))
x4 = xtensor("x4", dims=(), shape=())
out = concat([x1, x2, x3, x4], dim=dim)
fn = xr_function([x1, x2, x3, x4], out)
x1_test = xr_random_like(x1, rng)
x2_test = xr_random_like(x2, rng)
x3_test = xr_random_like(x3, rng)
x4_test = xr_random_like(x4, rng)
res = fn(x1_test, x2_test, x3_test, x4_test)
expected_res = xr_concat([x1_test, x2_test, x3_test, x4_test], dim=dim)
xr_assert_allclose(res, expected_res)
def test_concat_scalar():
x1 = xtensor("x1", dims=(), shape=())
x2 = xtensor("x2", dims=(), shape=())
out = concat([x1, x2], dim="new_dim")
fn = xr_function([x1, x2], out)
x1_test = xr_random_like(x1)
x2_test = xr_random_like(x2)
res = fn(x1_test, x2_test)
expected_res = xr_concat([x1_test, x2_test], dim="new_dim")
xr_assert_allclose(res, expected_res)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论