提交 5961b23f authored 作者: Adrian Seyboldt's avatar Adrian Seyboldt 提交者: Adrian Seyboldt

Use numpy function to normalize axis argument

上级 670b5caa
......@@ -6,6 +6,7 @@ from typing import Any, Callable, Optional, Union
import numba
import numpy as np
from numpy.core.numeric import normalize_axis_index, normalize_axis_tuple
from pytensor import config
from pytensor.graph.basic import Apply
......@@ -164,18 +165,6 @@ def create_vectorize_func(
return elemwise_fn
def normalize_axis(axis, ndim):
if axis is None:
return axis
if axis < 0:
axis = ndim + axis
if axis < 0 or axis >= ndim:
raise np.AxisError(ndim=ndim, axis=axis)
return axis
def create_axis_reducer(
scalar_op: Op,
identity: Union[np.ndarray, Number],
......@@ -230,7 +219,7 @@ def create_axis_reducer(
"""
axis = normalize_axis(axis, ndim)
axis = normalize_axis_index(axis, ndim)
reduce_elemwise_fn_name = "careduce_axis"
......@@ -354,7 +343,7 @@ def create_multiaxis_reducer(
if len(axes) == 1:
return create_axis_reducer(scalar_op, identity, axes[0], ndim, dtype)
axes = [normalize_axis(axis, ndim) for axis in axes]
axes = normalize_axis_tuple(axes, ndim)
careduce_fn_name = f"careduce_{scalar_op}"
global_env = {}
......@@ -425,7 +414,7 @@ def jit_compile_reducer(node, fn, **kwds):
def create_axis_apply_fn(fn, axis, ndim, dtype):
axis = normalize_axis(axis, ndim)
axis = normalize_axis_index(axis, ndim)
reaxis_first = tuple(i for i in range(ndim) if i != axis) + (axis,)
......@@ -627,9 +616,8 @@ def numba_funcify_Softmax(op, node, **kwargs):
x_dtype = numba.np.numpy_support.from_dtype(x_dtype)
axis = op.axis
axis = normalize_axis(axis, x_at.ndim)
if axis is not None:
axis = normalize_axis_index(axis, x_at.ndim)
reduce_max_py = create_axis_reducer(
scalar_maximum, -np.inf, axis, x_at.ndim, x_dtype, keepdims=True
)
......@@ -666,8 +654,8 @@ def numba_funcify_SoftmaxGrad(op, node, **kwargs):
sm_dtype = numba.np.numpy_support.from_dtype(sm_dtype)
axis = op.axis
axis = normalize_axis(axis, sm_at.ndim)
if axis is not None:
axis = normalize_axis_index(axis, sm_at.ndim)
reduce_sum_py = create_axis_reducer(
add_as, 0.0, axis, sm_at.ndim, sm_dtype, keepdims=True
)
......@@ -697,9 +685,9 @@ def numba_funcify_LogSoftmax(op, node, **kwargs):
x_dtype = x_at.type.numpy_dtype
x_dtype = numba.np.numpy_support.from_dtype(x_dtype)
axis = op.axis
axis = normalize_axis(axis, x_at.ndim)
if axis is not None:
axis = normalize_axis_index(axis, x_at.ndim)
reduce_max_py = create_axis_reducer(
scalar_maximum, -np.inf, axis, x_at.ndim, x_dtype, keepdims=True
)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论