提交 9d99267c authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Implement vectorize_node for XOps

上级 9dd929ab
...@@ -163,7 +163,7 @@ lines-after-imports = 2 ...@@ -163,7 +163,7 @@ lines-after-imports = 2
"tests/link/numba/**/test_*.py" = ["E402"] "tests/link/numba/**/test_*.py" = ["E402"]
"tests/link/pytorch/**/test_*.py" = ["E402"] "tests/link/pytorch/**/test_*.py" = ["E402"]
"tests/link/mlx/**/test_*.py" = ["E402"] "tests/link/mlx/**/test_*.py" = ["E402"]
"tests/xtensor/**/test_*.py" = ["E402"] "tests/xtensor/**/*.py" = ["E402"]
......
...@@ -2,6 +2,7 @@ from collections.abc import Sequence ...@@ -2,6 +2,7 @@ from collections.abc import Sequence
from pytensor.compile.ops import TypeCastingOp from pytensor.compile.ops import TypeCastingOp
from pytensor.graph import Apply, Op from pytensor.graph import Apply, Op
from pytensor.graph.basic import Variable
from pytensor.tensor.type import TensorType from pytensor.tensor.type import TensorType
from pytensor.xtensor.type import XTensorType, as_xtensor, xtensor from pytensor.xtensor.type import XTensorType, as_xtensor, xtensor
...@@ -17,6 +18,9 @@ class XOp(Op): ...@@ -17,6 +18,9 @@ class XOp(Op):
def do_constant_folding(self, fgraph, node): def do_constant_folding(self, fgraph, node):
return False return False
def vectorize_node(self, node, *new_inputs) -> Sequence[Variable]:
raise NotImplementedError(f"Vectorized node not implemented for {self}")
class XTypeCastOp(TypeCastingOp): class XTypeCastOp(TypeCastingOp):
"""Base class for Ops that type cast between TensorType and XTensorType. """Base class for Ops that type cast between TensorType and XTensorType.
...@@ -27,6 +31,9 @@ class XTypeCastOp(TypeCastingOp): ...@@ -27,6 +31,9 @@ class XTypeCastOp(TypeCastingOp):
def infer_shape(self, fgraph, node, input_shapes): def infer_shape(self, fgraph, node, input_shapes):
return input_shapes return input_shapes
def vectorize_node(self, node, *new_inputs) -> Sequence[Variable]:
raise NotImplementedError(f"Vectorized node not implemented for {self}")
class TensorFromXTensor(XTypeCastOp): class TensorFromXTensor(XTypeCastOp):
__props__ = () __props__ = ()
...@@ -42,6 +49,16 @@ class TensorFromXTensor(XTypeCastOp): ...@@ -42,6 +49,16 @@ class TensorFromXTensor(XTypeCastOp):
[g_out] = g_outs [g_out] = g_outs
return [xtensor_from_tensor(g_out, dims=x.type.dims)] return [xtensor_from_tensor(g_out, dims=x.type.dims)]
def vectorize_node(self, node, new_x):
[old_x] = node.inputs
if (new_x.ndim - old_x.ndim) > 1:
raise NotImplementedError(
f"Vectorization of {self} cannot guarantee correct placement of multiple batch dimensions. "
"You can call vectorize_graph one batch dimension at a time."
)
new_x = new_x.transpose(..., *old_x.dims)
return [self(new_x)]
tensor_from_xtensor = TensorFromXTensor() tensor_from_xtensor = TensorFromXTensor()
...@@ -63,6 +80,15 @@ class XTensorFromTensor(XTypeCastOp): ...@@ -63,6 +80,15 @@ class XTensorFromTensor(XTypeCastOp):
[g_out] = g_outs [g_out] = g_outs
return [tensor_from_xtensor(g_out)] return [tensor_from_xtensor(g_out)]
def vectorize_node(self, node, new_x):
[old_x] = node.inputs
if new_x.ndim != old_x.ndim:
raise NotImplementedError(
f"Vectorization of {self} with batched inputs not implemented, "
"as it can't infer new dimension labels"
)
return [self(new_x)]
def xtensor_from_tensor(x, dims, name=None): def xtensor_from_tensor(x, dims, name=None):
return XTensorFromTensor(dims=dims)(x, name=name) return XTensorFromTensor(dims=dims)(x, name=name)
...@@ -85,6 +111,16 @@ class Rename(XTypeCastOp): ...@@ -85,6 +111,16 @@ class Rename(XTypeCastOp):
[g_out] = g_outs [g_out] = g_outs
return [rename(g_out, dims=x.type.dims)] return [rename(g_out, dims=x.type.dims)]
def vectorize_node(self, node, new_x):
[old_x] = node.inputs
old_dim_mapping = dict(zip(old_x.dims, self.new_dims, strict=True))
# new_dims may include a mix of old dims (possibly re-ordered), and new dims which won't be renamed
new_dims = tuple(
old_dim_mapping.get(new_dim, new_dim) for new_dim in new_x.dims
)
return [type(self)(new_dims)(new_x)]
def rename(x, name_dict: dict[str, str] | None = None, **names: str): def rename(x, name_dict: dict[str, str] | None = None, **names: str):
if name_dict is not None: if name_dict is not None:
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
# https://numpy.org/neps/nep-0021-advanced-indexing.html # https://numpy.org/neps/nep-0021-advanced-indexing.html
# https://docs.xarray.dev/en/latest/user-guide/indexing.html # https://docs.xarray.dev/en/latest/user-guide/indexing.html
# https://tutorial.xarray.dev/intermediate/indexing/advanced-indexing.html # https://tutorial.xarray.dev/intermediate/indexing/advanced-indexing.html
from itertools import chain
from typing import Literal from typing import Literal
from pytensor.graph.basic import Apply, Constant, Variable from pytensor.graph.basic import Apply, Constant, Variable
...@@ -11,6 +12,7 @@ from pytensor.scalar.basic import discrete_dtypes ...@@ -11,6 +12,7 @@ from pytensor.scalar.basic import discrete_dtypes
from pytensor.tensor.basic import as_tensor from pytensor.tensor.basic import as_tensor
from pytensor.tensor.type_other import NoneTypeT, SliceType, make_slice from pytensor.tensor.type_other import NoneTypeT, SliceType, make_slice
from pytensor.xtensor.basic import XOp, xtensor_from_tensor from pytensor.xtensor.basic import XOp, xtensor_from_tensor
from pytensor.xtensor.shape import broadcast
from pytensor.xtensor.type import XTensorType, as_xtensor, xtensor from pytensor.xtensor.type import XTensorType, as_xtensor, xtensor
...@@ -195,6 +197,15 @@ class Index(XOp): ...@@ -195,6 +197,15 @@ class Index(XOp):
output = xtensor(dtype=x.type.dtype, shape=out_shape, dims=out_dims) output = xtensor(dtype=x.type.dtype, shape=out_shape, dims=out_dims)
return Apply(self, [x, *idxs], [output]) return Apply(self, [x, *idxs], [output])
def vectorize_node(self, node, new_x, *new_idxs):
# new_x may have dims in different order
# we pair each pre-existing dim to the respective index
# with new dims having simply a slice(None)
old_x, *_ = node.inputs
dims_to_idxs = dict(zip(old_x.dims, new_idxs, strict=False))
new_idxs = tuple(dims_to_idxs.get(dim, slice(None)) for dim in new_x.dims)
return [self(new_x, *new_idxs)]
index = Index() index = Index()
...@@ -226,6 +237,29 @@ class IndexUpdate(XOp): ...@@ -226,6 +237,29 @@ class IndexUpdate(XOp):
out = x.type() out = x.type()
return Apply(self, [x, y, *idxs], [out]) return Apply(self, [x, y, *idxs], [out])
def vectorize_node(self, node, *new_inputs):
# If y or the indices have new dimensions we need to broadcast_x
exclude: set[str] = set(
chain.from_iterable(
old_inp.dims
for old_inp in node.inputs
if isinstance(old_inp.type, XTensorType)
)
)
old_x, *_ = node.inputs
new_x, *_ = broadcast(
*[
new_inp
for new_inp in new_inputs
if isinstance(new_inp.type, XTensorType)
],
exclude=tuple(exclude),
)
# New batch dimensions must go on the right since indices map to indexed dimensions positionally in the Op
new_x = new_x.transpose(*old_x.dims, ...)
_, new_y, *new_idxs = new_inputs
return [self(new_x, new_y, *new_idxs)]
index_assignment = IndexUpdate("set") index_assignment = IndexUpdate("set")
index_increment = IndexUpdate("inc") index_increment = IndexUpdate("inc")
...@@ -46,6 +46,9 @@ class XReduce(XOp): ...@@ -46,6 +46,9 @@ class XReduce(XOp):
output = xtensor(dtype=x.type.dtype, shape=out_shape, dims=out_dims) output = xtensor(dtype=x.type.dtype, shape=out_shape, dims=out_dims)
return Apply(self, [x], [output]) return Apply(self, [x], [output])
def vectorize_node(self, node, new_x):
return [self(new_x)]
def _process_user_dims(x: XTensorVariable, dim: REDUCE_DIM) -> Sequence[str]: def _process_user_dims(x: XTensorVariable, dim: REDUCE_DIM) -> Sequence[str]:
if isinstance(dim, str): if isinstance(dim, str):
...@@ -117,6 +120,9 @@ class XCumReduce(XOp): ...@@ -117,6 +120,9 @@ class XCumReduce(XOp):
out = x.type() out = x.type()
return Apply(self, [x], [out]) return Apply(self, [x], [out])
def vectorize_node(self, node, new_x):
return [self(new_x)]
def cumreduce(x, dim: REDUCE_DIM, *, binary_op): def cumreduce(x, dim: REDUCE_DIM, *, binary_op):
x = as_xtensor(x) x = as_xtensor(x)
......
...@@ -68,6 +68,9 @@ class Stack(XOp): ...@@ -68,6 +68,9 @@ class Stack(XOp):
) )
return Apply(self, [x], [output]) return Apply(self, [x], [output])
def vectorize_node(self, node, new_x):
return [self(new_x)]
def stack(x, dim: dict[str, Sequence[str]] | None = None, **dims: Sequence[str]): def stack(x, dim: dict[str, Sequence[str]] | None = None, **dims: Sequence[str]):
if dim is not None: if dim is not None:
...@@ -146,6 +149,14 @@ class UnStack(XOp): ...@@ -146,6 +149,14 @@ class UnStack(XOp):
) )
return Apply(self, [x, *unstacked_lengths], [output]) return Apply(self, [x, *unstacked_lengths], [output])
def vectorize_node(self, node, new_x, *new_unstacked_length):
new_unstacked_length = [ul.squeeze() for ul in new_unstacked_length]
if not all(ul.type.ndim == 0 for ul in new_unstacked_length):
raise NotImplementedError(
f"Vectorization of {self} with batched unstacked_length not implemented, "
)
return [self(new_x, *new_unstacked_length)]
def unstack(x, dim: dict[str, dict[str, int]] | None = None, **dims: dict[str, int]): def unstack(x, dim: dict[str, dict[str, int]] | None = None, **dims: dict[str, int]):
if dim is not None: if dim is not None:
...@@ -189,6 +200,11 @@ class Transpose(XOp): ...@@ -189,6 +200,11 @@ class Transpose(XOp):
) )
return Apply(self, [x], [output]) return Apply(self, [x], [output])
def vectorize_node(self, node, new_x):
old_dims = self.dims
new_dims = tuple(dim for dim in new_x.dims if dim not in old_dims)
return [type(self)(dims=(*new_dims, *old_dims))(new_x)]
def transpose( def transpose(
x, x,
...@@ -302,6 +318,9 @@ class Concat(XOp): ...@@ -302,6 +318,9 @@ class Concat(XOp):
output = xtensor(dtype=dtype, dims=dims, shape=shape) output = xtensor(dtype=dtype, dims=dims, shape=shape)
return Apply(self, inputs, [output]) return Apply(self, inputs, [output])
def vectorize_node(self, node, *new_inputs):
return [self(*new_inputs)]
def concat(xtensors, dim: str): def concat(xtensors, dim: str):
"""Concatenate a sequence of XTensorVariables along a specified dimension. """Concatenate a sequence of XTensorVariables along a specified dimension.
...@@ -383,6 +402,9 @@ class Squeeze(XOp): ...@@ -383,6 +402,9 @@ class Squeeze(XOp):
) )
return Apply(self, [x], [out]) return Apply(self, [x], [out])
def vectorize_node(self, node, new_x):
return [self(new_x)]
def squeeze(x, dim: str | Sequence[str] | None = None): def squeeze(x, dim: str | Sequence[str] | None = None):
"""Remove dimensions of size 1 from an XTensorVariable.""" """Remove dimensions of size 1 from an XTensorVariable."""
...@@ -442,6 +464,14 @@ class ExpandDims(XOp): ...@@ -442,6 +464,14 @@ class ExpandDims(XOp):
) )
return Apply(self, [x, size], [out]) return Apply(self, [x, size], [out])
def vectorize_node(self, node, new_x, new_size):
new_size = new_size.squeeze()
if new_size.type.ndim != 0:
raise NotImplementedError(
f"Vectorization of {self} with batched new_size not implemented, "
)
return [self(new_x, new_size)]
def expand_dims(x, dim=None, axis=None, **dim_kwargs): def expand_dims(x, dim=None, axis=None, **dim_kwargs):
"""Add one or more new dimensions to an XTensorVariable.""" """Add one or more new dimensions to an XTensorVariable."""
...@@ -537,6 +567,19 @@ class Broadcast(XOp): ...@@ -537,6 +567,19 @@ class Broadcast(XOp):
return Apply(self, inputs, outputs) return Apply(self, inputs, outputs)
def vectorize_node(self, node, *new_inputs):
if exclude_set := set(self.exclude):
for new_x, old_x in zip(node.inputs, new_inputs, strict=True):
if invalid_excluded := (
(set(new_x.dims) - set(old_x.dims)) & exclude_set
):
raise NotImplementedError(
f"Vectorize of {self} is undefined because one of the inputs {new_x} "
f"has an excluded dimension {sorted(invalid_excluded)} that it did not have before."
)
return self(*new_inputs, return_list=True)
def broadcast( def broadcast(
*args, exclude: str | Sequence[str] | None = None *args, exclude: str | Sequence[str] | None = None
......
...@@ -1044,7 +1044,7 @@ def as_xtensor(x, dims: Sequence[str] | None = None, *, name: str | None = None) ...@@ -1044,7 +1044,7 @@ def as_xtensor(x, dims: Sequence[str] | None = None, *, name: str | None = None)
if isinstance(x, Variable): if isinstance(x, Variable):
if isinstance(x.type, XTensorType): if isinstance(x.type, XTensorType):
if (dims is None) or (x.type.dims == dims): if (dims is None) or (x.type.dims == tuple(dims)):
return x return x
else: else:
raise ValueError( raise ValueError(
......
...@@ -6,6 +6,8 @@ import numpy as np ...@@ -6,6 +6,8 @@ import numpy as np
from pytensor import scalar as ps from pytensor import scalar as ps
from pytensor import shared from pytensor import shared
from pytensor.graph import Apply, Op from pytensor.graph import Apply, Op
from pytensor.graph.basic import Variable
from pytensor.graph.replace import _vectorize_node
from pytensor.scalar import discrete_dtypes from pytensor.scalar import discrete_dtypes
from pytensor.tensor import tensor from pytensor.tensor import tensor
from pytensor.tensor.random.op import RNGConsumerOp from pytensor.tensor.random.op import RNGConsumerOp
...@@ -14,8 +16,11 @@ from pytensor.tensor.utils import ( ...@@ -14,8 +16,11 @@ from pytensor.tensor.utils import (
get_static_shape_from_size_variables, get_static_shape_from_size_variables,
) )
from pytensor.utils import unzip from pytensor.utils import unzip
from pytensor.xtensor.basic import XOp from pytensor.xtensor.basic import (
from pytensor.xtensor.type import XTensorVariable, as_xtensor, xtensor XOp,
XTypeCastOp,
)
from pytensor.xtensor.type import XTensorType, XTensorVariable, as_xtensor, xtensor
def combine_dims_and_shape( def combine_dims_and_shape(
...@@ -74,6 +79,9 @@ class XElemwise(XOp): ...@@ -74,6 +79,9 @@ class XElemwise(XOp):
] ]
return Apply(self, inputs, outputs) return Apply(self, inputs, outputs)
def vectorize_node(self, node, *new_inputs):
return self(*new_inputs, return_list=True)
class XBlockwise(XOp): class XBlockwise(XOp):
__props__ = ("core_op", "core_dims") __props__ = ("core_op", "core_dims")
...@@ -141,6 +149,9 @@ class XBlockwise(XOp): ...@@ -141,6 +149,9 @@ class XBlockwise(XOp):
] ]
return Apply(self, inputs, outputs) return Apply(self, inputs, outputs)
def vectorize_node(self, node, *new_inputs):
return self(*new_inputs, return_list=True)
class XRV(XOp, RNGConsumerOp): class XRV(XOp, RNGConsumerOp):
"""Wrapper for RandomVariable operations that follows xarray-like broadcasting semantics. """Wrapper for RandomVariable operations that follows xarray-like broadcasting semantics.
...@@ -288,3 +299,54 @@ class XRV(XOp, RNGConsumerOp): ...@@ -288,3 +299,54 @@ class XRV(XOp, RNGConsumerOp):
) )
return Apply(self, [rng, *extra_dim_lengths, *params], [rng.type(), out]) return Apply(self, [rng, *extra_dim_lengths, *params], [rng.type(), out])
def vectorize_node(self, node, *new_inputs):
new_rng, *new_extra_dim_lengths_and_params = new_inputs
k = len(self.extra_dims)
new_extra_dim_lengths, new_params = (
new_extra_dim_lengths_and_params[:k],
new_extra_dim_lengths_and_params[k:],
)
new_extra_dim_lengths = [dl.squeeze() for dl in new_extra_dim_lengths]
if not all(dl.type.ndim == 0 for dl in new_extra_dim_lengths):
raise NotImplementedError(
f"Vectorization of {self} with batched extra_dim_lengths not implemented, "
)
return self.make_node(new_rng, *new_extra_dim_lengths, *new_params).outputs
@_vectorize_node.register(XOp)
@_vectorize_node.register(XTypeCastOp)
def vectorize_xop(op: XOp, node, *new_inputs) -> Sequence[Variable]:
old_inp_dims = [
inp.dims for inp in node.inputs if isinstance(inp.type, XTensorType)
]
old_out_dims = [
out.dims for out in node.outputs if isinstance(out.type, XTensorType)
]
all_old_dims_set = set(chain.from_iterable((*old_inp_dims, old_out_dims)))
for new_inp, old_inp in zip(new_inputs, node.inputs, strict=True):
if not (
isinstance(new_inp.type, XTensorType)
and isinstance(old_inp.type, XTensorType)
):
continue
old_dims_set = set(old_inp.dims)
new_dims_set = set(new_inp.dims)
# Validate that new inputs didn't drop pre-existing dims
if missing_dims := old_dims_set - new_dims_set:
raise ValueError(
f"Vectorized input {new_inp} is missing pre-existing dims: {sorted(missing_dims)}"
)
# Or have new dimensions that were already in the graph
if new_core_dims := ((new_dims_set - old_dims_set) & all_old_dims_set):
raise ValueError(
f"Vectorized input {new_inp} has new dimensions that were present in the original graph: {new_core_dims}"
)
return op.vectorize_node(node, *new_inputs)
import pytest
pytest.importorskip("xarray")
import numpy as np import numpy as np
from pytensor import function from pytensor import function
from pytensor.xtensor.basic import Rename from pytensor.graph import vectorize_graph
from pytensor.tensor import matrix, vector
from pytensor.xtensor.basic import (
Rename,
rename,
tensor_from_xtensor,
xtensor_from_tensor,
)
from pytensor.xtensor.type import xtensor from pytensor.xtensor.type import xtensor
from tests.unittest_tools import assert_equal_computations
# from pytensor.xtensor.vectorization import vectorize_graph
from tests.xtensor.util import check_vectorization
def test_shape_feature_does_not_see_xop(): def test_shape_feature_does_not_see_xop():
...@@ -24,3 +40,36 @@ def test_shape_feature_does_not_see_xop(): ...@@ -24,3 +40,36 @@ def test_shape_feature_does_not_see_xop():
fn = function([x], out) fn = function([x], out)
np.testing.assert_allclose(fn([1, 2, 3]), [0, 0, 0]) np.testing.assert_allclose(fn([1, 2, 3]), [0, 0, 0])
assert not CALLED assert not CALLED
def test_rename_vectorize():
ab = xtensor("ab", dims=("a", "b"), shape=(2, 3), dtype="float64")
check_vectorization(ab, rename(ab, a="c"))
def test_xtensor_from_tensor_vectorize():
t = vector("t")
x = xtensor_from_tensor(t, dims=("a",))
t_batched = matrix("t_batched")
with pytest.raises(
NotImplementedError, match=r"Vectorization of .* not implemented"
):
vectorize_graph([x], {t: t_batched})
def test_tensor_from_xtensor_vectorize():
x = xtensor("x", dims=("a",), shape=(3,))
y = tensor_from_xtensor(x)
x_batched = xtensor("x", dims=("a", "b"), shape=(3, 5))
y_batched = vectorize_graph(y, {x: x_batched})
# vectorize_graph should place output batch dimension on the left
assert y_batched.type.shape == (5, 3)
assert_equal_computations([y_batched], [x_batched.transpose("b", ...).values])
x_batched = xtensor("x", dims=("c", "a", "b"), shape=(7, 3, 5))
# vectorize_graph can't handle multiple batch dimensions safely
with pytest.raises(NotImplementedError):
vectorize_graph(y, {x: x_batched})
...@@ -14,6 +14,7 @@ from pytensor.tensor import tensor ...@@ -14,6 +14,7 @@ from pytensor.tensor import tensor
from pytensor.xtensor import xtensor from pytensor.xtensor import xtensor
from tests.unittest_tools import assert_equal_computations from tests.unittest_tools import assert_equal_computations
from tests.xtensor.util import ( from tests.xtensor.util import (
check_vectorization,
xr_arange_like, xr_arange_like,
xr_assert_allclose, xr_assert_allclose,
xr_function, xr_function,
...@@ -542,3 +543,43 @@ def test_empty_update_index(): ...@@ -542,3 +543,43 @@ def test_empty_update_index():
fn = xr_function([x], out1) fn = xr_function([x], out1)
x_test = xr_random_like(x) x_test = xr_random_like(x)
xr_assert_allclose(fn(x_test), x_test + 1) xr_assert_allclose(fn(x_test), x_test + 1)
def test_indexing_vectorize():
abc = xtensor(dims=("a", "b", "c"), shape=(3, 5, 7))
a_idx = xtensor(dims=("a",), shape=(5,), dtype="int64")
c_idx = xtensor(dims=("c",), shape=(3,), dtype="int64")
abc_val = xr_random_like(abc)
a_idx_val = DataArray([0, 1, 0, 2, 0], dims=("a",))
c_idx_val = DataArray([0, 5, 6], dims=("c",))
check_vectorization([abc, a_idx], [abc.isel(a=a_idx)], [abc_val, a_idx_val])
check_vectorization(
[abc, a_idx], [abc.isel(a=a_idx.rename(a="b"))], [abc_val, a_idx_val]
)
check_vectorization(
[abc, a_idx], [abc.isel(a=a_idx.rename(a="d"))], [abc_val, a_idx_val]
)
check_vectorization([abc, a_idx], [abc.isel(c=a_idx[:3])], [abc_val, a_idx_val])
check_vectorization(
[abc, a_idx], [abc.isel(a=a_idx, c=a_idx)], [abc_val, a_idx_val]
)
check_vectorization(
[abc, a_idx, c_idx],
[abc.isel(a=a_idx, c=c_idx)],
[abc_val, a_idx_val, c_idx_val],
)
def test_index_update_vectorize():
x = xtensor("x", dims=("a", "b"), shape=(3, 5))
idx = xtensor("idx", dims=("b*",), shape=(7,), dtype=int)
y = xtensor("y", dims=("b*",), shape=(7,))
x_val = xr_random_like(x)
idx_val = DataArray([2, 0, 4, 0, 1, 0, 3], dims=("b*",))
y_val = xr_random_like(y)
check_vectorization([x, idx, y], [x.isel(b=idx).set(y)], [x_val, idx_val, y_val])
check_vectorization([x, idx, y], [x.isel(b=idx).inc(y)], [x_val, idx_val, y_val])
...@@ -16,7 +16,7 @@ from xarray_einstats.linalg import ( ...@@ -16,7 +16,7 @@ from xarray_einstats.linalg import (
from pytensor.xtensor.linalg import cholesky, solve from pytensor.xtensor.linalg import cholesky, solve
from pytensor.xtensor.type import xtensor from pytensor.xtensor.type import xtensor
from tests.xtensor.util import xr_assert_allclose, xr_function from tests.xtensor.util import check_vectorization, xr_assert_allclose, xr_function
def test_cholesky(): def test_cholesky():
...@@ -74,3 +74,22 @@ def test_solve_matrix_b(): ...@@ -74,3 +74,22 @@ def test_solve_matrix_b():
fn(a_test, b_test), fn(a_test, b_test),
xr_solve(a_test, b_test, dims=["country", "city", "district"]), xr_solve(a_test, b_test, dims=["country", "city", "district"]),
) )
def test_linalg_vectorize():
# Note: We only need to test a couple Ops, since the vectorization logic is not Op specific
a = xtensor("b", dims=("a",), shape=(3,))
ab = xtensor("a", dims=("a", "b"), shape=(3, 3))
test_spd = np.random.randn(3, 3)
test_spd = test_spd @ test_spd.T
check_vectorization(
[ab],
[cholesky(ab, dims=("b", "a"))],
input_vals=[DataArray(test_spd, dims=("a", "b"))],
)
check_vectorization(
[ab, a],
[solve(ab, a, dims=("a", "b"))],
)
...@@ -17,7 +17,12 @@ from pytensor.scalar import ScalarOp ...@@ -17,7 +17,12 @@ from pytensor.scalar import ScalarOp
from pytensor.xtensor.basic import rename from pytensor.xtensor.basic import rename
from pytensor.xtensor.math import add, exp, logsumexp from pytensor.xtensor.math import add, exp, logsumexp
from pytensor.xtensor.type import xtensor from pytensor.xtensor.type import xtensor
from tests.xtensor.util import xr_arange_like, xr_assert_allclose, xr_function from tests.xtensor.util import (
check_vectorization,
xr_arange_like,
xr_assert_allclose,
xr_function,
)
def test_all_scalar_ops_are_wrapped(): def test_all_scalar_ops_are_wrapped():
...@@ -340,3 +345,11 @@ def test_dot_errors(): ...@@ -340,3 +345,11 @@ def test_dot_errors():
match=r"(Input operand 1 has a mismatch in its core dimension 0|incompatible array sizes for np.dot)", match=r"(Input operand 1 has a mismatch in its core dimension 0|incompatible array sizes for np.dot)",
): ):
fn(x_test, y_test) fn(x_test, y_test)
def test_xelemwise_vectorize():
ab = xtensor("ab", dims=("a", "b"), shape=(2, 3))
bc = xtensor("bc", dims=("b", "c"), shape=(3, 5))
check_vectorization([ab], [exp(ab)])
check_vectorization([ab, bc], [ab + bc])
...@@ -9,6 +9,7 @@ import re ...@@ -9,6 +9,7 @@ import re
from copy import deepcopy from copy import deepcopy
import numpy as np import numpy as np
from xarray import DataArray
import pytensor.tensor.random as ptr import pytensor.tensor.random as ptr
import pytensor.xtensor.random as pxr import pytensor.xtensor.random as pxr
...@@ -26,6 +27,7 @@ from pytensor.xtensor.random import ( ...@@ -26,6 +27,7 @@ from pytensor.xtensor.random import (
normal, normal,
) )
from pytensor.xtensor.vectorization import XRV from pytensor.xtensor.vectorization import XRV
from tests.xtensor.util import check_vectorization
def lower_rewrite(vars): def lower_rewrite(vars):
...@@ -438,3 +440,27 @@ def test_multivariate_normal(): ...@@ -438,3 +440,27 @@ def test_multivariate_normal():
): ):
# cov must have both core_dims # cov must have both core_dims
multivariate_normal(mu_xr, cov_xr, core_dims=("rows", "missing_cols")) multivariate_normal(mu_xr, cov_xr, core_dims=("rows", "missing_cols"))
def test_xrv_vectorize():
# Note: We only need to test a couple Ops, since the vectorization logic is not Op specific
n = xtensor("n", dims=("n",), shape=(3,), dtype=int)
pna = xtensor("p", dims=("p", "n", "a"), shape=(5, 3, 2))
out = multinomial(n, pna, core_dims=("p",), extra_dims={"extra": 5})
check_vectorization(
[n, pna],
[out],
input_vals=[
DataArray([3, 5, 10], dims=("n",)),
DataArray(
np.random.multinomial(n=1, pvals=np.ones(5) / 5, size=(2, 3)).T,
dims=("p", "n", "a"),
),
],
)
def test_xrv_batch_extra_dim_vectorize():
# TODO: Check it raises NotImplementedError when we try to batch the extra_dim of an xrv
pass
...@@ -8,7 +8,12 @@ import numpy as np ...@@ -8,7 +8,12 @@ import numpy as np
import xarray as xr import xarray as xr
from pytensor.xtensor.type import xtensor from pytensor.xtensor.type import xtensor
from tests.xtensor.util import xr_arange_like, xr_assert_allclose, xr_function from tests.xtensor.util import (
check_vectorization,
xr_arange_like,
xr_assert_allclose,
xr_function,
)
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -99,3 +104,16 @@ def test_discrete_reduction_upcasting(signed): ...@@ -99,3 +104,16 @@ def test_discrete_reduction_upcasting(signed):
res = fn(x_val) res = fn(x_val)
np.testing.assert_allclose(res, [test_val, test_val**2]) np.testing.assert_allclose(res, [test_val, test_val**2])
xr_assert_allclose(res, x_val.cumprod()) xr_assert_allclose(res, x_val.cumprod())
def test_reduction_vectorize():
# Note: We only need to test a couple Ops, since the vectorization logic is not Op specific
abc = xtensor("x", dims=("a", "b", "c"), shape=(3, 5, 7))
check_vectorization([abc], [abc.sum(dim="a")])
check_vectorization([abc], [abc.max(dim=("a", "c"))])
check_vectorization([abc], [abc.all()])
check_vectorization([abc], [abc.cumsum(dim="b")])
check_vectorization([abc], [abc.cumsum(dim=("c", "b"))])
check_vectorization([abc], [abc.cumprod()])
...@@ -15,7 +15,8 @@ from xarray import full_like as xr_full_like ...@@ -15,7 +15,8 @@ from xarray import full_like as xr_full_like
from xarray import ones_like as xr_ones_like from xarray import ones_like as xr_ones_like
from xarray import zeros_like as xr_zeros_like from xarray import zeros_like as xr_zeros_like
from pytensor.tensor import scalar from pytensor.graph import vectorize_graph
from pytensor.tensor import scalar, vector
from pytensor.xtensor.shape import ( from pytensor.xtensor.shape import (
broadcast, broadcast,
concat, concat,
...@@ -25,8 +26,9 @@ from pytensor.xtensor.shape import ( ...@@ -25,8 +26,9 @@ from pytensor.xtensor.shape import (
unstack, unstack,
zeros_like, zeros_like,
) )
from pytensor.xtensor.type import xtensor from pytensor.xtensor.type import as_xtensor, xtensor
from tests.xtensor.util import ( from tests.xtensor.util import (
check_vectorization,
xr_arange_like, xr_arange_like,
xr_assert_allclose, xr_assert_allclose,
xr_function, xr_function,
...@@ -800,3 +802,90 @@ def test_zeros_like(): ...@@ -800,3 +802,90 @@ def test_zeros_like():
expected1 = xr_zeros_like(x_test) expected1 = xr_zeros_like(x_test)
xr_assert_allclose(result1, expected1) xr_assert_allclose(result1, expected1)
assert result1.dtype == expected1.dtype assert result1.dtype == expected1.dtype
def test_shape_ops_vectorize():
a1 = xtensor("a1", dims=("a", "1"), shape=(2, 1), dtype="float64")
ab = xtensor("ab", dims=("a", "b"), shape=(2, 3), dtype="float64")
abc = xtensor("abc", dims=("a", "b", "c"), shape=(2, 3, 5), dtype="float64")
a_bc_d = xtensor("a_bc_d", dims=("a", "bc", "d"), shape=(4, 15, 7))
check_vectorization(abc, abc.transpose("b", "c", "a"))
check_vectorization(abc, abc.transpose("b", ...))
check_vectorization(abc, stack(abc, new_dim=("a", "c")))
check_vectorization(a_bc_d, unstack(a_bc_d, bc=dict(b=3, c=5)))
check_vectorization([abc, ab], concat([abc, ab], dim="a"))
check_vectorization(a1, a1.squeeze("1"))
check_vectorization(abc, abc.expand_dims(d=5))
check_vectorization([ab, abc], broadcast(ab, abc))
check_vectorization([ab, abc, a1], broadcast(ab, abc, a1, exclude="1"))
# a is longer in a_bc_d than in ab and abc, helper can't handle that
# check_vectorization([ab, abc, a_bc_d], broadcast(ab, abc, a_bc_d, exclude="a"))
def test_broadcast_exclude_vectorize():
x = xtensor("x", dims=("a", "b"), shape=(3, 4))
y = xtensor("y", dims=("b", "c"), shape=(7, 5))
out_x, out_y = broadcast(x, y, exclude=("b",))
x_val = xr_random_like(x)
y_val = xr_random_like(y)
x_batch_val = x_val.expand_dims({"batch": 2})
y_batch_val = y_val.expand_dims({"batch": 2})
x_batch = as_xtensor(x_batch_val).type("x_batch")
y_batch = as_xtensor(y_batch_val).type("y_batch")
[out_x_vec, out_y_vec] = vectorize_graph([out_x, out_y], {x: x_batch, y: y_batch})
fn = xr_function([x_batch, y_batch], [out_x_vec, out_y_vec])
res_x, res_y = fn(x_batch_val, y_batch_val)
expected_x = []
expected_y = []
for i in range(2):
ex_x, ex_y = xr_broadcast(
x_batch_val.isel(batch=i), y_batch_val.isel(batch=i), exclude=("b",)
)
expected_x.append(ex_x)
expected_y.append(ex_y)
expected_x = xr_concat(expected_x, dim="batch")
expected_y = xr_concat(expected_y, dim="batch")
xr_assert_allclose(res_x, expected_x)
xr_assert_allclose(res_y, expected_y)
def test_expand_dims_batch_length_vectorize():
x = xtensor("x", dims=("a",), shape=(3,))
l = scalar("l", dtype="int64")
y = x.expand_dims(b=l)
x_batch = as_xtensor(xr_random_like(x).expand_dims(batch=2)).type("x_batch")
l_batch = vector("l_batch", dtype="int64")
with pytest.raises(
NotImplementedError, match=r"Vectorization of .* not implemented"
):
vectorize_graph([y], {x: x_batch, l: l_batch})
def test_unstack_batch_length_vectorize():
x = xtensor("x", dims=("ab",), shape=(12,))
l = scalar("l", dtype="int64")
y = unstack(x, ab={"a": l, "b": x.sizes["ab"] // l})
x_batch = as_xtensor(xr_random_like(x).expand_dims(batch=2)).type("x_batch")
l_batch = vector("l_batch", dtype="int64")
with pytest.raises(
NotImplementedError, match=r"Vectorization of .* not implemented"
):
vectorize_graph([y], {x: x_batch, l: l_batch})
...@@ -12,7 +12,12 @@ from xarray import apply_ufunc ...@@ -12,7 +12,12 @@ from xarray import apply_ufunc
from pytensor.xtensor.signal import convolve1d from pytensor.xtensor.signal import convolve1d
from pytensor.xtensor.type import xtensor from pytensor.xtensor.type import xtensor
from tests.xtensor.util import xr_arange_like, xr_assert_allclose, xr_function from tests.xtensor.util import (
check_vectorization,
xr_arange_like,
xr_assert_allclose,
xr_function,
)
@pytest.mark.parametrize("mode", ("full", "valid", "same")) @pytest.mark.parametrize("mode", ("full", "valid", "same"))
...@@ -68,3 +73,14 @@ def test_convolve_1d_invalid(): ...@@ -68,3 +73,14 @@ def test_convolve_1d_invalid():
match=re.escape("Input 1 has invalid core dims ['time']. Allowed: ('kernel',)"), match=re.escape("Input 1 has invalid core dims ['time']. Allowed: ('kernel',)"),
): ):
convolve1d(in1, in2.rename({"batch": "time"}), dims=("time", "kernel")) convolve1d(in1, in2.rename({"batch": "time"}), dims=("time", "kernel"))
def test_signal_vectorize():
# Note: We only need to test a couple Ops, since the vectorization logic is not Op specific
ab = xtensor("a", dims=("a", "b"), shape=(3, 3))
c = xtensor(name="c", dims=("c",), shape=(7,))
check_vectorization(
[ab, c],
[convolve1d(ab, c, dims=("a", "c"))],
)
import pytest import pytest
pytest.importorskip("xarray") xr = pytest.importorskip("xarray")
from itertools import chain
import numpy as np import numpy as np
from xarray import DataArray from xarray import DataArray
from xarray.testing import assert_allclose from xarray.testing import assert_allclose
from pytensor import function from pytensor import function
from pytensor.xtensor.type import XTensorType from pytensor.graph import vectorize_graph
from pytensor.xtensor.type import XTensorType, as_xtensor
def xr_function(*args, **kwargs): def xr_function(*args, **kwargs):
...@@ -76,3 +79,57 @@ def xr_random_like(x, rng=None): ...@@ -76,3 +79,57 @@ def xr_random_like(x, rng=None):
return DataArray( return DataArray(
rng.standard_normal(size=x.type.shape, dtype=x.type.dtype), dims=x.type.dims rng.standard_normal(size=x.type.shape, dtype=x.type.dtype), dims=x.type.dims
) )
def check_vectorization(inputs, outputs, input_vals=None, rng=None):
# Create core graph and function
if not isinstance(inputs, list | tuple):
inputs = (inputs,)
if not isinstance(outputs, list | tuple):
outputs = (outputs,)
# apply_ufunc isn't happy with list output or single entry
_core_fn = function(inputs, outputs)
def core_fn(*args, _core_fn=_core_fn):
res = _core_fn(*args)
if len(res) == 1:
return res[0]
else:
return tuple(res)
if input_vals is None:
rng = np.random.default_rng(rng)
input_vals = [xr_random_like(inp, rng) for inp in inputs]
# Create vectorized inputs
batch_inputs = []
batch_input_vals = []
for i, (inp, val) in enumerate(zip(inputs, input_vals)):
new_val = val.expand_dims({f"batch_{i}": 2 ** (i + 1)})
new_inp = as_xtensor(new_val).type(f"batch_{inp.name or f'input{i}'}")
batch_inputs.append(new_inp)
batch_input_vals.append(new_val)
# Create vectorized function
new_outputs = vectorize_graph(outputs, dict(zip(inputs, batch_inputs)))
vec_fn = xr_function(batch_inputs, new_outputs)
vec_res = vec_fn(*batch_input_vals)
# xarray.apply_ufunc with vectorize=True loops over non-core dims
input_core_dims = [i.dims for i in inputs]
output_core_dims = [o.dims for o in outputs]
expected_res = xr.apply_ufunc(
core_fn,
*batch_input_vals,
input_core_dims=input_core_dims,
output_core_dims=output_core_dims,
exclude_dims=set(chain.from_iterable((*input_core_dims, *output_core_dims))),
vectorize=True,
)
if not isinstance(expected_res, list | tuple):
expected_res = (expected_res,)
for v_r, e_r in zip(vec_res, expected_res):
xr_assert_allclose(v_r, e_r.transpose(*v_r.dims))
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论