提交 f277af71 authored 作者: Dhruvanshu-Joshi's avatar Dhruvanshu-Joshi 提交者: Ricardo Vieira

Implement median helper

上级 ed6ca162
......@@ -1566,6 +1566,48 @@ def std(input, axis=None, ddof=0, keepdims=False, corrected=False):
return ret
def median(x: TensorLike, axis=None) -> TensorVariable:
"""
Computes the median along the given axis(es) of a tensor `input`.
Parameters
----------
x: TensorVariable
The input tensor.
axis: None or int or (list of int) (see `Sum`)
Compute the median along this axis of the tensor.
None means all axes (like numpy).
"""
from pytensor.ifelse import ifelse
x = as_tensor_variable(x)
x_ndim = x.type.ndim
if axis is None:
axis = list(range(x_ndim))
else:
axis = list(normalize_axis_tuple(axis, x_ndim))
non_axis = [i for i in range(x_ndim) if i not in axis]
non_axis_shape = [x.shape[i] for i in non_axis]
# Put axis at the end and unravel them
x_raveled = x.transpose(*non_axis, *axis)
if len(axis) > 1:
x_raveled = x_raveled.reshape((*non_axis_shape, -1))
raveled_size = x_raveled.shape[-1]
k = raveled_size // 2
# Sort the input tensor along the specified axis and pick median value
x_sorted = x_raveled.sort(axis=-1)
k_values = x_sorted[..., k]
km1_values = x_sorted[..., k - 1]
even_median = (k_values + km1_values) / 2.0
odd_median = k_values.astype(even_median.type.dtype)
even_k = eq(mod(raveled_size, 2), 0)
return ifelse(even_k, even_median, odd_median, name="median")
@scalar_elemwise(symbolname="scalar_maximum")
def maximum(x, y):
"""elemwise maximum. See max for the maximum in one tensor"""
......@@ -3015,6 +3057,7 @@ __all__ = [
"sum",
"prod",
"mean",
"median",
"var",
"std",
"std",
......
......@@ -93,6 +93,7 @@ from pytensor.tensor.math import (
max_and_argmax,
maximum,
mean,
median,
min,
minimum,
mod,
......@@ -3735,3 +3736,33 @@ def test_nan_to_num(nan, posinf, neginf):
out,
np.nan_to_num(y, nan=nan, posinf=posinf, neginf=neginf),
)
@pytest.mark.parametrize(
"ndim, axis",
[
(2, None),
(2, 1),
(2, (0, 1)),
(3, None),
(3, (1, 2)),
(4, (1, 3, 0)),
],
)
def test_median(ndim, axis):
# Generate random data with both odd and even lengths
shape_even = np.arange(1, ndim + 1) * 2
shape_odd = shape_even - 1
data_even = np.random.rand(*shape_even)
data_odd = np.random.rand(*shape_odd)
x = tensor(dtype="float64", shape=(None,) * ndim)
f = function([x], median(x, axis=axis))
result_odd = f(data_odd)
result_even = f(data_even)
expected_odd = np.median(data_odd, axis=axis)
expected_even = np.median(data_even, axis=axis)
assert np.allclose(result_odd, expected_odd)
assert np.allclose(result_even, expected_even)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论