提交 66760618 authored 作者: kc611's avatar kc611 提交者: Brandon T. Willard

Updated careduce_axis implementation for Numba

上级 39235a34
import inspect
from functools import singledispatch
from numbers import Number
from textwrap import indent
from typing import Union
......@@ -8,7 +9,7 @@ import numpy as np
from numba.cpython.unsafe.tuple import tuple_setitem
from aesara import config
from aesara.graph.basic import Apply
from aesara.graph.op import Op
from aesara.link.numba.dispatch import basic as numba_basic
from aesara.link.numba.dispatch.basic import (
create_numba_signature,
......@@ -20,10 +21,22 @@ from aesara.link.utils import (
get_name_for_object,
unique_name_generator,
)
from aesara.scalar.basic import (
AND,
OR,
XOR,
Add,
IntDiv,
Mul,
ScalarMaximum,
Sub,
TrueDiv,
)
from aesara.scalar.basic import add as add_as
from aesara.scalar.basic import scalar_maximum
from aesara.tensor.elemwise import CAReduce, DimShuffle, Elemwise
from aesara.tensor.math import MaxAndArgmax
from aesara.tensor.nnet.basic import LogSoftmax, Softmax, SoftmaxGrad
from aesara.tensor.type import tensor
def create_vectorize_func(op, node, use_signature=False, identity=None, **kwargs):
......@@ -107,8 +120,74 @@ def {inplace_elemwise_fn_name}({input_signature_str}):
return elemwise_fn
@singledispatch
def scalar_in_place_fn(op: Op, idx: str, res: str, arr: str):
"""Return code for an in-place update on an array using a binary scalar :class:`Op`.
Parameters
----------
op
The scalar :class:`Op`
idx
The index of `res` that needs to be updated.
res
The symbol name for the first input and results/output.
arr
The symbol name for the second input.
"""
return f"{res}[{idx}] = {op.nfunc_spec[0]}({res}[{idx}], arr)"
@scalar_in_place_fn.register(Add)
def scalar_in_place_fn_Add(op, idx, res, arr):
return f"{res}[{idx}] += {arr}"
@scalar_in_place_fn.register(Mul)
def scalar_in_place_fn_Mul(op, idx, res, arr):
return f"{res}[{idx}] *= {arr}"
@scalar_in_place_fn.register(Sub)
def scalar_in_place_fn_Sub(op, idx, res, arr):
return f"{res}[{idx}] -= {arr}"
@scalar_in_place_fn.register(AND)
def scalar_in_place_fn_AND(op, idx, res, arr):
return f"{res}[{idx}] &= {arr}"
@scalar_in_place_fn.register(OR)
def scalar_in_place_fn_OR(op, idx, res, arr):
return f"{res}[{idx}] |= {arr}"
@scalar_in_place_fn.register(XOR)
def scalar_in_place_fn_XOR(op, idx, res, arr):
return f"{res}[{idx}] ^= {arr}"
@scalar_in_place_fn.register(TrueDiv)
def scalar_in_place_fn_TrueDiv(op, idx, res, arr):
return f"{res}[{idx}] /= {arr}"
@scalar_in_place_fn.register(IntDiv)
def scalar_in_place_fn_IntDiv(op, idx, res, arr):
return f"{res}[{idx}] //= {arr}"
@scalar_in_place_fn.register(ScalarMaximum)
def scalar_in_place_fn_ScalarMaximum(op, idx, res, arr):
return f"""
if {res}[{idx}] < {arr}:
{res}[{idx}] = {arr}
"""
def create_axis_reducer(
reduce_fn: numba.np.ufunc.dufunc.DUFunc,
scalar_op: Op,
identity: Union[np.ndarray, Number],
axis: int,
ndim: int,
......@@ -141,9 +220,8 @@ def create_axis_reducer(
Parameters
==========
reduce_fn:
The Numba ``ufunc`` representing a binary op that can perform the
reduction on arbitrary ``ndarray``\s.
scalar_op:
The scalar :class:`Op` that performs the desired reduction.
identity:
The identity value for the reduction.
axis:
......@@ -155,64 +233,108 @@ def create_axis_reducer(
keepdims:
Determines whether or not the reduced dimension is retained.
"""
if ndim > 1:
reduce_elemwise_fn_name = "careduce_axis"
if ndim > 1:
res_shape_tuple_ctor = create_tuple_creator(
lambda i, shape: shape[i] if i < axis else shape[i + 1], ndim - 1
)
if keepdims:
set_out_dims = numba_basic.numba_njit(
lambda x: np.expand_dims(x, axis), inline="always"
)
else:
set_out_dims = numba_basic.numba_njit(lambda x: x, inline="always")
@numba_basic.numba_njit(inline="always")
def set_out_dims(x):
return np.expand_dims(x, axis)
else:
@numba_basic.numba_njit
def res_shape_tuple_ctor(args):
return 1
if keepdims:
set_out_dims = numba_basic.numba_njit(
lambda x: numba_basic.direct_cast(x, dtype), inline="always"
)
else:
set_out_dims = numba_basic.numba_njit(
lambda x: numba_basic.direct_cast(x[0], dtype), inline="always"
)
@numba_basic.numba_njit(inline="always")
def set_out_dims(x):
return x
identity = str(identity)
if identity == "inf":
identity = "np.inf"
elif identity == "-inf":
identity = "-np.inf"
res_shape_tuple_ctor = create_tuple_creator(
lambda i, shape: shape[i] if i < axis else shape[i + 1], ndim - 1
if ndim > 1:
res_indices = []
arr_indices = []
count = 0
for i in range(ndim):
if i == axis:
arr_indices.append("i")
else:
res_indices.append(f"idx_arr[{count}]")
arr_indices.append(f"idx_arr[{count}]")
count = count + 1
res_indices = ", ".join(res_indices)
arr_indices = ", ".join(arr_indices)
inplace_update_statement = scalar_in_place_fn(
scalar_op, res_indices, "res", f"x[{arr_indices}]"
)
inplace_update_statement = indent(inplace_update_statement, " " * 4 * 3)
reaxis_first = (axis,) + tuple(i for i in range(ndim) if i != axis)
reduce_elemwise_def_src = f"""
def {reduce_elemwise_fn_name}(x):
@numba_basic.numba_njit(boundscheck=False)
def careduce_axis(x):
res_shape = res_shape_tuple_ctor(x.shape)
x_axis_first = x.transpose(reaxis_first)
res_shape = res_shape_tuple_ctor(x.shape)
res = np.full(res_shape, numba_basic.to_scalar({identity}))
res = np.full(res_shape, numba_basic.to_scalar(identity), dtype=dtype)
for m in numba.prange(x.shape[axis]):
reduce_fn(res, x_axis_first[m], res)
axis_shape = x.shape[{axis}]
return set_out_dims(res)
for idx_arr in np.ndindex(res_shape):
for i in range(axis_shape):
{inplace_update_statement}
return set_out_dims(res)
"""
else:
inplace_update_statement = scalar_in_place_fn(scalar_op, "0", "res", "x[i]")
inplace_update_statement = indent(inplace_update_statement, " " * 4 * 3)
if keepdims:
reduce_elemwise_def_src = f"""
def {reduce_elemwise_fn_name}(x):
@numba_basic.numba_njit(inline="always")
def set_out_dims(x):
return np.array([x], dtype)
res_shape = res_shape_tuple_ctor(x.shape)
res = np.full(res_shape, numba_basic.to_scalar({identity}))
else:
axis_shape = x.shape[{axis}]
@numba_basic.numba_njit(inline="always")
def set_out_dims(x):
return numba_basic.direct_cast(x, dtype)
for i in range(axis_shape):
{inplace_update_statement}
@numba_basic.numba_njit(boundscheck=False)
def careduce_axis(x):
res = numba_basic.to_scalar(identity)
x_ravel = x.ravel()
for i in numba.prange(x_ravel.size):
res = reduce_fn(res, x_ravel[i])
return set_out_dims(res)
return set_out_dims(res)
"""
return careduce_axis
global_env = {
"np": np,
"res_shape_tuple_ctor": res_shape_tuple_ctor,
"numba_basic": numba_basic,
"set_out_dims": set_out_dims,
}
reduce_elemwise_fn_py = compile_function_src(
reduce_elemwise_def_src, reduce_elemwise_fn_name, global_env
)
return numba_basic.numba_njit(boundscheck=False)(reduce_elemwise_fn_py)
def create_multiaxis_reducer(
reduce_fn, identity, axes, ndim, dtype, input_name="input"
scalar_op, identity, axes, ndim, dtype, input_name="input"
):
r"""Construct a function that reduces multiple axes.
......@@ -233,9 +355,8 @@ def create_multiaxis_reducer(
Parameters
==========
reduce_fn:
The Numba ``ufunc`` representing a binary op that can perform the
reduction on arbitrary ``ndarray``\s.
scalar_op:
The scalar :class:`Op` that performs the desired reduction.
identity:
The identity value for the reduction.
axes:
......@@ -247,9 +368,9 @@ def create_multiaxis_reducer(
"""
if len(axes) == 1:
return create_axis_reducer(reduce_fn, identity, axes[0], ndim, dtype)
return create_axis_reducer(scalar_op, identity, axes[0], ndim, dtype)
careduce_fn_name = f"careduce_{get_name_for_object(reduce_fn)}"
careduce_fn_name = f"careduce_{scalar_op}"
global_env = {}
to_reduce = reversed(sorted(axes))
careduce_lines_src = []
......@@ -258,7 +379,7 @@ def create_multiaxis_reducer(
for i, axis in enumerate(to_reduce):
careducer_axes_fn_name = f"careduce_axes_fn_{i}"
global_env[careducer_axes_fn_name] = create_axis_reducer(
reduce_fn, identity, axis - i, ndim, dtype
scalar_op, identity, axis, ndim, dtype
)
ndim -= 1
last_var_name = var_name
......@@ -311,24 +432,15 @@ def numba_funcify_CAReduce(op, node, **kwargs):
scalar_op_identity = np.asarray(op.scalar_op.identity, dtype=np_acc_dtype)
scalar_nfunc_spec = op.scalar_op.nfunc_spec
# We construct a dummy `Apply` that has the minimum required number of
# inputs for the scalar `Op`. Without this, we would get a scalar function
# with too few arguments.
dummy_node = Apply(
op,
[tensor(np_acc_dtype, [False]) for i in range(scalar_nfunc_spec[1])],
[tensor(np_acc_dtype, [False]) for o in range(scalar_nfunc_spec[2])],
)
# TODO: Use `scalar_op_identity`?
elemwise_fn = create_vectorize_func(op, dummy_node, use_signature=True, **kwargs)
input_name = get_name_for_object(node.inputs[0])
ndim = node.inputs[0].ndim
careduce_fn = create_multiaxis_reducer(
elemwise_fn, scalar_op_identity, axes, ndim, np_acc_dtype, input_name=input_name
op.scalar_op,
scalar_op_identity,
axes,
ndim,
np_acc_dtype,
input_name=input_name,
)
return careduce_fn
......@@ -422,10 +534,10 @@ def numba_funcify_Softmax(op, node, **kwargs):
if axis is not None:
reduce_max = create_axis_reducer(
np.maximum, -np.inf, axis, x_at.ndim, x_dtype, keepdims=True
scalar_maximum, -np.inf, axis, x_at.ndim, x_dtype, keepdims=True
)
reduce_sum = create_axis_reducer(
np.add, 0.0, axis, x_at.ndim, x_dtype, keepdims=True
add_as, 0.0, axis, x_at.ndim, x_dtype, keepdims=True
)
else:
reduce_max = np.max
......@@ -452,7 +564,7 @@ def numba_funcify_SoftmaxGrad(op, node, **kwargs):
axis = op.axis
if axis is not None:
reduce_sum = create_axis_reducer(
np.add, 0.0, axis, sm_at.ndim, sm_dtype, keepdims=True
add_as, 0.0, axis, sm_at.ndim, sm_dtype, keepdims=True
)
else:
reduce_sum = np.sum
......@@ -477,10 +589,10 @@ def numba_funcify_LogSoftmax(op, node, **kwargs):
if axis is not None:
reduce_max = create_axis_reducer(
np.maximum, -np.inf, axis, x_at.ndim, x_dtype, keepdims=True
scalar_maximum, -np.inf, axis, x_at.ndim, x_dtype, keepdims=True
)
reduce_sum = create_axis_reducer(
np.add, 0.0, axis, x_at.ndim, x_dtype, keepdims=True
add_as, 0.0, axis, x_at.ndim, x_dtype, keepdims=True
)
else:
reduce_max = np.max
......@@ -518,7 +630,7 @@ def numba_funcify_MaxAndArgmax(op, node, **kwargs):
keep_axes = tuple(i for i in range(x_ndim) if i not in axes)
reduce_max = create_multiaxis_reducer(
np.maximum, -np.inf, axes, x_ndim, x_dtype
scalar_maximum, -np.inf, axes, x_ndim, x_dtype
)
reduced_x_ndim = x_ndim - len(axes) + 1
argmax_axis = create_axis_apply_fn(
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论