提交 18449731 authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Fix discrete xtensor reduction

上级 4ce8ef3c
......@@ -9,7 +9,7 @@ from pytensor.tensor.math import variadic_mul
from pytensor.xtensor.basic import XOp
from pytensor.xtensor.math import neq, sqrt
from pytensor.xtensor.math import sqr as square
from pytensor.xtensor.type import as_xtensor, xtensor
from pytensor.xtensor.type import XTensorVariable, as_xtensor, xtensor
REDUCE_DIM = str | Sequence[str] | EllipsisType | None
......@@ -47,22 +47,24 @@ class XReduce(XOp):
return Apply(self, [x], [output])
def _process_user_dims(x, dim: REDUCE_DIM) -> Sequence[str]:
def _process_user_dims(x: XTensorVariable, dim: REDUCE_DIM) -> Sequence[str]:
if isinstance(dim, str):
return (dim,)
elif dim is None or dim is Ellipsis:
x = as_xtensor(x)
return typing.cast(tuple[str], x.type.dims)
return dim
def reduce(x, dim: REDUCE_DIM = None, *, binary_op):
def reduce(x, dim: REDUCE_DIM = None, *, binary_op, upcast_discrete_inp: bool = False):
x = as_xtensor(x)
dims = _process_user_dims(x, dim)
if upcast_discrete_inp and ((x_kind := x.type.numpy_dtype.kind) in "ibu"):
x = x.astype("uint64" if x_kind == "u" else "int64")
return XReduce(binary_op=binary_op, dims=dims)(x)
sum = partial(reduce, binary_op=ps.add)
prod = partial(reduce, binary_op=ps.mul)
sum = partial(reduce, binary_op=ps.add, upcast_discrete_inp=True)
prod = partial(reduce, binary_op=ps.mul, upcast_discrete_inp=True)
max = partial(reduce, binary_op=ps.maximum)
min = partial(reduce, binary_op=ps.minimum)
......@@ -117,7 +119,10 @@ class XCumReduce(XOp):
def cumreduce(x, dim: REDUCE_DIM, *, binary_op):
x = as_xtensor(x)
dims = _process_user_dims(x, dim)
if (x_kind := x.type.numpy_dtype.kind) in "ibu":
x = x.astype("uint64" if x_kind == "u" else "int64")
return XCumReduce(dims=dims, binary_op=binary_op)(x)
......
......@@ -4,6 +4,9 @@ import pytest
pytest.importorskip("xarray")
pytestmark = pytest.mark.filterwarnings("error")
import numpy as np
import xarray as xr
from pytensor.xtensor.type import xtensor
from tests.xtensor.util import xr_arange_like, xr_assert_allclose, xr_function
......@@ -52,3 +55,47 @@ def test_std_var(method, dim):
results[1],
getattr(x_test, method)(dim=dim, ddof=2),
)
@pytest.mark.parametrize("signed", [True, False])
def test_discrete_reduction_upcasting(signed):
# Test that sum, prod reductions on discrete inputs are upcast to prevent overflow
# This is also a regression test for lower_xtensor, which would raise by returning a different dtype
in_dtype = "int8" if signed else "uint8"
out_dtype = "int64" if signed else "uint64"
test_val = 127 if signed else 255 # max value allowed by in_dtype
x = xtensor("x", dtype=in_dtype, dims=("a",), shape=(2,))
x_val = xr.DataArray(np.array([test_val, test_val], dtype=in_dtype), dims="a")
assert x_val.dtype == in_dtype
# sum
out = x.sum()
assert out.dtype == out_dtype
fn = xr_function([x], out)
res = fn(x_val)
assert res == test_val * 2
xr_assert_allclose(res, x_val.sum())
# prod
out = x.prod()
assert out.dtype == out_dtype
fn = xr_function([x], out)
res = fn(x_val)
assert res == test_val**2
xr_assert_allclose(res, x_val.prod())
# cumsum
out = x.cumsum()
assert out.dtype == out_dtype
fn = xr_function([x], out)
res = fn(x_val)
np.testing.assert_allclose(res, [test_val, test_val * 2])
xr_assert_allclose(res, x_val.cumsum())
# cumprod
out = x.cumprod()
assert out.dtype == out_dtype
fn = xr_function([x], out)
res = fn(x_val)
np.testing.assert_allclose(res, [test_val, test_val**2])
xr_assert_allclose(res, x_val.cumprod())
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论