提交 5961b23f authored 作者: Adrian Seyboldt's avatar Adrian Seyboldt 提交者: Adrian Seyboldt

Use numpy function to normalize axis argument

上级 670b5caa
...@@ -6,6 +6,7 @@ from typing import Any, Callable, Optional, Union ...@@ -6,6 +6,7 @@ from typing import Any, Callable, Optional, Union
import numba import numba
import numpy as np import numpy as np
from numpy.core.numeric import normalize_axis_index, normalize_axis_tuple
from pytensor import config from pytensor import config
from pytensor.graph.basic import Apply from pytensor.graph.basic import Apply
...@@ -164,18 +165,6 @@ def create_vectorize_func( ...@@ -164,18 +165,6 @@ def create_vectorize_func(
return elemwise_fn return elemwise_fn
def normalize_axis(axis, ndim):
if axis is None:
return axis
if axis < 0:
axis = ndim + axis
if axis < 0 or axis >= ndim:
raise np.AxisError(ndim=ndim, axis=axis)
return axis
def create_axis_reducer( def create_axis_reducer(
scalar_op: Op, scalar_op: Op,
identity: Union[np.ndarray, Number], identity: Union[np.ndarray, Number],
...@@ -230,7 +219,7 @@ def create_axis_reducer( ...@@ -230,7 +219,7 @@ def create_axis_reducer(
""" """
axis = normalize_axis(axis, ndim) axis = normalize_axis_index(axis, ndim)
reduce_elemwise_fn_name = "careduce_axis" reduce_elemwise_fn_name = "careduce_axis"
...@@ -354,7 +343,7 @@ def create_multiaxis_reducer( ...@@ -354,7 +343,7 @@ def create_multiaxis_reducer(
if len(axes) == 1: if len(axes) == 1:
return create_axis_reducer(scalar_op, identity, axes[0], ndim, dtype) return create_axis_reducer(scalar_op, identity, axes[0], ndim, dtype)
axes = [normalize_axis(axis, ndim) for axis in axes] axes = normalize_axis_tuple(axes, ndim)
careduce_fn_name = f"careduce_{scalar_op}" careduce_fn_name = f"careduce_{scalar_op}"
global_env = {} global_env = {}
...@@ -425,7 +414,7 @@ def jit_compile_reducer(node, fn, **kwds): ...@@ -425,7 +414,7 @@ def jit_compile_reducer(node, fn, **kwds):
def create_axis_apply_fn(fn, axis, ndim, dtype): def create_axis_apply_fn(fn, axis, ndim, dtype):
axis = normalize_axis(axis, ndim) axis = normalize_axis_index(axis, ndim)
reaxis_first = tuple(i for i in range(ndim) if i != axis) + (axis,) reaxis_first = tuple(i for i in range(ndim) if i != axis) + (axis,)
...@@ -627,9 +616,8 @@ def numba_funcify_Softmax(op, node, **kwargs): ...@@ -627,9 +616,8 @@ def numba_funcify_Softmax(op, node, **kwargs):
x_dtype = numba.np.numpy_support.from_dtype(x_dtype) x_dtype = numba.np.numpy_support.from_dtype(x_dtype)
axis = op.axis axis = op.axis
axis = normalize_axis(axis, x_at.ndim)
if axis is not None: if axis is not None:
axis = normalize_axis_index(axis, x_at.ndim)
reduce_max_py = create_axis_reducer( reduce_max_py = create_axis_reducer(
scalar_maximum, -np.inf, axis, x_at.ndim, x_dtype, keepdims=True scalar_maximum, -np.inf, axis, x_at.ndim, x_dtype, keepdims=True
) )
...@@ -666,8 +654,8 @@ def numba_funcify_SoftmaxGrad(op, node, **kwargs): ...@@ -666,8 +654,8 @@ def numba_funcify_SoftmaxGrad(op, node, **kwargs):
sm_dtype = numba.np.numpy_support.from_dtype(sm_dtype) sm_dtype = numba.np.numpy_support.from_dtype(sm_dtype)
axis = op.axis axis = op.axis
axis = normalize_axis(axis, sm_at.ndim)
if axis is not None: if axis is not None:
axis = normalize_axis_index(axis, sm_at.ndim)
reduce_sum_py = create_axis_reducer( reduce_sum_py = create_axis_reducer(
add_as, 0.0, axis, sm_at.ndim, sm_dtype, keepdims=True add_as, 0.0, axis, sm_at.ndim, sm_dtype, keepdims=True
) )
...@@ -697,9 +685,9 @@ def numba_funcify_LogSoftmax(op, node, **kwargs): ...@@ -697,9 +685,9 @@ def numba_funcify_LogSoftmax(op, node, **kwargs):
x_dtype = x_at.type.numpy_dtype x_dtype = x_at.type.numpy_dtype
x_dtype = numba.np.numpy_support.from_dtype(x_dtype) x_dtype = numba.np.numpy_support.from_dtype(x_dtype)
axis = op.axis axis = op.axis
axis = normalize_axis(axis, x_at.ndim)
if axis is not None: if axis is not None:
axis = normalize_axis_index(axis, x_at.ndim)
reduce_max_py = create_axis_reducer( reduce_max_py = create_axis_reducer(
scalar_maximum, -np.inf, axis, x_at.ndim, x_dtype, keepdims=True scalar_maximum, -np.inf, axis, x_at.ndim, x_dtype, keepdims=True
) )
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论