提交 fdb40877 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Fix mean, var and std of XTensorVariables

上级 b4522d23
......@@ -81,7 +81,7 @@ any = partial(bool_reduce, binary_op=ps.or_)
def _infer_reduced_size(original_var, reduced_var):
reduced_dims = reduced_var.dims
return variadic_mul(
*[size for dim, size in original_var.sizes if dim not in reduced_dims]
*[size for dim, size in original_var.sizes.items() if dim not in reduced_dims]
)
......@@ -96,7 +96,7 @@ def var(x, dim: REDUCE_DIM, *, ddof: int = 0):
x = as_xtensor(x)
x_mean = mean(x, dim)
n = _infer_reduced_size(x, x_mean)
return square(x - x_mean) / (n - ddof)
return square(x - x_mean).sum(dim) / (n - ddof)
def std(x, dim: REDUCE_DIM, *, ddof: int = 0):
......
......@@ -692,11 +692,11 @@ class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]):
def sum(self, dim=None):
return px.reduction.sum(self, dim)
def std(self, dim=None):
return px.reduction.std(self, dim)
def std(self, dim=None, ddof=0):
return px.reduction.std(self, dim, ddof=ddof)
def var(self, dim=None):
return px.reduction.var(self, dim)
def var(self, dim=None, ddof=0):
return px.reduction.var(self, dim, ddof=ddof)
def cumsum(self, dim=None):
return px.reduction.cumsum(self, dim)
......
......@@ -12,7 +12,8 @@ from tests.xtensor.util import xr_arange_like, xr_assert_allclose, xr_function
"dim", [..., None, "a", ("c", "a")], ids=["Ellipsis", "None", "a", "(a, c)"]
)
@pytest.mark.parametrize(
"method", ["sum", "prod", "all", "any", "max", "min", "cumsum", "cumprod"][2:]
"method",
["sum", "prod", "all", "any", "max", "min", "mean", "cumsum", "cumprod"],
)
def test_reduction(method, dim):
x = xtensor("x", dims=("a", "b", "c"), shape=(3, 5, 7))
......@@ -25,3 +26,29 @@ def test_reduction(method, dim):
fn(x_test),
getattr(x_test, method)(dim=dim),
)
@pytest.mark.parametrize(
"dim", [..., None, "a", ("c", "a")], ids=["Ellipsis", "None", "a", "(a, c)"]
)
@pytest.mark.parametrize("method", ["std", "var"])
def test_std_var(method, dim):
x = xtensor("x", dims=("a", "b", "c"), shape=(3, 5, 7))
out = [
getattr(x, method)(dim=dim),
getattr(x, method)(dim=dim, ddof=2),
]
fn = xr_function([x], out)
x_test = xr_arange_like(x)
results = fn(x_test)
xr_assert_allclose(
results[0],
getattr(x_test, method)(dim=dim),
)
xr_assert_allclose(
results[1],
getattr(x_test, method)(dim=dim, ddof=2),
)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论