提交 e92e5e30 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Update docstrings for Numba reducer functions

上级 d6fecc46
...@@ -505,13 +505,35 @@ def create_axis_reducer( ...@@ -505,13 +505,35 @@ def create_axis_reducer(
dtype: numba.types.Type, dtype: numba.types.Type,
keepdims: bool = False, keepdims: bool = False,
) -> numba.core.dispatcher.Dispatcher: ) -> numba.core.dispatcher.Dispatcher:
"""Create a Numba JITed function that performs a NumPy reduction on a given axis. r"""Create a Numba JITed function that performs a NumPy reduction on a given axis.
The functions generated by this function take the following form:
.. code-block:: python
def careduce_axis(x):
res_shape = tuple(shape[i] if i < axis else shape[i + 1] for i in range(ndim - 1))
res = np.full(res_shape, identity, dtype=dtype)
x_axis_first = x.transpose(reaxis_first)
for m in range(x.shape[axis]):
reduce_fn(res, x_axis_first[m], res)
if keepdims:
return np.expand_dims(res, axis)
else:
return res
This can be removed/replaced when
https://github.com/numba/numba/issues/4504 is implemented.
Parameters Parameters
========== ==========
reduce_fn: reduce_fn:
The Numba ``ufunc`` representing a binary op that can perform the The Numba ``ufunc`` representing a binary op that can perform the
reduction on arbitrary ``ndarray``s. reduction on arbitrary ``ndarray``\s.
identity: identity:
The identity value for the reduction. The identity value for the reduction.
axis: axis:
...@@ -581,6 +603,38 @@ def create_axis_reducer( ...@@ -581,6 +603,38 @@ def create_axis_reducer(
def create_multiaxis_reducer( def create_multiaxis_reducer(
reduce_fn, identity, axes, ndim, dtype, input_name="input" reduce_fn, identity, axes, ndim, dtype, input_name="input"
): ):
r"""Construct a function that reduces multiple axes.
The functions generated by this function take the following form:
.. code-block:: python
def careduce_maximum(input):
axis_0_res = careduce_axes_fn_0(input)
axis_1_res = careduce_axes_fn_1(axis_0_res)
...
axis_N_res = careduce_axes_fn_N(axis_N_minus_1_res)
return axis_N_res
The range 0-N is determined by the `axes` argument (i.e. the
axes to be reduced).
Parameters
==========
reduce_fn:
The Numba ``ufunc`` representing a binary op that can perform the
reduction on arbitrary ``ndarray``\s.
identity:
The identity value for the reduction.
axes:
The axes to reduce.
ndim:
The number of dimensions of the result.
dtype:
The data type of the result.
"""
if len(axes) == 1: if len(axes) == 1:
return create_axis_reducer(reduce_fn, identity, axes[0], ndim, dtype) return create_axis_reducer(reduce_fn, identity, axes[0], ndim, dtype)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论