提交 010e0f97 authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Implement reduction operations for XTensorVariables

上级 cdb026f7
......@@ -473,24 +473,6 @@ def cumprod(x, axis=None):
return CumOp(axis=axis, mode="mul")(x)
class CumsumOp(Op):
__props__ = ("axis",)
def __new__(typ, *args, **kwargs):
obj = object.__new__(CumOp, *args, **kwargs)
obj.mode = "add"
return obj
class CumprodOp(Op):
__props__ = ("axis",)
def __new__(typ, *args, **kwargs):
obj = object.__new__(CumOp, *args, **kwargs)
obj.mode = "mul"
return obj
def diff(x, n=1, axis=-1):
"""Calculate the `n`-th order discrete difference along the given `axis`.
......
import warnings
import pytensor.xtensor.rewriting
from pytensor.xtensor import (
linalg,
)
from pytensor.xtensor import linalg
from pytensor.xtensor.type import (
XTensorType,
as_xtensor,
......
......@@ -134,3 +134,8 @@ def cast(x, dtype):
if dtype not in _xelemwise_cast_op:
_xelemwise_cast_op[dtype] = XElemwise(scalar_op=_cast_mapping[dtype])
return _xelemwise_cast_op[dtype](x)
def softmax(x, dim=None):
exp_x = exp(x)
return exp_x / exp_x.sum(dim=dim)
import typing
from collections.abc import Sequence
from functools import partial
from types import EllipsisType
import pytensor.scalar as ps
from pytensor.graph.basic import Apply
from pytensor.tensor.math import variadic_mul
from pytensor.xtensor.basic import XOp
from pytensor.xtensor.math import neq, sqrt
from pytensor.xtensor.math import sqr as square
from pytensor.xtensor.type import as_xtensor, xtensor
REDUCE_DIM = str | Sequence[str] | EllipsisType | None
class XReduce(XOp):
__slots__ = ("binary_op", "dims")
def __init__(self, binary_op, dims: Sequence[str]):
super().__init__()
self.binary_op = binary_op
# Order of reduce dims doesn't change the behavior of the Op
self.dims = tuple(sorted(dims))
def make_node(self, x):
x = as_xtensor(x)
x_dims = x.type.dims
x_dims_set = set(x_dims)
reduce_dims_set = set(self.dims)
if x_dims_set == reduce_dims_set:
out_dims, out_shape = [], []
else:
if not reduce_dims_set.issubset(x_dims_set):
raise ValueError(
f"Reduced dims {self.dims} not found in array dimensions {x_dims}."
)
out_dims, out_shape = zip(
*[
(d, s)
for d, s in zip(x_dims, x.type.shape)
if d not in reduce_dims_set
]
)
output = xtensor(dtype=x.type.dtype, shape=out_shape, dims=out_dims)
return Apply(self, [x], [output])
def _process_user_dims(x, dim: REDUCE_DIM) -> Sequence[str]:
if isinstance(dim, str):
return (dim,)
elif dim is None or dim is Ellipsis:
x = as_xtensor(x)
return typing.cast(tuple[str], x.type.dims)
return dim
def reduce(x, dim: REDUCE_DIM = None, *, binary_op):
dims = _process_user_dims(x, dim)
return XReduce(binary_op=binary_op, dims=dims)(x)
sum = partial(reduce, binary_op=ps.add)
prod = partial(reduce, binary_op=ps.mul)
max = partial(reduce, binary_op=ps.scalar_maximum)
min = partial(reduce, binary_op=ps.scalar_minimum)
def bool_reduce(x, dim: REDUCE_DIM = None, *, binary_op):
x = as_xtensor(x)
if x.type.dtype != "bool":
x = neq(x, 0)
return reduce(x, dim=dim, binary_op=binary_op)
all = partial(bool_reduce, binary_op=ps.and_)
any = partial(bool_reduce, binary_op=ps.or_)
def _infer_reduced_size(original_var, reduced_var):
reduced_dims = reduced_var.dims
return variadic_mul(
*[size for dim, size in original_var.sizes if dim not in reduced_dims]
)
def mean(x, dim: REDUCE_DIM):
x = as_xtensor(x)
sum_x = sum(x, dim)
n = _infer_reduced_size(x, sum_x)
return sum_x / n
def var(x, dim: REDUCE_DIM, *, ddof: int = 0):
x = as_xtensor(x)
x_mean = mean(x, dim)
n = _infer_reduced_size(x, x_mean)
return square(x - x_mean) / (n - ddof)
def std(x, dim: REDUCE_DIM, *, ddof: int = 0):
return sqrt(var(x, dim, ddof=ddof))
class XCumReduce(XOp):
__props__ = ("binary_op", "dims")
def __init__(self, binary_op, dims: Sequence[str]):
self.binary_op = binary_op
self.dims = tuple(sorted(dims)) # Order doesn't matter
def make_node(self, x):
x = as_xtensor(x)
out = x.type()
return Apply(self, [x], [out])
def cumreduce(x, dim: REDUCE_DIM, *, binary_op):
dims = _process_user_dims(x, dim)
return XCumReduce(dims=dims, binary_op=binary_op)(x)
cumsum = partial(cumreduce, binary_op=ps.add)
cumprod = partial(cumreduce, binary_op=ps.mul)
import pytensor.xtensor.rewriting.basic
import pytensor.xtensor.rewriting.reduction
import pytensor.xtensor.rewriting.shape
import pytensor.xtensor.rewriting.vectorization
from functools import partial
import pytensor.scalar as ps
from pytensor.graph.rewriting.basic import node_rewriter
from pytensor.tensor.extra_ops import CumOp
from pytensor.tensor.math import All, Any, CAReduce, Max, Min, Prod, Sum
from pytensor.xtensor.basic import tensor_from_xtensor, xtensor_from_tensor
from pytensor.xtensor.reduction import XCumReduce, XReduce
from pytensor.xtensor.rewriting.utils import register_lower_xtensor
@register_lower_xtensor
@node_rewriter(tracks=[XReduce])
def lower_reduce(fgraph, node):
[x] = node.inputs
[out] = node.outputs
x_dims = x.type.dims
reduce_dims = node.op.dims
reduce_axis = [x_dims.index(dim) for dim in reduce_dims]
if not reduce_axis:
return [x]
match node.op.binary_op:
case ps.add:
tensor_op_class = Sum
case ps.mul:
tensor_op_class = Prod
case ps.and_:
tensor_op_class = All
case ps.or_:
tensor_op_class = Any
case ps.scalar_maximum:
tensor_op_class = Max
case ps.scalar_minimum:
tensor_op_class = Min
case _:
# Case without known/predefined Ops
tensor_op_class = partial(CAReduce, scalar_op=node.op.binary_op)
x_tensor = tensor_from_xtensor(x)
out_tensor = tensor_op_class(axis=reduce_axis)(x_tensor)
new_out = xtensor_from_tensor(out_tensor, out.type.dims)
return [new_out]
@register_lower_xtensor
@node_rewriter(tracks=[XCumReduce])
def lower_cumreduce(fgraph, node):
[x] = node.inputs
x_dims = x.type.dims
reduce_dims = node.op.dims
reduce_axis = [x_dims.index(dim) for dim in reduce_dims]
if not reduce_axis:
return [x]
match node.op.binary_op:
case ps.add:
tensor_op_class = partial(CumOp, mode="add")
case ps.mul:
tensor_op_class = partial(CumOp, mode="mul")
case _:
# We don't know how to convert an arbitrary binary cum/reduce Op
return None
# Each dim corresponds to an application of Cumsum/Cumprod
out_tensor = tensor_from_xtensor(x)
for axis in reduce_axis:
out_tensor = tensor_op_class(axis=axis)(out_tensor)
out = xtensor_from_tensor(out_tensor, x.type.dims)
return [out]
......@@ -438,6 +438,41 @@ class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]):
def real(self):
return px.math.real(self)
# Aggregation
# https://docs.xarray.dev/en/latest/api.html#id6
def all(self, dim=None):
return px.reduction.all(self, dim)
def any(self, dim=None):
return px.reduction.any(self, dim)
def max(self, dim=None):
return px.reduction.max(self, dim)
def min(self, dim=None):
return px.reduction.min(self, dim)
def mean(self, dim=None):
return px.reduction.mean(self, dim)
def prod(self, dim=None):
return px.reduction.prod(self, dim)
def sum(self, dim=None):
return px.reduction.sum(self, dim)
def std(self, dim=None):
return px.reduction.std(self, dim)
def var(self, dim=None):
return px.reduction.var(self, dim)
def cumsum(self, dim=None):
return px.reduction.cumsum(self, dim)
def cumprod(self, dim=None):
return px.reduction.cumprod(self, dim)
# Reshaping and reorganizing
# https://docs.xarray.dev/en/latest/api.html#id8
def stack(self, dim, **dims):
......
# ruff: noqa: E402
import pytest
pytest.importorskip("xarray")
from pytensor.xtensor.type import xtensor
from tests.xtensor.util import xr_arange_like, xr_assert_allclose, xr_function
@pytest.mark.parametrize(
"dim", [..., None, "a", ("c", "a")], ids=["Ellipsis", "None", "a", "(a, c)"]
)
@pytest.mark.parametrize(
"method", ["sum", "prod", "all", "any", "max", "min", "cumsum", "cumprod"][2:]
)
def test_reduction(method, dim):
x = xtensor("x", dims=("a", "b", "c"), shape=(3, 5, 7))
out = getattr(x, method)(dim=dim)
fn = xr_function([x], out)
x_test = xr_arange_like(x)
xr_assert_allclose(
fn(x_test),
getattr(x_test, method)(dim=dim),
)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论