提交 30b50fda authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Implement concat for XTensorVariables

上级 010e0f97
......@@ -2,6 +2,7 @@ import warnings
import pytensor.xtensor.rewriting
from pytensor.xtensor import linalg
from pytensor.xtensor.shape import concat
from pytensor.xtensor.type import (
XTensorType,
as_xtensor,
......
from pytensor.graph import node_rewriter
from pytensor.tensor import moveaxis
from pytensor.tensor import broadcast_to, join, moveaxis
from pytensor.xtensor.basic import tensor_from_xtensor, xtensor_from_tensor
from pytensor.xtensor.rewriting.basic import register_lower_xtensor
from pytensor.xtensor.shape import Stack
from pytensor.xtensor.shape import Concat, Stack
@register_lower_xtensor
......@@ -27,3 +27,46 @@ def lower_stack(fgraph, node):
new_out = xtensor_from_tensor(final_tensor, dims=node.outputs[0].type.dims)
return [new_out]
@register_lower_xtensor
@node_rewriter(tracks=[Concat])
def lower_concat(fgraph, node):
out_dims = node.outputs[0].type.dims
concat_dim = node.op.dim
concat_axis = out_dims.index(concat_dim)
# Convert input XTensors to Tensors and align batch dimensions
tensor_inputs = []
for inp in node.inputs:
inp_dims = inp.type.dims
order = [
inp_dims.index(out_dim) if out_dim in inp_dims else "x"
for out_dim in out_dims
]
tensor_inp = tensor_from_xtensor(inp).dimshuffle(order)
tensor_inputs.append(tensor_inp)
# Broadcast non-concatenated dimensions of each input
non_concat_shape = [None] * len(out_dims)
for tensor_inp in tensor_inputs:
# TODO: This is assuming the graph is correct and every non-concat dimension matches in shape at runtime
# I'm running this as "shape_unsafe" to simplify the logic / returned graph
for i, (bcast, sh) in enumerate(
zip(tensor_inp.type.broadcastable, tensor_inp.shape)
):
if bcast or i == concat_axis or non_concat_shape[i] is not None:
continue
non_concat_shape[i] = sh
assert non_concat_shape.count(None) == 1
bcast_tensor_inputs = []
for tensor_inp in tensor_inputs:
# We modify the concat_axis in place, as we don't need the list anywhere else
non_concat_shape[concat_axis] = tensor_inp.shape[concat_axis]
bcast_tensor_inputs.append(broadcast_to(tensor_inp, non_concat_shape))
joined_tensor = join(concat_axis, *bcast_tensor_inputs)
new_out = xtensor_from_tensor(joined_tensor, dims=out_dims)
return [new_out]
from collections.abc import Sequence
from pytensor.graph import Apply
from pytensor.scalar import upcast
from pytensor.xtensor.basic import XOp
from pytensor.xtensor.type import as_xtensor, xtensor
......@@ -69,3 +70,55 @@ def stack(x, dim: dict[str, Sequence[str]] | None = None, **dims: Sequence[str])
)
y = Stack(new_dim_name, tuple(stacked_dims))(y)
return y
class Concat(XOp):
__props__ = ("dim",)
def __init__(self, dim: str):
self.dim = dim
super().__init__()
def make_node(self, *inputs):
inputs = [as_xtensor(inp) for inp in inputs]
concat_dim = self.dim
dims_and_shape: dict[str, int | None] = {}
for inp in inputs:
for dim, dim_length in zip(inp.type.dims, inp.type.shape):
if dim not in dims_and_shape:
dims_and_shape[dim] = dim_length
else:
if dim == concat_dim:
if dim_length is None:
dims_and_shape[dim] = None
elif dims_and_shape[dim] is not None:
dims_and_shape[dim] += dim_length
elif dim_length is not None:
# Check for conflicting in non-concatenated shapes
if (dims_and_shape[dim] is not None) and (
dims_and_shape[dim] != dim_length
):
raise ValueError(
f"Non-concatenated dimension {dim} has conflicting shapes"
)
# Keep the non-None shape
dims_and_shape[dim] = dim_length
if concat_dim not in dims_and_shape:
# It's a new dim, that should be located at the start
dims_and_shape = {concat_dim: len(inputs)} | dims_and_shape
elif dims_and_shape[concat_dim] is not None:
# We need to add +1 for every input that doesn't have this dimension
for inp in inputs:
if concat_dim not in inp.type.dims:
dims_and_shape[concat_dim] += 1
dims, shape = zip(*dims_and_shape.items())
dtype = upcast(*[x.type.dtype for x in inputs])
output = xtensor(dtype=dtype, dims=dims, shape=shape)
return Apply(self, inputs, [output])
def concat(xtensors, dim: str):
return Concat(dim=dim)(*xtensors)
......@@ -6,12 +6,16 @@ pytest.importorskip("xarray")
from itertools import chain, combinations
from pytensor.xtensor.shape import stack
import numpy as np
from xarray import concat as xr_concat
from pytensor.xtensor.shape import concat, stack
from pytensor.xtensor.type import xtensor
from tests.xtensor.util import (
xr_arange_like,
xr_assert_allclose,
xr_function,
xr_random_like,
)
......@@ -65,3 +69,63 @@ def test_multiple_stacks():
res = fn(x_test)
expected_res = x_test.stack(new_dim1=("a", "b"), new_dim2=("c", "d"))
xr_assert_allclose(res[0], expected_res)
@pytest.mark.parametrize("dim", ("a", "b", "new"))
def test_concat(dim):
rng = np.random.default_rng(sum(map(ord, dim)))
x1 = xtensor("x1", dims=("a", "b"), shape=(2, 3))
x2 = xtensor("x2", dims=("b", "a"), shape=(3, 2))
x3_shape0 = 4 if dim == "a" else 2
x3_shape1 = 5 if dim == "b" else 3
x3 = xtensor("x3", dims=("a", "b"), shape=(x3_shape0, x3_shape1))
out = concat([x1, x2, x3], dim=dim)
fn = xr_function([x1, x2, x3], out)
x1_test = xr_random_like(x1, rng)
x2_test = xr_random_like(x2, rng)
x3_test = xr_random_like(x3, rng)
res = fn(x1_test, x2_test, x3_test)
expected_res = xr_concat([x1_test, x2_test, x3_test], dim=dim)
xr_assert_allclose(res, expected_res)
@pytest.mark.parametrize("dim", ("a", "b", "c", "d", "new"))
def test_concat_with_broadcast(dim):
rng = np.random.default_rng(sum(map(ord, dim)) + 1)
x1 = xtensor("x1", dims=("a", "b"), shape=(2, 3))
x2 = xtensor("x2", dims=("b", "c"), shape=(3, 5))
x3 = xtensor("x3", dims=("c", "d"), shape=(5, 7))
x4 = xtensor("x4", dims=(), shape=())
out = concat([x1, x2, x3, x4], dim=dim)
fn = xr_function([x1, x2, x3, x4], out)
x1_test = xr_random_like(x1, rng)
x2_test = xr_random_like(x2, rng)
x3_test = xr_random_like(x3, rng)
x4_test = xr_random_like(x4, rng)
res = fn(x1_test, x2_test, x3_test, x4_test)
expected_res = xr_concat([x1_test, x2_test, x3_test, x4_test], dim=dim)
xr_assert_allclose(res, expected_res)
def test_concat_scalar():
x1 = xtensor("x1", dims=(), shape=())
x2 = xtensor("x2", dims=(), shape=())
out = concat([x1, x2], dim="new_dim")
fn = xr_function([x1, x2], out)
x1_test = xr_random_like(x1)
x2_test = xr_random_like(x2)
res = fn(x1_test, x2_test)
expected_res = xr_concat([x1_test, x2_test], dim="new_dim")
xr_assert_allclose(res, expected_res)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论