提交 0933d203 authored 作者: Adrian Seyboldt's avatar Adrian Seyboldt 提交者: Adrian Seyboldt

Normalize negative axes

上级 ee9a6ff2
......@@ -164,6 +164,18 @@ def create_vectorize_func(
return elemwise_fn
def normalize_axis(axis, ndim):
if axis is None:
return axis
if axis < 0:
axis = ndim + axis
if axis < 0 or axis >= ndim:
raise np.AxisError(ndim=ndim, axis=axis)
return axis
def create_axis_reducer(
scalar_op: Op,
identity: Union[np.ndarray, Number],
......@@ -218,6 +230,8 @@ def create_axis_reducer(
"""
axis = normalize_axis(axis, ndim)
reduce_elemwise_fn_name = "careduce_axis"
identity = str(identity)
......@@ -340,6 +354,8 @@ def create_multiaxis_reducer(
if len(axes) == 1:
return create_axis_reducer(scalar_op, identity, axes[0], ndim, dtype)
axes = [normalize_axis(axis, ndim) for axis in axes]
careduce_fn_name = f"careduce_{scalar_op}"
global_env = {}
to_reduce = reversed(sorted(axes))
......@@ -409,6 +425,8 @@ def jit_compile_reducer(node, fn, **kwds):
def create_axis_apply_fn(fn, axis, ndim, dtype):
axis = normalize_axis(axis, ndim)
reaxis_first = tuple(i for i in range(ndim) if i != axis) + (axis,)
@numba_basic.numba_njit(boundscheck=False)
......@@ -609,6 +627,8 @@ def numba_funcify_Softmax(op, node, **kwargs):
x_dtype = numba.np.numpy_support.from_dtype(x_dtype)
axis = op.axis
axis = normalize_axis(axis, x_at.ndim)
if axis is not None:
reduce_max_py = create_axis_reducer(
scalar_maximum, -np.inf, axis, x_at.ndim, x_dtype, keepdims=True
......@@ -646,6 +666,7 @@ def numba_funcify_SoftmaxGrad(op, node, **kwargs):
sm_dtype = numba.np.numpy_support.from_dtype(sm_dtype)
axis = op.axis
axis = normalize_axis(axis, sm_at.ndim)
if axis is not None:
reduce_sum_py = create_axis_reducer(
add_as, 0.0, axis, sm_at.ndim, sm_dtype, keepdims=True
......@@ -676,6 +697,7 @@ def numba_funcify_LogSoftmax(op, node, **kwargs):
x_dtype = x_at.type.numpy_dtype
x_dtype = numba.np.numpy_support.from_dtype(x_dtype)
axis = op.axis
axis = normalize_axis(axis, x_at.ndim)
if axis is not None:
reduce_max_py = create_axis_reducer(
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论