提交 9d99267c authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Implement vectorize_node for XOps

上级 9dd929ab
......@@ -163,7 +163,7 @@ lines-after-imports = 2
"tests/link/numba/**/test_*.py" = ["E402"]
"tests/link/pytorch/**/test_*.py" = ["E402"]
"tests/link/mlx/**/test_*.py" = ["E402"]
"tests/xtensor/**/test_*.py" = ["E402"]
"tests/xtensor/**/*.py" = ["E402"]
......
......@@ -2,6 +2,7 @@ from collections.abc import Sequence
from pytensor.compile.ops import TypeCastingOp
from pytensor.graph import Apply, Op
from pytensor.graph.basic import Variable
from pytensor.tensor.type import TensorType
from pytensor.xtensor.type import XTensorType, as_xtensor, xtensor
......@@ -17,6 +18,9 @@ class XOp(Op):
def do_constant_folding(self, fgraph, node):
return False
def vectorize_node(self, node, *new_inputs) -> Sequence[Variable]:
raise NotImplementedError(f"Vectorized node not implemented for {self}")
class XTypeCastOp(TypeCastingOp):
"""Base class for Ops that type cast between TensorType and XTensorType.
......@@ -27,6 +31,9 @@ class XTypeCastOp(TypeCastingOp):
def infer_shape(self, fgraph, node, input_shapes):
return input_shapes
def vectorize_node(self, node, *new_inputs) -> Sequence[Variable]:
raise NotImplementedError(f"Vectorized node not implemented for {self}")
class TensorFromXTensor(XTypeCastOp):
__props__ = ()
......@@ -42,6 +49,16 @@ class TensorFromXTensor(XTypeCastOp):
[g_out] = g_outs
return [xtensor_from_tensor(g_out, dims=x.type.dims)]
def vectorize_node(self, node, new_x):
[old_x] = node.inputs
if (new_x.ndim - old_x.ndim) > 1:
raise NotImplementedError(
f"Vectorization of {self} cannot guarantee correct placement of multiple batch dimensions. "
"You can call vectorize_graph one batch dimension at a time."
)
new_x = new_x.transpose(..., *old_x.dims)
return [self(new_x)]
tensor_from_xtensor = TensorFromXTensor()
......@@ -63,6 +80,15 @@ class XTensorFromTensor(XTypeCastOp):
[g_out] = g_outs
return [tensor_from_xtensor(g_out)]
def vectorize_node(self, node, new_x):
[old_x] = node.inputs
if new_x.ndim != old_x.ndim:
raise NotImplementedError(
f"Vectorization of {self} with batched inputs not implemented, "
"as it can't infer new dimension labels"
)
return [self(new_x)]
def xtensor_from_tensor(x, dims, name=None):
return XTensorFromTensor(dims=dims)(x, name=name)
......@@ -85,6 +111,16 @@ class Rename(XTypeCastOp):
[g_out] = g_outs
return [rename(g_out, dims=x.type.dims)]
def vectorize_node(self, node, new_x):
[old_x] = node.inputs
old_dim_mapping = dict(zip(old_x.dims, self.new_dims, strict=True))
# new_dims may include a mix of old dims (possibly re-ordered), and new dims which won't be renamed
new_dims = tuple(
old_dim_mapping.get(new_dim, new_dim) for new_dim in new_x.dims
)
return [type(self)(new_dims)(new_x)]
def rename(x, name_dict: dict[str, str] | None = None, **names: str):
if name_dict is not None:
......
......@@ -4,6 +4,7 @@
# https://numpy.org/neps/nep-0021-advanced-indexing.html
# https://docs.xarray.dev/en/latest/user-guide/indexing.html
# https://tutorial.xarray.dev/intermediate/indexing/advanced-indexing.html
from itertools import chain
from typing import Literal
from pytensor.graph.basic import Apply, Constant, Variable
......@@ -11,6 +12,7 @@ from pytensor.scalar.basic import discrete_dtypes
from pytensor.tensor.basic import as_tensor
from pytensor.tensor.type_other import NoneTypeT, SliceType, make_slice
from pytensor.xtensor.basic import XOp, xtensor_from_tensor
from pytensor.xtensor.shape import broadcast
from pytensor.xtensor.type import XTensorType, as_xtensor, xtensor
......@@ -195,6 +197,15 @@ class Index(XOp):
output = xtensor(dtype=x.type.dtype, shape=out_shape, dims=out_dims)
return Apply(self, [x, *idxs], [output])
def vectorize_node(self, node, new_x, *new_idxs):
# new_x may have dims in different order
# we pair each pre-existing dim to the respective index
# with new dims having simply a slice(None)
old_x, *_ = node.inputs
dims_to_idxs = dict(zip(old_x.dims, new_idxs, strict=False))
new_idxs = tuple(dims_to_idxs.get(dim, slice(None)) for dim in new_x.dims)
return [self(new_x, *new_idxs)]
index = Index()
......@@ -226,6 +237,29 @@ class IndexUpdate(XOp):
out = x.type()
return Apply(self, [x, y, *idxs], [out])
def vectorize_node(self, node, *new_inputs):
# If y or the indices have new dimensions we need to broadcast_x
exclude: set[str] = set(
chain.from_iterable(
old_inp.dims
for old_inp in node.inputs
if isinstance(old_inp.type, XTensorType)
)
)
old_x, *_ = node.inputs
new_x, *_ = broadcast(
*[
new_inp
for new_inp in new_inputs
if isinstance(new_inp.type, XTensorType)
],
exclude=tuple(exclude),
)
# New batch dimensions must go on the right since indices map to indexed dimensions positionally in the Op
new_x = new_x.transpose(*old_x.dims, ...)
_, new_y, *new_idxs = new_inputs
return [self(new_x, new_y, *new_idxs)]
index_assignment = IndexUpdate("set")
index_increment = IndexUpdate("inc")
......@@ -46,6 +46,9 @@ class XReduce(XOp):
output = xtensor(dtype=x.type.dtype, shape=out_shape, dims=out_dims)
return Apply(self, [x], [output])
def vectorize_node(self, node, new_x):
return [self(new_x)]
def _process_user_dims(x: XTensorVariable, dim: REDUCE_DIM) -> Sequence[str]:
if isinstance(dim, str):
......@@ -117,6 +120,9 @@ class XCumReduce(XOp):
out = x.type()
return Apply(self, [x], [out])
def vectorize_node(self, node, new_x):
return [self(new_x)]
def cumreduce(x, dim: REDUCE_DIM, *, binary_op):
x = as_xtensor(x)
......
......@@ -68,6 +68,9 @@ class Stack(XOp):
)
return Apply(self, [x], [output])
def vectorize_node(self, node, new_x):
return [self(new_x)]
def stack(x, dim: dict[str, Sequence[str]] | None = None, **dims: Sequence[str]):
if dim is not None:
......@@ -146,6 +149,14 @@ class UnStack(XOp):
)
return Apply(self, [x, *unstacked_lengths], [output])
def vectorize_node(self, node, new_x, *new_unstacked_length):
new_unstacked_length = [ul.squeeze() for ul in new_unstacked_length]
if not all(ul.type.ndim == 0 for ul in new_unstacked_length):
raise NotImplementedError(
f"Vectorization of {self} with batched unstacked_length not implemented, "
)
return [self(new_x, *new_unstacked_length)]
def unstack(x, dim: dict[str, dict[str, int]] | None = None, **dims: dict[str, int]):
if dim is not None:
......@@ -189,6 +200,11 @@ class Transpose(XOp):
)
return Apply(self, [x], [output])
def vectorize_node(self, node, new_x):
old_dims = self.dims
new_dims = tuple(dim for dim in new_x.dims if dim not in old_dims)
return [type(self)(dims=(*new_dims, *old_dims))(new_x)]
def transpose(
x,
......@@ -302,6 +318,9 @@ class Concat(XOp):
output = xtensor(dtype=dtype, dims=dims, shape=shape)
return Apply(self, inputs, [output])
def vectorize_node(self, node, *new_inputs):
return [self(*new_inputs)]
def concat(xtensors, dim: str):
"""Concatenate a sequence of XTensorVariables along a specified dimension.
......@@ -383,6 +402,9 @@ class Squeeze(XOp):
)
return Apply(self, [x], [out])
def vectorize_node(self, node, new_x):
return [self(new_x)]
def squeeze(x, dim: str | Sequence[str] | None = None):
"""Remove dimensions of size 1 from an XTensorVariable."""
......@@ -442,6 +464,14 @@ class ExpandDims(XOp):
)
return Apply(self, [x, size], [out])
def vectorize_node(self, node, new_x, new_size):
new_size = new_size.squeeze()
if new_size.type.ndim != 0:
raise NotImplementedError(
f"Vectorization of {self} with batched new_size not implemented, "
)
return [self(new_x, new_size)]
def expand_dims(x, dim=None, axis=None, **dim_kwargs):
"""Add one or more new dimensions to an XTensorVariable."""
......@@ -537,6 +567,19 @@ class Broadcast(XOp):
return Apply(self, inputs, outputs)
def vectorize_node(self, node, *new_inputs):
if exclude_set := set(self.exclude):
for new_x, old_x in zip(node.inputs, new_inputs, strict=True):
if invalid_excluded := (
(set(new_x.dims) - set(old_x.dims)) & exclude_set
):
raise NotImplementedError(
f"Vectorize of {self} is undefined because one of the inputs {new_x} "
f"has an excluded dimension {sorted(invalid_excluded)} that it did not have before."
)
return self(*new_inputs, return_list=True)
def broadcast(
*args, exclude: str | Sequence[str] | None = None
......
......@@ -1044,7 +1044,7 @@ def as_xtensor(x, dims: Sequence[str] | None = None, *, name: str | None = None)
if isinstance(x, Variable):
if isinstance(x.type, XTensorType):
if (dims is None) or (x.type.dims == dims):
if (dims is None) or (x.type.dims == tuple(dims)):
return x
else:
raise ValueError(
......
......@@ -6,6 +6,8 @@ import numpy as np
from pytensor import scalar as ps
from pytensor import shared
from pytensor.graph import Apply, Op
from pytensor.graph.basic import Variable
from pytensor.graph.replace import _vectorize_node
from pytensor.scalar import discrete_dtypes
from pytensor.tensor import tensor
from pytensor.tensor.random.op import RNGConsumerOp
......@@ -14,8 +16,11 @@ from pytensor.tensor.utils import (
get_static_shape_from_size_variables,
)
from pytensor.utils import unzip
from pytensor.xtensor.basic import XOp
from pytensor.xtensor.type import XTensorVariable, as_xtensor, xtensor
from pytensor.xtensor.basic import (
XOp,
XTypeCastOp,
)
from pytensor.xtensor.type import XTensorType, XTensorVariable, as_xtensor, xtensor
def combine_dims_and_shape(
......@@ -74,6 +79,9 @@ class XElemwise(XOp):
]
return Apply(self, inputs, outputs)
def vectorize_node(self, node, *new_inputs):
return self(*new_inputs, return_list=True)
class XBlockwise(XOp):
__props__ = ("core_op", "core_dims")
......@@ -141,6 +149,9 @@ class XBlockwise(XOp):
]
return Apply(self, inputs, outputs)
def vectorize_node(self, node, *new_inputs):
return self(*new_inputs, return_list=True)
class XRV(XOp, RNGConsumerOp):
"""Wrapper for RandomVariable operations that follows xarray-like broadcasting semantics.
......@@ -288,3 +299,54 @@ class XRV(XOp, RNGConsumerOp):
)
return Apply(self, [rng, *extra_dim_lengths, *params], [rng.type(), out])
def vectorize_node(self, node, *new_inputs):
new_rng, *new_extra_dim_lengths_and_params = new_inputs
k = len(self.extra_dims)
new_extra_dim_lengths, new_params = (
new_extra_dim_lengths_and_params[:k],
new_extra_dim_lengths_and_params[k:],
)
new_extra_dim_lengths = [dl.squeeze() for dl in new_extra_dim_lengths]
if not all(dl.type.ndim == 0 for dl in new_extra_dim_lengths):
raise NotImplementedError(
f"Vectorization of {self} with batched extra_dim_lengths not implemented, "
)
return self.make_node(new_rng, *new_extra_dim_lengths, *new_params).outputs
@_vectorize_node.register(XOp)
@_vectorize_node.register(XTypeCastOp)
def vectorize_xop(op: XOp, node, *new_inputs) -> Sequence[Variable]:
old_inp_dims = [
inp.dims for inp in node.inputs if isinstance(inp.type, XTensorType)
]
old_out_dims = [
out.dims for out in node.outputs if isinstance(out.type, XTensorType)
]
all_old_dims_set = set(chain.from_iterable((*old_inp_dims, old_out_dims)))
for new_inp, old_inp in zip(new_inputs, node.inputs, strict=True):
if not (
isinstance(new_inp.type, XTensorType)
and isinstance(old_inp.type, XTensorType)
):
continue
old_dims_set = set(old_inp.dims)
new_dims_set = set(new_inp.dims)
# Validate that new inputs didn't drop pre-existing dims
if missing_dims := old_dims_set - new_dims_set:
raise ValueError(
f"Vectorized input {new_inp} is missing pre-existing dims: {sorted(missing_dims)}"
)
# Or have new dimensions that were already in the graph
if new_core_dims := ((new_dims_set - old_dims_set) & all_old_dims_set):
raise ValueError(
f"Vectorized input {new_inp} has new dimensions that were present in the original graph: {new_core_dims}"
)
return op.vectorize_node(node, *new_inputs)
import pytest
pytest.importorskip("xarray")
import numpy as np
from pytensor import function
from pytensor.xtensor.basic import Rename
from pytensor.graph import vectorize_graph
from pytensor.tensor import matrix, vector
from pytensor.xtensor.basic import (
Rename,
rename,
tensor_from_xtensor,
xtensor_from_tensor,
)
from pytensor.xtensor.type import xtensor
from tests.unittest_tools import assert_equal_computations
# from pytensor.xtensor.vectorization import vectorize_graph
from tests.xtensor.util import check_vectorization
def test_shape_feature_does_not_see_xop():
......@@ -24,3 +40,36 @@ def test_shape_feature_does_not_see_xop():
fn = function([x], out)
np.testing.assert_allclose(fn([1, 2, 3]), [0, 0, 0])
assert not CALLED
def test_rename_vectorize():
ab = xtensor("ab", dims=("a", "b"), shape=(2, 3), dtype="float64")
check_vectorization(ab, rename(ab, a="c"))
def test_xtensor_from_tensor_vectorize():
t = vector("t")
x = xtensor_from_tensor(t, dims=("a",))
t_batched = matrix("t_batched")
with pytest.raises(
NotImplementedError, match=r"Vectorization of .* not implemented"
):
vectorize_graph([x], {t: t_batched})
def test_tensor_from_xtensor_vectorize():
x = xtensor("x", dims=("a",), shape=(3,))
y = tensor_from_xtensor(x)
x_batched = xtensor("x", dims=("a", "b"), shape=(3, 5))
y_batched = vectorize_graph(y, {x: x_batched})
# vectorize_graph should place output batch dimension on the left
assert y_batched.type.shape == (5, 3)
assert_equal_computations([y_batched], [x_batched.transpose("b", ...).values])
x_batched = xtensor("x", dims=("c", "a", "b"), shape=(7, 3, 5))
# vectorize_graph can't handle multiple batch dimensions safely
with pytest.raises(NotImplementedError):
vectorize_graph(y, {x: x_batched})
......@@ -14,6 +14,7 @@ from pytensor.tensor import tensor
from pytensor.xtensor import xtensor
from tests.unittest_tools import assert_equal_computations
from tests.xtensor.util import (
check_vectorization,
xr_arange_like,
xr_assert_allclose,
xr_function,
......@@ -542,3 +543,43 @@ def test_empty_update_index():
fn = xr_function([x], out1)
x_test = xr_random_like(x)
xr_assert_allclose(fn(x_test), x_test + 1)
def test_indexing_vectorize():
abc = xtensor(dims=("a", "b", "c"), shape=(3, 5, 7))
a_idx = xtensor(dims=("a",), shape=(5,), dtype="int64")
c_idx = xtensor(dims=("c",), shape=(3,), dtype="int64")
abc_val = xr_random_like(abc)
a_idx_val = DataArray([0, 1, 0, 2, 0], dims=("a",))
c_idx_val = DataArray([0, 5, 6], dims=("c",))
check_vectorization([abc, a_idx], [abc.isel(a=a_idx)], [abc_val, a_idx_val])
check_vectorization(
[abc, a_idx], [abc.isel(a=a_idx.rename(a="b"))], [abc_val, a_idx_val]
)
check_vectorization(
[abc, a_idx], [abc.isel(a=a_idx.rename(a="d"))], [abc_val, a_idx_val]
)
check_vectorization([abc, a_idx], [abc.isel(c=a_idx[:3])], [abc_val, a_idx_val])
check_vectorization(
[abc, a_idx], [abc.isel(a=a_idx, c=a_idx)], [abc_val, a_idx_val]
)
check_vectorization(
[abc, a_idx, c_idx],
[abc.isel(a=a_idx, c=c_idx)],
[abc_val, a_idx_val, c_idx_val],
)
def test_index_update_vectorize():
x = xtensor("x", dims=("a", "b"), shape=(3, 5))
idx = xtensor("idx", dims=("b*",), shape=(7,), dtype=int)
y = xtensor("y", dims=("b*",), shape=(7,))
x_val = xr_random_like(x)
idx_val = DataArray([2, 0, 4, 0, 1, 0, 3], dims=("b*",))
y_val = xr_random_like(y)
check_vectorization([x, idx, y], [x.isel(b=idx).set(y)], [x_val, idx_val, y_val])
check_vectorization([x, idx, y], [x.isel(b=idx).inc(y)], [x_val, idx_val, y_val])
......@@ -16,7 +16,7 @@ from xarray_einstats.linalg import (
from pytensor.xtensor.linalg import cholesky, solve
from pytensor.xtensor.type import xtensor
from tests.xtensor.util import xr_assert_allclose, xr_function
from tests.xtensor.util import check_vectorization, xr_assert_allclose, xr_function
def test_cholesky():
......@@ -74,3 +74,22 @@ def test_solve_matrix_b():
fn(a_test, b_test),
xr_solve(a_test, b_test, dims=["country", "city", "district"]),
)
def test_linalg_vectorize():
# Note: We only need to test a couple Ops, since the vectorization logic is not Op specific
a = xtensor("b", dims=("a",), shape=(3,))
ab = xtensor("a", dims=("a", "b"), shape=(3, 3))
test_spd = np.random.randn(3, 3)
test_spd = test_spd @ test_spd.T
check_vectorization(
[ab],
[cholesky(ab, dims=("b", "a"))],
input_vals=[DataArray(test_spd, dims=("a", "b"))],
)
check_vectorization(
[ab, a],
[solve(ab, a, dims=("a", "b"))],
)
......@@ -17,7 +17,12 @@ from pytensor.scalar import ScalarOp
from pytensor.xtensor.basic import rename
from pytensor.xtensor.math import add, exp, logsumexp
from pytensor.xtensor.type import xtensor
from tests.xtensor.util import xr_arange_like, xr_assert_allclose, xr_function
from tests.xtensor.util import (
check_vectorization,
xr_arange_like,
xr_assert_allclose,
xr_function,
)
def test_all_scalar_ops_are_wrapped():
......@@ -340,3 +345,11 @@ def test_dot_errors():
match=r"(Input operand 1 has a mismatch in its core dimension 0|incompatible array sizes for np.dot)",
):
fn(x_test, y_test)
def test_xelemwise_vectorize():
ab = xtensor("ab", dims=("a", "b"), shape=(2, 3))
bc = xtensor("bc", dims=("b", "c"), shape=(3, 5))
check_vectorization([ab], [exp(ab)])
check_vectorization([ab, bc], [ab + bc])
......@@ -9,6 +9,7 @@ import re
from copy import deepcopy
import numpy as np
from xarray import DataArray
import pytensor.tensor.random as ptr
import pytensor.xtensor.random as pxr
......@@ -26,6 +27,7 @@ from pytensor.xtensor.random import (
normal,
)
from pytensor.xtensor.vectorization import XRV
from tests.xtensor.util import check_vectorization
def lower_rewrite(vars):
......@@ -438,3 +440,27 @@ def test_multivariate_normal():
):
# cov must have both core_dims
multivariate_normal(mu_xr, cov_xr, core_dims=("rows", "missing_cols"))
def test_xrv_vectorize():
# Note: We only need to test a couple Ops, since the vectorization logic is not Op specific
n = xtensor("n", dims=("n",), shape=(3,), dtype=int)
pna = xtensor("p", dims=("p", "n", "a"), shape=(5, 3, 2))
out = multinomial(n, pna, core_dims=("p",), extra_dims={"extra": 5})
check_vectorization(
[n, pna],
[out],
input_vals=[
DataArray([3, 5, 10], dims=("n",)),
DataArray(
np.random.multinomial(n=1, pvals=np.ones(5) / 5, size=(2, 3)).T,
dims=("p", "n", "a"),
),
],
)
def test_xrv_batch_extra_dim_vectorize():
# TODO: Check it raises NotImplementedError when we try to batch the extra_dim of an xrv
pass
......@@ -8,7 +8,12 @@ import numpy as np
import xarray as xr
from pytensor.xtensor.type import xtensor
from tests.xtensor.util import xr_arange_like, xr_assert_allclose, xr_function
from tests.xtensor.util import (
check_vectorization,
xr_arange_like,
xr_assert_allclose,
xr_function,
)
@pytest.mark.parametrize(
......@@ -99,3 +104,16 @@ def test_discrete_reduction_upcasting(signed):
res = fn(x_val)
np.testing.assert_allclose(res, [test_val, test_val**2])
xr_assert_allclose(res, x_val.cumprod())
def test_reduction_vectorize():
# Note: We only need to test a couple Ops, since the vectorization logic is not Op specific
abc = xtensor("x", dims=("a", "b", "c"), shape=(3, 5, 7))
check_vectorization([abc], [abc.sum(dim="a")])
check_vectorization([abc], [abc.max(dim=("a", "c"))])
check_vectorization([abc], [abc.all()])
check_vectorization([abc], [abc.cumsum(dim="b")])
check_vectorization([abc], [abc.cumsum(dim=("c", "b"))])
check_vectorization([abc], [abc.cumprod()])
......@@ -15,7 +15,8 @@ from xarray import full_like as xr_full_like
from xarray import ones_like as xr_ones_like
from xarray import zeros_like as xr_zeros_like
from pytensor.tensor import scalar
from pytensor.graph import vectorize_graph
from pytensor.tensor import scalar, vector
from pytensor.xtensor.shape import (
broadcast,
concat,
......@@ -25,8 +26,9 @@ from pytensor.xtensor.shape import (
unstack,
zeros_like,
)
from pytensor.xtensor.type import xtensor
from pytensor.xtensor.type import as_xtensor, xtensor
from tests.xtensor.util import (
check_vectorization,
xr_arange_like,
xr_assert_allclose,
xr_function,
......@@ -800,3 +802,90 @@ def test_zeros_like():
expected1 = xr_zeros_like(x_test)
xr_assert_allclose(result1, expected1)
assert result1.dtype == expected1.dtype
def test_shape_ops_vectorize():
a1 = xtensor("a1", dims=("a", "1"), shape=(2, 1), dtype="float64")
ab = xtensor("ab", dims=("a", "b"), shape=(2, 3), dtype="float64")
abc = xtensor("abc", dims=("a", "b", "c"), shape=(2, 3, 5), dtype="float64")
a_bc_d = xtensor("a_bc_d", dims=("a", "bc", "d"), shape=(4, 15, 7))
check_vectorization(abc, abc.transpose("b", "c", "a"))
check_vectorization(abc, abc.transpose("b", ...))
check_vectorization(abc, stack(abc, new_dim=("a", "c")))
check_vectorization(a_bc_d, unstack(a_bc_d, bc=dict(b=3, c=5)))
check_vectorization([abc, ab], concat([abc, ab], dim="a"))
check_vectorization(a1, a1.squeeze("1"))
check_vectorization(abc, abc.expand_dims(d=5))
check_vectorization([ab, abc], broadcast(ab, abc))
check_vectorization([ab, abc, a1], broadcast(ab, abc, a1, exclude="1"))
# a is longer in a_bc_d than in ab and abc, helper can't handle that
# check_vectorization([ab, abc, a_bc_d], broadcast(ab, abc, a_bc_d, exclude="a"))
def test_broadcast_exclude_vectorize():
x = xtensor("x", dims=("a", "b"), shape=(3, 4))
y = xtensor("y", dims=("b", "c"), shape=(7, 5))
out_x, out_y = broadcast(x, y, exclude=("b",))
x_val = xr_random_like(x)
y_val = xr_random_like(y)
x_batch_val = x_val.expand_dims({"batch": 2})
y_batch_val = y_val.expand_dims({"batch": 2})
x_batch = as_xtensor(x_batch_val).type("x_batch")
y_batch = as_xtensor(y_batch_val).type("y_batch")
[out_x_vec, out_y_vec] = vectorize_graph([out_x, out_y], {x: x_batch, y: y_batch})
fn = xr_function([x_batch, y_batch], [out_x_vec, out_y_vec])
res_x, res_y = fn(x_batch_val, y_batch_val)
expected_x = []
expected_y = []
for i in range(2):
ex_x, ex_y = xr_broadcast(
x_batch_val.isel(batch=i), y_batch_val.isel(batch=i), exclude=("b",)
)
expected_x.append(ex_x)
expected_y.append(ex_y)
expected_x = xr_concat(expected_x, dim="batch")
expected_y = xr_concat(expected_y, dim="batch")
xr_assert_allclose(res_x, expected_x)
xr_assert_allclose(res_y, expected_y)
def test_expand_dims_batch_length_vectorize():
x = xtensor("x", dims=("a",), shape=(3,))
l = scalar("l", dtype="int64")
y = x.expand_dims(b=l)
x_batch = as_xtensor(xr_random_like(x).expand_dims(batch=2)).type("x_batch")
l_batch = vector("l_batch", dtype="int64")
with pytest.raises(
NotImplementedError, match=r"Vectorization of .* not implemented"
):
vectorize_graph([y], {x: x_batch, l: l_batch})
def test_unstack_batch_length_vectorize():
x = xtensor("x", dims=("ab",), shape=(12,))
l = scalar("l", dtype="int64")
y = unstack(x, ab={"a": l, "b": x.sizes["ab"] // l})
x_batch = as_xtensor(xr_random_like(x).expand_dims(batch=2)).type("x_batch")
l_batch = vector("l_batch", dtype="int64")
with pytest.raises(
NotImplementedError, match=r"Vectorization of .* not implemented"
):
vectorize_graph([y], {x: x_batch, l: l_batch})
......@@ -12,7 +12,12 @@ from xarray import apply_ufunc
from pytensor.xtensor.signal import convolve1d
from pytensor.xtensor.type import xtensor
from tests.xtensor.util import xr_arange_like, xr_assert_allclose, xr_function
from tests.xtensor.util import (
check_vectorization,
xr_arange_like,
xr_assert_allclose,
xr_function,
)
@pytest.mark.parametrize("mode", ("full", "valid", "same"))
......@@ -68,3 +73,14 @@ def test_convolve_1d_invalid():
match=re.escape("Input 1 has invalid core dims ['time']. Allowed: ('kernel',)"),
):
convolve1d(in1, in2.rename({"batch": "time"}), dims=("time", "kernel"))
def test_signal_vectorize():
# Note: We only need to test a couple Ops, since the vectorization logic is not Op specific
ab = xtensor("a", dims=("a", "b"), shape=(3, 3))
c = xtensor(name="c", dims=("c",), shape=(7,))
check_vectorization(
[ab, c],
[convolve1d(ab, c, dims=("a", "c"))],
)
import pytest
pytest.importorskip("xarray")
xr = pytest.importorskip("xarray")
from itertools import chain
import numpy as np
from xarray import DataArray
from xarray.testing import assert_allclose
from pytensor import function
from pytensor.xtensor.type import XTensorType
from pytensor.graph import vectorize_graph
from pytensor.xtensor.type import XTensorType, as_xtensor
def xr_function(*args, **kwargs):
......@@ -76,3 +79,57 @@ def xr_random_like(x, rng=None):
return DataArray(
rng.standard_normal(size=x.type.shape, dtype=x.type.dtype), dims=x.type.dims
)
def check_vectorization(inputs, outputs, input_vals=None, rng=None):
# Create core graph and function
if not isinstance(inputs, list | tuple):
inputs = (inputs,)
if not isinstance(outputs, list | tuple):
outputs = (outputs,)
# apply_ufunc isn't happy with list output or single entry
_core_fn = function(inputs, outputs)
def core_fn(*args, _core_fn=_core_fn):
res = _core_fn(*args)
if len(res) == 1:
return res[0]
else:
return tuple(res)
if input_vals is None:
rng = np.random.default_rng(rng)
input_vals = [xr_random_like(inp, rng) for inp in inputs]
# Create vectorized inputs
batch_inputs = []
batch_input_vals = []
for i, (inp, val) in enumerate(zip(inputs, input_vals)):
new_val = val.expand_dims({f"batch_{i}": 2 ** (i + 1)})
new_inp = as_xtensor(new_val).type(f"batch_{inp.name or f'input{i}'}")
batch_inputs.append(new_inp)
batch_input_vals.append(new_val)
# Create vectorized function
new_outputs = vectorize_graph(outputs, dict(zip(inputs, batch_inputs)))
vec_fn = xr_function(batch_inputs, new_outputs)
vec_res = vec_fn(*batch_input_vals)
# xarray.apply_ufunc with vectorize=True loops over non-core dims
input_core_dims = [i.dims for i in inputs]
output_core_dims = [o.dims for o in outputs]
expected_res = xr.apply_ufunc(
core_fn,
*batch_input_vals,
input_core_dims=input_core_dims,
output_core_dims=output_core_dims,
exclude_dims=set(chain.from_iterable((*input_core_dims, *output_core_dims))),
vectorize=True,
)
if not isinstance(expected_res, list | tuple):
expected_res = (expected_res,)
for v_r, e_r in zip(vec_res, expected_res):
xr_assert_allclose(v_r, e_r.transpose(*v_r.dims))
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论