提交 42936163 authored 作者: Oriol (ProDesk)'s avatar Oriol (ProDesk) 提交者: Ricardo Vieira

Implement unstack for XTensorVariables

上级 133ec80e
from pytensor.graph import node_rewriter from pytensor.graph import node_rewriter
from pytensor.tensor import broadcast_to, join, moveaxis from pytensor.tensor import broadcast_to, join, moveaxis, specify_shape
from pytensor.xtensor.basic import tensor_from_xtensor, xtensor_from_tensor from pytensor.xtensor.basic import tensor_from_xtensor, xtensor_from_tensor
from pytensor.xtensor.rewriting.basic import register_lower_xtensor from pytensor.xtensor.rewriting.basic import register_lower_xtensor
from pytensor.xtensor.shape import Concat, Stack, Transpose from pytensor.xtensor.shape import Concat, Stack, Transpose, UnStack
@register_lower_xtensor @register_lower_xtensor
...@@ -29,6 +29,25 @@ def lower_stack(fgraph, node): ...@@ -29,6 +29,25 @@ def lower_stack(fgraph, node):
return [new_out] return [new_out]
@register_lower_xtensor
@node_rewriter(tracks=[UnStack])
def lower_unstack(fgraph, node):
x = node.inputs[0]
unstacked_lengths = node.inputs[1:]
axis_to_unstack = x.type.dims.index(node.op.old_dim_name)
x_tensor = tensor_from_xtensor(x)
x_tensor_transposed = moveaxis(x_tensor, source=[axis_to_unstack], destination=[-1])
final_tensor = x_tensor_transposed.reshape(
(*x_tensor_transposed.shape[:-1], *unstacked_lengths)
)
# Reintroduce any static shape information that was lost during the reshape
final_tensor = specify_shape(final_tensor, node.outputs[0].type.shape)
new_out = xtensor_from_tensor(final_tensor, dims=node.outputs[0].type.dims)
return [new_out]
@register_lower_xtensor @register_lower_xtensor
@node_rewriter(tracks=[Concat]) @node_rewriter(tracks=[Concat])
def lower_concat(fgraph, node): def lower_concat(fgraph, node):
......
...@@ -5,7 +5,9 @@ from types import EllipsisType ...@@ -5,7 +5,9 @@ from types import EllipsisType
from typing import Literal from typing import Literal
from pytensor.graph import Apply from pytensor.graph import Apply
from pytensor.scalar import upcast from pytensor.scalar import discrete_dtypes, upcast
from pytensor.tensor import as_tensor, get_scalar_constant_value
from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.xtensor.basic import XOp from pytensor.xtensor.basic import XOp
from pytensor.xtensor.type import as_xtensor, xtensor from pytensor.xtensor.type import as_xtensor, xtensor
...@@ -76,6 +78,89 @@ def stack(x, dim: dict[str, Sequence[str]] | None = None, **dims: Sequence[str]) ...@@ -76,6 +78,89 @@ def stack(x, dim: dict[str, Sequence[str]] | None = None, **dims: Sequence[str])
return y return y
class UnStack(XOp):
__props__ = ("old_dim_name", "unstacked_dims")
def __init__(
self,
old_dim_name: str,
unstacked_dims: tuple[str, ...],
):
super().__init__()
if old_dim_name in unstacked_dims:
raise ValueError(
f"Dim to be unstacked {old_dim_name} can't be in {unstacked_dims}"
)
if not unstacked_dims:
raise ValueError("Dims to unstack into can't be empty.")
if len(unstacked_dims) == 1:
raise ValueError("Only one dimension to unstack into, use rename instead")
self.old_dim_name = old_dim_name
self.unstacked_dims = unstacked_dims
def make_node(self, x, *unstacked_length):
x = as_xtensor(x)
if self.old_dim_name not in x.type.dims:
raise ValueError(
f"Dim to unstack {self.old_dim_name} must be in {x.type.dims}"
)
if not set(self.unstacked_dims).isdisjoint(x.type.dims):
raise ValueError(
f"Dims to unstack into {self.unstacked_dims} must not be in {x.type.dims}"
)
if len(unstacked_length) != len(self.unstacked_dims):
raise ValueError(
f"Number of unstacked lengths {len(unstacked_length)} must match number of unstacked dims {len(self.unstacked_dims)}"
)
unstacked_lengths = [as_tensor(length, ndim=0) for length in unstacked_length]
if not all(length.dtype in discrete_dtypes for length in unstacked_lengths):
raise TypeError("Unstacked lengths must be discrete dtypes.")
if x.type.ndim == 1:
batch_dims, batch_shape = (), ()
else:
batch_dims, batch_shape = zip(
*(
(dim, shape)
for dim, shape in zip(x.type.dims, x.type.shape)
if dim != self.old_dim_name
)
)
static_unstacked_lengths = [None] * len(unstacked_lengths)
for i, length in enumerate(unstacked_lengths):
try:
static_length = get_scalar_constant_value(length)
except NotScalarConstantError:
pass
else:
static_unstacked_lengths[i] = int(static_length)
output = xtensor(
dtype=x.type.dtype,
shape=(*batch_shape, *static_unstacked_lengths),
dims=(*batch_dims, *self.unstacked_dims),
)
return Apply(self, [x, *unstacked_lengths], [output])
def unstack(x, dim: dict[str, dict[str, int]] | None = None, **dims: dict[str, int]):
if dim is not None:
if dims:
raise ValueError(
"Cannot use both positional dim and keyword dims in unstack"
)
dims = dim
y = x
for old_dim_name, unstacked_dict in dims.items():
y = UnStack(old_dim_name, tuple(unstacked_dict.keys()))(
y, *tuple(unstacked_dict.values())
)
return y
class Transpose(XOp): class Transpose(XOp):
__props__ = ("dims",) __props__ = ("dims",)
......
...@@ -523,6 +523,9 @@ class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]): ...@@ -523,6 +523,9 @@ class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]):
def stack(self, dim, **dims): def stack(self, dim, **dims):
return px.shape.stack(self, dim, **dims) return px.shape.stack(self, dim, **dims)
def unstack(self, dim, **dims):
return px.shape.unstack(self, dim, **dims)
class XTensorConstantSignature(TensorConstantSignature): class XTensorConstantSignature(TensorConstantSignature):
pass pass
......
...@@ -8,9 +8,10 @@ import re ...@@ -8,9 +8,10 @@ import re
from itertools import chain, combinations from itertools import chain, combinations
import numpy as np import numpy as np
from xarray import DataArray
from xarray import concat as xr_concat from xarray import concat as xr_concat
from pytensor.xtensor.shape import concat, stack, transpose from pytensor.xtensor.shape import concat, stack, transpose, unstack
from pytensor.xtensor.type import xtensor from pytensor.xtensor.type import xtensor
from tests.xtensor.util import ( from tests.xtensor.util import (
xr_arange_like, xr_arange_like,
...@@ -154,6 +155,49 @@ def test_multiple_stacks(): ...@@ -154,6 +155,49 @@ def test_multiple_stacks():
xr_assert_allclose(res[0], expected_res) xr_assert_allclose(res[0], expected_res)
def test_unstack_constant_size():
x = xtensor("x", dims=("a", "bc", "d"), shape=(2, 3 * 5, 7))
y = unstack(x, bc=dict(b=3, c=5))
assert y.type.dims == ("a", "d", "b", "c")
assert y.type.shape == (2, 7, 3, 5)
fn = xr_function([x], y)
x_test = xr_arange_like(x)
x_np = x_test.values
res = fn(x_test)
expected = (
DataArray(x_np.reshape(2, 3, 5, 7), dims=("a", "b", "c", "d"))
.stack(bc=("b", "c"))
.unstack("bc")
)
xr_assert_allclose(res, expected)
def test_unstack_symbolic_size():
x = xtensor(dims=("a", "b", "c"))
y = stack(x, bc=("b", "c"))
y = y / y.sum("bc")
z = unstack(y, bc={"b": x.sizes["b"], "c": x.sizes["c"]})
x_test = xr_arange_like(xtensor(dims=x.dims, shape=(2, 3, 5)))
fn = xr_function([x], z)
res = fn(x_test)
expected_res = x_test / x_test.sum(["b", "c"])
xr_assert_allclose(res, expected_res)
def test_stack_unstack():
x = xtensor("x", dims=("a", "b", "c", "d"), shape=(2, 3, 5, 7))
stack_x = stack(x, bd=("b", "d"))
unstack_x = unstack(stack_x, bd=dict(b=3, d=7))
x_test = xr_arange_like(x)
fn = xr_function([x], unstack_x)
res = fn(x_test)
expected_res = x_test.transpose("a", "c", "b", "d")
xr_assert_allclose(res, expected_res)
@pytest.mark.parametrize("dim", ("a", "b", "new")) @pytest.mark.parametrize("dim", ("a", "b", "new"))
def test_concat(dim): def test_concat(dim):
rng = np.random.default_rng(sum(map(ord, dim))) rng = np.random.default_rng(sum(map(ord, dim)))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论