提交 18449731 authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Fix discrete xtensor reduction

上级 4ce8ef3c
...@@ -9,7 +9,7 @@ from pytensor.tensor.math import variadic_mul ...@@ -9,7 +9,7 @@ from pytensor.tensor.math import variadic_mul
from pytensor.xtensor.basic import XOp from pytensor.xtensor.basic import XOp
from pytensor.xtensor.math import neq, sqrt from pytensor.xtensor.math import neq, sqrt
from pytensor.xtensor.math import sqr as square from pytensor.xtensor.math import sqr as square
from pytensor.xtensor.type import as_xtensor, xtensor from pytensor.xtensor.type import XTensorVariable, as_xtensor, xtensor
REDUCE_DIM = str | Sequence[str] | EllipsisType | None REDUCE_DIM = str | Sequence[str] | EllipsisType | None
...@@ -47,22 +47,24 @@ class XReduce(XOp): ...@@ -47,22 +47,24 @@ class XReduce(XOp):
return Apply(self, [x], [output]) return Apply(self, [x], [output])
def _process_user_dims(x, dim: REDUCE_DIM) -> Sequence[str]: def _process_user_dims(x: XTensorVariable, dim: REDUCE_DIM) -> Sequence[str]:
if isinstance(dim, str): if isinstance(dim, str):
return (dim,) return (dim,)
elif dim is None or dim is Ellipsis: elif dim is None or dim is Ellipsis:
x = as_xtensor(x)
return typing.cast(tuple[str], x.type.dims) return typing.cast(tuple[str], x.type.dims)
return dim return dim
def reduce(x, dim: REDUCE_DIM = None, *, binary_op): def reduce(x, dim: REDUCE_DIM = None, *, binary_op, upcast_discrete_inp: bool = False):
x = as_xtensor(x)
dims = _process_user_dims(x, dim) dims = _process_user_dims(x, dim)
if upcast_discrete_inp and ((x_kind := x.type.numpy_dtype.kind) in "ibu"):
x = x.astype("uint64" if x_kind == "u" else "int64")
return XReduce(binary_op=binary_op, dims=dims)(x) return XReduce(binary_op=binary_op, dims=dims)(x)
sum = partial(reduce, binary_op=ps.add) sum = partial(reduce, binary_op=ps.add, upcast_discrete_inp=True)
prod = partial(reduce, binary_op=ps.mul) prod = partial(reduce, binary_op=ps.mul, upcast_discrete_inp=True)
max = partial(reduce, binary_op=ps.maximum) max = partial(reduce, binary_op=ps.maximum)
min = partial(reduce, binary_op=ps.minimum) min = partial(reduce, binary_op=ps.minimum)
...@@ -117,7 +119,10 @@ class XCumReduce(XOp): ...@@ -117,7 +119,10 @@ class XCumReduce(XOp):
def cumreduce(x, dim: REDUCE_DIM, *, binary_op): def cumreduce(x, dim: REDUCE_DIM, *, binary_op):
x = as_xtensor(x)
dims = _process_user_dims(x, dim) dims = _process_user_dims(x, dim)
if (x_kind := x.type.numpy_dtype.kind) in "ibu":
x = x.astype("uint64" if x_kind == "u" else "int64")
return XCumReduce(dims=dims, binary_op=binary_op)(x) return XCumReduce(dims=dims, binary_op=binary_op)(x)
......
...@@ -4,6 +4,9 @@ import pytest ...@@ -4,6 +4,9 @@ import pytest
pytest.importorskip("xarray") pytest.importorskip("xarray")
pytestmark = pytest.mark.filterwarnings("error") pytestmark = pytest.mark.filterwarnings("error")
import numpy as np
import xarray as xr
from pytensor.xtensor.type import xtensor from pytensor.xtensor.type import xtensor
from tests.xtensor.util import xr_arange_like, xr_assert_allclose, xr_function from tests.xtensor.util import xr_arange_like, xr_assert_allclose, xr_function
...@@ -52,3 +55,47 @@ def test_std_var(method, dim): ...@@ -52,3 +55,47 @@ def test_std_var(method, dim):
results[1], results[1],
getattr(x_test, method)(dim=dim, ddof=2), getattr(x_test, method)(dim=dim, ddof=2),
) )
@pytest.mark.parametrize("signed", [True, False])
def test_discrete_reduction_upcasting(signed):
# Test that sum, prod reductions on discrete inputs are upcast to prevent overflow
# This is also a regression test for lower_xtensor, which would raise by returning a different dtype
in_dtype = "int8" if signed else "uint8"
out_dtype = "int64" if signed else "uint64"
test_val = 127 if signed else 255 # max value allowed by in_dtype
x = xtensor("x", dtype=in_dtype, dims=("a",), shape=(2,))
x_val = xr.DataArray(np.array([test_val, test_val], dtype=in_dtype), dims="a")
assert x_val.dtype == in_dtype
# sum
out = x.sum()
assert out.dtype == out_dtype
fn = xr_function([x], out)
res = fn(x_val)
assert res == test_val * 2
xr_assert_allclose(res, x_val.sum())
# prod
out = x.prod()
assert out.dtype == out_dtype
fn = xr_function([x], out)
res = fn(x_val)
assert res == test_val**2
xr_assert_allclose(res, x_val.prod())
# cumsum
out = x.cumsum()
assert out.dtype == out_dtype
fn = xr_function([x], out)
res = fn(x_val)
np.testing.assert_allclose(res, [test_val, test_val * 2])
xr_assert_allclose(res, x_val.cumsum())
# cumprod
out = x.cumprod()
assert out.dtype == out_dtype
fn = xr_function([x], out)
res = fn(x_val)
np.testing.assert_allclose(res, [test_val, test_val**2])
xr_assert_allclose(res, x_val.cumprod())
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论