提交 66760618 authored 作者: kc611's avatar kc611 提交者: Brandon T. Willard

Updated careduce_axis implementation for Numba

上级 39235a34
import inspect import inspect
from functools import singledispatch
from numbers import Number from numbers import Number
from textwrap import indent from textwrap import indent
from typing import Union from typing import Union
...@@ -8,7 +9,7 @@ import numpy as np ...@@ -8,7 +9,7 @@ import numpy as np
from numba.cpython.unsafe.tuple import tuple_setitem from numba.cpython.unsafe.tuple import tuple_setitem
from aesara import config from aesara import config
from aesara.graph.basic import Apply from aesara.graph.op import Op
from aesara.link.numba.dispatch import basic as numba_basic from aesara.link.numba.dispatch import basic as numba_basic
from aesara.link.numba.dispatch.basic import ( from aesara.link.numba.dispatch.basic import (
create_numba_signature, create_numba_signature,
...@@ -20,10 +21,22 @@ from aesara.link.utils import ( ...@@ -20,10 +21,22 @@ from aesara.link.utils import (
get_name_for_object, get_name_for_object,
unique_name_generator, unique_name_generator,
) )
from aesara.scalar.basic import (
AND,
OR,
XOR,
Add,
IntDiv,
Mul,
ScalarMaximum,
Sub,
TrueDiv,
)
from aesara.scalar.basic import add as add_as
from aesara.scalar.basic import scalar_maximum
from aesara.tensor.elemwise import CAReduce, DimShuffle, Elemwise from aesara.tensor.elemwise import CAReduce, DimShuffle, Elemwise
from aesara.tensor.math import MaxAndArgmax from aesara.tensor.math import MaxAndArgmax
from aesara.tensor.nnet.basic import LogSoftmax, Softmax, SoftmaxGrad from aesara.tensor.nnet.basic import LogSoftmax, Softmax, SoftmaxGrad
from aesara.tensor.type import tensor
def create_vectorize_func(op, node, use_signature=False, identity=None, **kwargs): def create_vectorize_func(op, node, use_signature=False, identity=None, **kwargs):
...@@ -107,8 +120,74 @@ def {inplace_elemwise_fn_name}({input_signature_str}): ...@@ -107,8 +120,74 @@ def {inplace_elemwise_fn_name}({input_signature_str}):
return elemwise_fn return elemwise_fn
@singledispatch
def scalar_in_place_fn(op: Op, idx: str, res: str, arr: str):
"""Return code for an in-place update on an array using a binary scalar :class:`Op`.
Parameters
----------
op
The scalar :class:`Op`
idx
The index of `res` that needs to be updated.
res
The symbol name for the first input and results/output.
arr
The symbol name for the second input.
"""
return f"{res}[{idx}] = {op.nfunc_spec[0]}({res}[{idx}], arr)"
@scalar_in_place_fn.register(Add)
def scalar_in_place_fn_Add(op, idx, res, arr):
return f"{res}[{idx}] += {arr}"
@scalar_in_place_fn.register(Mul)
def scalar_in_place_fn_Mul(op, idx, res, arr):
return f"{res}[{idx}] *= {arr}"
@scalar_in_place_fn.register(Sub)
def scalar_in_place_fn_Sub(op, idx, res, arr):
return f"{res}[{idx}] -= {arr}"
@scalar_in_place_fn.register(AND)
def scalar_in_place_fn_AND(op, idx, res, arr):
return f"{res}[{idx}] &= {arr}"
@scalar_in_place_fn.register(OR)
def scalar_in_place_fn_OR(op, idx, res, arr):
return f"{res}[{idx}] |= {arr}"
@scalar_in_place_fn.register(XOR)
def scalar_in_place_fn_XOR(op, idx, res, arr):
return f"{res}[{idx}] ^= {arr}"
@scalar_in_place_fn.register(TrueDiv)
def scalar_in_place_fn_TrueDiv(op, idx, res, arr):
return f"{res}[{idx}] /= {arr}"
@scalar_in_place_fn.register(IntDiv)
def scalar_in_place_fn_IntDiv(op, idx, res, arr):
return f"{res}[{idx}] //= {arr}"
@scalar_in_place_fn.register(ScalarMaximum)
def scalar_in_place_fn_ScalarMaximum(op, idx, res, arr):
return f"""
if {res}[{idx}] < {arr}:
{res}[{idx}] = {arr}
"""
def create_axis_reducer( def create_axis_reducer(
reduce_fn: numba.np.ufunc.dufunc.DUFunc, scalar_op: Op,
identity: Union[np.ndarray, Number], identity: Union[np.ndarray, Number],
axis: int, axis: int,
ndim: int, ndim: int,
...@@ -141,9 +220,8 @@ def create_axis_reducer( ...@@ -141,9 +220,8 @@ def create_axis_reducer(
Parameters Parameters
========== ==========
reduce_fn: scalar_op:
The Numba ``ufunc`` representing a binary op that can perform the The scalar :class:`Op` that performs the desired reduction.
reduction on arbitrary ``ndarray``\s.
identity: identity:
The identity value for the reduction. The identity value for the reduction.
axis: axis:
...@@ -155,64 +233,108 @@ def create_axis_reducer( ...@@ -155,64 +233,108 @@ def create_axis_reducer(
keepdims: keepdims:
Determines whether or not the reduced dimension is retained. Determines whether or not the reduced dimension is retained.
""" """
if ndim > 1:
reduce_elemwise_fn_name = "careduce_axis"
if ndim > 1:
res_shape_tuple_ctor = create_tuple_creator(
lambda i, shape: shape[i] if i < axis else shape[i + 1], ndim - 1
)
if keepdims: if keepdims:
set_out_dims = numba_basic.numba_njit(
lambda x: np.expand_dims(x, axis), inline="always"
)
else:
set_out_dims = numba_basic.numba_njit(lambda x: x, inline="always")
@numba_basic.numba_njit(inline="always") else:
def set_out_dims(x):
return np.expand_dims(x, axis) @numba_basic.numba_njit
def res_shape_tuple_ctor(args):
return 1
if keepdims:
set_out_dims = numba_basic.numba_njit(
lambda x: numba_basic.direct_cast(x, dtype), inline="always"
)
else: else:
set_out_dims = numba_basic.numba_njit(
lambda x: numba_basic.direct_cast(x[0], dtype), inline="always"
)
@numba_basic.numba_njit(inline="always") identity = str(identity)
def set_out_dims(x): if identity == "inf":
return x identity = "np.inf"
elif identity == "-inf":
identity = "-np.inf"
res_shape_tuple_ctor = create_tuple_creator( if ndim > 1:
lambda i, shape: shape[i] if i < axis else shape[i + 1], ndim - 1 res_indices = []
arr_indices = []
count = 0
for i in range(ndim):
if i == axis:
arr_indices.append("i")
else:
res_indices.append(f"idx_arr[{count}]")
arr_indices.append(f"idx_arr[{count}]")
count = count + 1
res_indices = ", ".join(res_indices)
arr_indices = ", ".join(arr_indices)
inplace_update_statement = scalar_in_place_fn(
scalar_op, res_indices, "res", f"x[{arr_indices}]"
) )
inplace_update_statement = indent(inplace_update_statement, " " * 4 * 3)
reaxis_first = (axis,) + tuple(i for i in range(ndim) if i != axis) reduce_elemwise_def_src = f"""
def {reduce_elemwise_fn_name}(x):
@numba_basic.numba_njit(boundscheck=False) res_shape = res_shape_tuple_ctor(x.shape)
def careduce_axis(x): res = np.full(res_shape, numba_basic.to_scalar({identity}))
res_shape = res_shape_tuple_ctor(x.shape)
x_axis_first = x.transpose(reaxis_first)
res = np.full(res_shape, numba_basic.to_scalar(identity), dtype=dtype) axis_shape = x.shape[{axis}]
for m in numba.prange(x.shape[axis]):
reduce_fn(res, x_axis_first[m], res)
return set_out_dims(res) for idx_arr in np.ndindex(res_shape):
for i in range(axis_shape):
{inplace_update_statement}
return set_out_dims(res)
"""
else: else:
inplace_update_statement = scalar_in_place_fn(scalar_op, "0", "res", "x[i]")
inplace_update_statement = indent(inplace_update_statement, " " * 4 * 3)
if keepdims: reduce_elemwise_def_src = f"""
def {reduce_elemwise_fn_name}(x):
@numba_basic.numba_njit(inline="always") res_shape = res_shape_tuple_ctor(x.shape)
def set_out_dims(x): res = np.full(res_shape, numba_basic.to_scalar({identity}))
return np.array([x], dtype)
else: axis_shape = x.shape[{axis}]
@numba_basic.numba_njit(inline="always") for i in range(axis_shape):
def set_out_dims(x): {inplace_update_statement}
return numba_basic.direct_cast(x, dtype)
@numba_basic.numba_njit(boundscheck=False) return set_out_dims(res)
def careduce_axis(x): """
res = numba_basic.to_scalar(identity)
x_ravel = x.ravel()
for i in numba.prange(x_ravel.size):
res = reduce_fn(res, x_ravel[i])
return set_out_dims(res)
return careduce_axis global_env = {
"np": np,
"res_shape_tuple_ctor": res_shape_tuple_ctor,
"numba_basic": numba_basic,
"set_out_dims": set_out_dims,
}
reduce_elemwise_fn_py = compile_function_src(
reduce_elemwise_def_src, reduce_elemwise_fn_name, global_env
)
return numba_basic.numba_njit(boundscheck=False)(reduce_elemwise_fn_py)
def create_multiaxis_reducer( def create_multiaxis_reducer(
reduce_fn, identity, axes, ndim, dtype, input_name="input" scalar_op, identity, axes, ndim, dtype, input_name="input"
): ):
r"""Construct a function that reduces multiple axes. r"""Construct a function that reduces multiple axes.
...@@ -233,9 +355,8 @@ def create_multiaxis_reducer( ...@@ -233,9 +355,8 @@ def create_multiaxis_reducer(
Parameters Parameters
========== ==========
reduce_fn: scalar_op:
The Numba ``ufunc`` representing a binary op that can perform the The scalar :class:`Op` that performs the desired reduction.
reduction on arbitrary ``ndarray``\s.
identity: identity:
The identity value for the reduction. The identity value for the reduction.
axes: axes:
...@@ -247,9 +368,9 @@ def create_multiaxis_reducer( ...@@ -247,9 +368,9 @@ def create_multiaxis_reducer(
""" """
if len(axes) == 1: if len(axes) == 1:
return create_axis_reducer(reduce_fn, identity, axes[0], ndim, dtype) return create_axis_reducer(scalar_op, identity, axes[0], ndim, dtype)
careduce_fn_name = f"careduce_{get_name_for_object(reduce_fn)}" careduce_fn_name = f"careduce_{scalar_op}"
global_env = {} global_env = {}
to_reduce = reversed(sorted(axes)) to_reduce = reversed(sorted(axes))
careduce_lines_src = [] careduce_lines_src = []
...@@ -258,7 +379,7 @@ def create_multiaxis_reducer( ...@@ -258,7 +379,7 @@ def create_multiaxis_reducer(
for i, axis in enumerate(to_reduce): for i, axis in enumerate(to_reduce):
careducer_axes_fn_name = f"careduce_axes_fn_{i}" careducer_axes_fn_name = f"careduce_axes_fn_{i}"
global_env[careducer_axes_fn_name] = create_axis_reducer( global_env[careducer_axes_fn_name] = create_axis_reducer(
reduce_fn, identity, axis - i, ndim, dtype scalar_op, identity, axis, ndim, dtype
) )
ndim -= 1 ndim -= 1
last_var_name = var_name last_var_name = var_name
...@@ -311,24 +432,15 @@ def numba_funcify_CAReduce(op, node, **kwargs): ...@@ -311,24 +432,15 @@ def numba_funcify_CAReduce(op, node, **kwargs):
scalar_op_identity = np.asarray(op.scalar_op.identity, dtype=np_acc_dtype) scalar_op_identity = np.asarray(op.scalar_op.identity, dtype=np_acc_dtype)
scalar_nfunc_spec = op.scalar_op.nfunc_spec
# We construct a dummy `Apply` that has the minimum required number of
# inputs for the scalar `Op`. Without this, we would get a scalar function
# with too few arguments.
dummy_node = Apply(
op,
[tensor(np_acc_dtype, [False]) for i in range(scalar_nfunc_spec[1])],
[tensor(np_acc_dtype, [False]) for o in range(scalar_nfunc_spec[2])],
)
# TODO: Use `scalar_op_identity`?
elemwise_fn = create_vectorize_func(op, dummy_node, use_signature=True, **kwargs)
input_name = get_name_for_object(node.inputs[0]) input_name = get_name_for_object(node.inputs[0])
ndim = node.inputs[0].ndim ndim = node.inputs[0].ndim
careduce_fn = create_multiaxis_reducer( careduce_fn = create_multiaxis_reducer(
elemwise_fn, scalar_op_identity, axes, ndim, np_acc_dtype, input_name=input_name op.scalar_op,
scalar_op_identity,
axes,
ndim,
np_acc_dtype,
input_name=input_name,
) )
return careduce_fn return careduce_fn
...@@ -422,10 +534,10 @@ def numba_funcify_Softmax(op, node, **kwargs): ...@@ -422,10 +534,10 @@ def numba_funcify_Softmax(op, node, **kwargs):
if axis is not None: if axis is not None:
reduce_max = create_axis_reducer( reduce_max = create_axis_reducer(
np.maximum, -np.inf, axis, x_at.ndim, x_dtype, keepdims=True scalar_maximum, -np.inf, axis, x_at.ndim, x_dtype, keepdims=True
) )
reduce_sum = create_axis_reducer( reduce_sum = create_axis_reducer(
np.add, 0.0, axis, x_at.ndim, x_dtype, keepdims=True add_as, 0.0, axis, x_at.ndim, x_dtype, keepdims=True
) )
else: else:
reduce_max = np.max reduce_max = np.max
...@@ -452,7 +564,7 @@ def numba_funcify_SoftmaxGrad(op, node, **kwargs): ...@@ -452,7 +564,7 @@ def numba_funcify_SoftmaxGrad(op, node, **kwargs):
axis = op.axis axis = op.axis
if axis is not None: if axis is not None:
reduce_sum = create_axis_reducer( reduce_sum = create_axis_reducer(
np.add, 0.0, axis, sm_at.ndim, sm_dtype, keepdims=True add_as, 0.0, axis, sm_at.ndim, sm_dtype, keepdims=True
) )
else: else:
reduce_sum = np.sum reduce_sum = np.sum
...@@ -477,10 +589,10 @@ def numba_funcify_LogSoftmax(op, node, **kwargs): ...@@ -477,10 +589,10 @@ def numba_funcify_LogSoftmax(op, node, **kwargs):
if axis is not None: if axis is not None:
reduce_max = create_axis_reducer( reduce_max = create_axis_reducer(
np.maximum, -np.inf, axis, x_at.ndim, x_dtype, keepdims=True scalar_maximum, -np.inf, axis, x_at.ndim, x_dtype, keepdims=True
) )
reduce_sum = create_axis_reducer( reduce_sum = create_axis_reducer(
np.add, 0.0, axis, x_at.ndim, x_dtype, keepdims=True add_as, 0.0, axis, x_at.ndim, x_dtype, keepdims=True
) )
else: else:
reduce_max = np.max reduce_max = np.max
...@@ -518,7 +630,7 @@ def numba_funcify_MaxAndArgmax(op, node, **kwargs): ...@@ -518,7 +630,7 @@ def numba_funcify_MaxAndArgmax(op, node, **kwargs):
keep_axes = tuple(i for i in range(x_ndim) if i not in axes) keep_axes = tuple(i for i in range(x_ndim) if i not in axes)
reduce_max = create_multiaxis_reducer( reduce_max = create_multiaxis_reducer(
np.maximum, -np.inf, axes, x_ndim, x_dtype scalar_maximum, -np.inf, axes, x_ndim, x_dtype
) )
reduced_x_ndim = x_ndim - len(axes) + 1 reduced_x_ndim = x_ndim - len(axes) + 1
argmax_axis = create_axis_apply_fn( argmax_axis = create_axis_apply_fn(
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论