提交 33998b20 authored 作者: kc611's avatar kc611 提交者: Brandon T. Willard

Added more optimizations to the Numba cheap pass-manager

This only applies to reduction `Op`s (e.g. `CAReduce`).
上级 66760618
import operator import operator
import warnings import warnings
from contextlib import contextmanager
from functools import singledispatch from functools import singledispatch
import numba import numba
...@@ -57,14 +58,31 @@ def numba_vectorize(*args, **kwargs): ...@@ -57,14 +58,31 @@ def numba_vectorize(*args, **kwargs):
def get_numba_type( def get_numba_type(
aesara_type: Type, layout: str = "A", force_scalar: bool = False aesara_type: Type,
layout: str = "A",
force_scalar: bool = False,
reduce_to_scalar: bool = False,
) -> numba.types.Type: ) -> numba.types.Type:
"""Create a Numba type object for a ``Type``.""" r"""Create a Numba type object for a :class:`Type`.
Parameters
----------
aesara_type
The :class:`Type` to convert.
layout
The :class:`numpy.ndarray` layout to use.
force_scalar
Ignore dimension information and return the corresponding Numba scalar types.
reduce_to_scalar
Return Numba scalars for zero dimensional :class:`TensorType`\s.
"""
if isinstance(aesara_type, TensorType): if isinstance(aesara_type, TensorType):
dtype = aesara_type.numpy_dtype dtype = aesara_type.numpy_dtype
numba_dtype = numba.from_dtype(dtype) numba_dtype = numba.from_dtype(dtype)
if force_scalar: if force_scalar or (
reduce_to_scalar and getattr(aesara_type, "ndim", None) == 0
):
return numba_dtype return numba_dtype
return numba.types.Array(numba_dtype, aesara_type.ndim, layout) return numba.types.Array(numba_dtype, aesara_type.ndim, layout)
elif isinstance(aesara_type, Scalar): elif isinstance(aesara_type, Scalar):
...@@ -75,15 +93,25 @@ def get_numba_type( ...@@ -75,15 +93,25 @@ def get_numba_type(
raise NotImplementedError(f"Numba type not implemented for {aesara_type}") raise NotImplementedError(f"Numba type not implemented for {aesara_type}")
def create_numba_signature(node: Apply, force_scalar: bool = False) -> numba.types.Type: def create_numba_signature(
node: Apply, force_scalar: bool = False, reduce_to_scalar: bool = False
) -> numba.types.Type:
"""Create a Numba type for the signature of an ``Apply`` node.""" """Create a Numba type for the signature of an ``Apply`` node."""
input_types = [] input_types = []
for inp in node.inputs: for inp in node.inputs:
input_types.append(get_numba_type(inp.type, force_scalar=force_scalar)) input_types.append(
get_numba_type(
inp.type, force_scalar=force_scalar, reduce_to_scalar=reduce_to_scalar
)
)
output_types = [] output_types = []
for out in node.outputs: for out in node.outputs:
output_types.append(get_numba_type(out.type, force_scalar=force_scalar)) output_types.append(
get_numba_type(
out.type, force_scalar=force_scalar, reduce_to_scalar=reduce_to_scalar
)
)
if len(output_types) > 1: if len(output_types) > 1:
return numba.types.Tuple(output_types)(*input_types) return numba.types.Tuple(output_types)(*input_types)
...@@ -263,6 +291,23 @@ def create_arg_string(x): ...@@ -263,6 +291,23 @@ def create_arg_string(x):
return args return args
@contextmanager
def use_optimized_cheap_pass(*args, **kwargs):
"""Temporarily replace the cheap optimization pass with a better one."""
from numba.core.registry import cpu_target
context = cpu_target.target_context._internal_codegen
old_pm = context._mpm_cheap
new_pm = context._module_pass_manager(
loop_vectorize=True, slp_vectorize=True, opt=3, cost="cheap"
)
context._mpm_cheap = new_pm
try:
yield
finally:
context._mpm_cheap = old_pm
@singledispatch @singledispatch
def numba_typify(data, dtype=None, **kwargs): def numba_typify(data, dtype=None, **kwargs):
return data return data
......
...@@ -9,12 +9,14 @@ import numpy as np ...@@ -9,12 +9,14 @@ import numpy as np
from numba.cpython.unsafe.tuple import tuple_setitem from numba.cpython.unsafe.tuple import tuple_setitem
from aesara import config from aesara import config
from aesara.graph.basic import Apply
from aesara.graph.op import Op from aesara.graph.op import Op
from aesara.link.numba.dispatch import basic as numba_basic from aesara.link.numba.dispatch import basic as numba_basic
from aesara.link.numba.dispatch.basic import ( from aesara.link.numba.dispatch.basic import (
create_numba_signature, create_numba_signature,
create_tuple_creator, create_tuple_creator,
numba_funcify, numba_funcify,
use_optimized_cheap_pass,
) )
from aesara.link.utils import ( from aesara.link.utils import (
compile_function_src, compile_function_src,
...@@ -27,99 +29,20 @@ from aesara.scalar.basic import ( ...@@ -27,99 +29,20 @@ from aesara.scalar.basic import (
XOR, XOR,
Add, Add,
IntDiv, IntDiv,
Mean,
Mul, Mul,
ScalarMaximum, ScalarMaximum,
ScalarMinimum,
Sub, Sub,
TrueDiv, TrueDiv,
) )
from aesara.scalar.basic import add as add_as from aesara.scalar.basic import add as add_as
from aesara.scalar.basic import scalar_maximum from aesara.scalar.basic import scalar_maximum
from aesara.tensor.elemwise import CAReduce, DimShuffle, Elemwise from aesara.tensor.elemwise import CAReduce, DimShuffle, Elemwise
from aesara.tensor.math import MaxAndArgmax from aesara.tensor.math import MaxAndArgmax, MulWithoutZeros
from aesara.tensor.nnet.basic import LogSoftmax, Softmax, SoftmaxGrad from aesara.tensor.nnet.basic import LogSoftmax, Softmax, SoftmaxGrad
def create_vectorize_func(op, node, use_signature=False, identity=None, **kwargs):
scalar_op_fn = numba_funcify(op.scalar_op, node=node, inline="always", **kwargs)
if len(node.outputs) > 1:
raise NotImplementedError(
"Multi-output Elemwise Ops are not supported by the Numba backend"
)
if use_signature:
signature = [create_numba_signature(node, force_scalar=True)]
else:
signature = []
target = (
getattr(node.tag, "numba__vectorize_target", None)
or config.numba__vectorize_target
)
numba_vectorized_fn = numba_basic.numba_vectorize(
signature, identity=identity, target=target, fastmath=config.numba__fastmath
)
py_scalar_func = getattr(scalar_op_fn, "py_func", scalar_op_fn)
elemwise_fn = numba_vectorized_fn(scalar_op_fn)
elemwise_fn.py_scalar_func = py_scalar_func
return elemwise_fn
@numba_funcify.register(Elemwise)
def numba_funcify_Elemwise(op, node, **kwargs):
elemwise_fn = create_vectorize_func(op, node, use_signature=False)
elemwise_fn_name = elemwise_fn.__name__
if op.inplace_pattern:
input_idx = op.inplace_pattern[0]
sign_obj = inspect.signature(elemwise_fn.py_scalar_func)
input_names = list(sign_obj.parameters.keys())
unique_names = unique_name_generator([elemwise_fn_name, "np"], suffix_sep="_")
input_names = [unique_names(i, force_unique=True) for i in input_names]
updated_input_name = input_names[input_idx]
inplace_global_env = {elemwise_fn_name: elemwise_fn, "np": np}
inplace_elemwise_fn_name = f"{elemwise_fn_name}_inplace"
input_signature_str = ", ".join(input_names)
if node.inputs[input_idx].ndim > 0:
inplace_elemwise_src = f"""
def {inplace_elemwise_fn_name}({input_signature_str}):
return {elemwise_fn_name}({input_signature_str + ", " + updated_input_name})
"""
else:
# We can't perform in-place updates on Numba scalars, so we need to
# convert them to NumPy scalars.
# TODO: We should really prevent the rewrites from creating
# in-place updates on scalars when the Numba mode is selected (or
# in general?).
inplace_elemwise_src = f"""
def {inplace_elemwise_fn_name}({input_signature_str}):
{updated_input_name}_scalar = np.asarray({updated_input_name})
return {elemwise_fn_name}({input_signature_str + ", " + updated_input_name}_scalar).item()
"""
inplace_elemwise_fn = compile_function_src(
inplace_elemwise_src,
inplace_elemwise_fn_name,
{**globals(), **inplace_global_env},
)
return numba_basic.numba_njit(inline="always", fastmath=config.numba__fastmath)(
inplace_elemwise_fn
)
return elemwise_fn
@singledispatch @singledispatch
def scalar_in_place_fn(op: Op, idx: str, res: str, arr: str): def scalar_in_place_fn(op: Op, idx: str, res: str, arr: str):
"""Return code for an in-place update on an array using a binary scalar :class:`Op`. """Return code for an in-place update on an array using a binary scalar :class:`Op`.
...@@ -135,7 +58,7 @@ def scalar_in_place_fn(op: Op, idx: str, res: str, arr: str): ...@@ -135,7 +58,7 @@ def scalar_in_place_fn(op: Op, idx: str, res: str, arr: str):
arr arr
The symbol name for the second input. The symbol name for the second input.
""" """
return f"{res}[{idx}] = {op.nfunc_spec[0]}({res}[{idx}], arr)" raise NotImplementedError()
@scalar_in_place_fn.register(Add) @scalar_in_place_fn.register(Add)
...@@ -143,14 +66,24 @@ def scalar_in_place_fn_Add(op, idx, res, arr): ...@@ -143,14 +66,24 @@ def scalar_in_place_fn_Add(op, idx, res, arr):
return f"{res}[{idx}] += {arr}" return f"{res}[{idx}] += {arr}"
@scalar_in_place_fn.register(Sub)
def scalar_in_place_fn_Sub(op, idx, res, arr):
return f"{res}[{idx}] -= {arr}"
@scalar_in_place_fn.register(Mean)
def scalar_in_place_fn_Mean(op, idx, res, arr):
return f"{res}[{idx}] += ({arr} - {res}[{idx}]) / (i + 1)"
@scalar_in_place_fn.register(Mul) @scalar_in_place_fn.register(Mul)
def scalar_in_place_fn_Mul(op, idx, res, arr): def scalar_in_place_fn_Mul(op, idx, res, arr):
return f"{res}[{idx}] *= {arr}" return f"{res}[{idx}] *= {arr}"
@scalar_in_place_fn.register(Sub) @scalar_in_place_fn.register(MulWithoutZeros)
def scalar_in_place_fn_Sub(op, idx, res, arr): def scalar_in_place_fn_MulWithoutZeros(op, idx, res, arr):
return f"{res}[{idx}] -= {arr}" return f"{res}[{idx}] = {arr} if {res}[{idx}] == 0 else ({res}[{idx}] if {arr} == 0 else {res}[{idx}] * {arr})"
@scalar_in_place_fn.register(AND) @scalar_in_place_fn.register(AND)
...@@ -186,6 +119,44 @@ if {res}[{idx}] < {arr}: ...@@ -186,6 +119,44 @@ if {res}[{idx}] < {arr}:
""" """
@scalar_in_place_fn.register(ScalarMinimum)
def scalar_in_place_fn_ScalarMinimum(op, idx, res, arr):
return f"""
if {res}[{idx}] > {arr}:
{res}[{idx}] = {arr}
"""
def create_vectorize_func(op, node, use_signature=False, identity=None, **kwargs):
scalar_op_fn = numba_funcify(op.scalar_op, node=node, inline="always", **kwargs)
if len(node.outputs) > 1:
raise NotImplementedError(
"Multi-output Elemwise Ops are not supported by the Numba backend"
)
if use_signature:
signature = [create_numba_signature(node, force_scalar=True)]
else:
signature = []
target = (
getattr(node.tag, "numba__vectorize_target", None)
or config.numba__vectorize_target
)
numba_vectorized_fn = numba_basic.numba_vectorize(
signature, identity=identity, target=target, fastmath=config.numba__fastmath
)
py_scalar_func = getattr(scalar_op_fn, "py_func", scalar_op_fn)
elemwise_fn = numba_vectorized_fn(scalar_op_fn)
elemwise_fn.py_scalar_func = py_scalar_func
return elemwise_fn
def create_axis_reducer( def create_axis_reducer(
scalar_op: Op, scalar_op: Op,
identity: Union[np.ndarray, Number], identity: Union[np.ndarray, Number],
...@@ -194,7 +165,7 @@ def create_axis_reducer( ...@@ -194,7 +165,7 @@ def create_axis_reducer(
dtype: numba.types.Type, dtype: numba.types.Type,
keepdims: bool = False, keepdims: bool = False,
) -> numba.core.dispatcher.Dispatcher: ) -> numba.core.dispatcher.Dispatcher:
r"""Create a Numba JITed function that performs a NumPy reduction on a given axis. r"""Create Python function that performs a NumPy-like reduction on a given axis.
The functions generated by this function take the following form: The functions generated by this function take the following form:
...@@ -232,35 +203,15 @@ def create_axis_reducer( ...@@ -232,35 +203,15 @@ def create_axis_reducer(
The data type of the result. The data type of the result.
keepdims: keepdims:
Determines whether or not the reduced dimension is retained. Determines whether or not the reduced dimension is retained.
"""
reduce_elemwise_fn_name = "careduce_axis"
if ndim > 1: Returns
res_shape_tuple_ctor = create_tuple_creator( =======
lambda i, shape: shape[i] if i < axis else shape[i + 1], ndim - 1 A Python function that can be JITed.
)
if keepdims:
set_out_dims = numba_basic.numba_njit(
lambda x: np.expand_dims(x, axis), inline="always"
)
else:
set_out_dims = numba_basic.numba_njit(lambda x: x, inline="always")
else:
@numba_basic.numba_njit """
def res_shape_tuple_ctor(args):
return 1
if keepdims: reduce_elemwise_fn_name = "careduce_axis"
set_out_dims = numba_basic.numba_njit(
lambda x: numba_basic.direct_cast(x, dtype), inline="always"
)
else:
set_out_dims = numba_basic.numba_njit(
lambda x: numba_basic.direct_cast(x[0], dtype), inline="always"
)
identity = str(identity) identity = str(identity)
if identity == "inf": if identity == "inf":
...@@ -268,7 +219,18 @@ def create_axis_reducer( ...@@ -268,7 +219,18 @@ def create_axis_reducer(
elif identity == "-inf": elif identity == "-inf":
identity = "-np.inf" identity = "-np.inf"
global_env = {
"np": np,
"numba_basic": numba_basic,
"out_dtype": dtype,
}
if ndim > 1: if ndim > 1:
res_shape_tuple_ctor = create_tuple_creator(
lambda i, shape: shape[i] if i < axis else shape[i + 1], ndim - 1
)
global_env["res_shape_tuple_ctor"] = res_shape_tuple_ctor
res_indices = [] res_indices = []
arr_indices = [] arr_indices = []
count = 0 count = 0
...@@ -289,48 +251,45 @@ def create_axis_reducer( ...@@ -289,48 +251,45 @@ def create_axis_reducer(
) )
inplace_update_statement = indent(inplace_update_statement, " " * 4 * 3) inplace_update_statement = indent(inplace_update_statement, " " * 4 * 3)
return_expr = f"np.expand_dims(res, {axis})" if keepdims else "res"
reduce_elemwise_def_src = f""" reduce_elemwise_def_src = f"""
def {reduce_elemwise_fn_name}(x): def {reduce_elemwise_fn_name}(x):
res_shape = res_shape_tuple_ctor(x.shape) x_shape = np.shape(x)
res = np.full(res_shape, numba_basic.to_scalar({identity})) res_shape = res_shape_tuple_ctor(x_shape)
res = np.full(res_shape, numba_basic.to_scalar({identity}), dtype=out_dtype)
axis_shape = x.shape[{axis}] axis_shape = x.shape[{axis}]
for idx_arr in np.ndindex(res_shape): for idx_arr in np.ndindex(res_shape):
for i in range(axis_shape): for i in range(axis_shape):
{inplace_update_statement} {inplace_update_statement}
return set_out_dims(res) return {return_expr}
""" """
else: else:
inplace_update_statement = scalar_in_place_fn(scalar_op, "0", "res", "x[i]") inplace_update_statement = scalar_in_place_fn(scalar_op, "0", "res", "x[i]")
inplace_update_statement = indent(inplace_update_statement, " " * 4 * 3) inplace_update_statement = indent(inplace_update_statement, " " * 4 * 2)
return_expr = "res" if keepdims else "res.item()"
reduce_elemwise_def_src = f""" reduce_elemwise_def_src = f"""
def {reduce_elemwise_fn_name}(x): def {reduce_elemwise_fn_name}(x):
res_shape = res_shape_tuple_ctor(x.shape) res = np.full(1, numba_basic.to_scalar({identity}), dtype=out_dtype)
res = np.full(res_shape, numba_basic.to_scalar({identity}))
axis_shape = x.shape[{axis}] axis_shape = x.shape[{axis}]
for i in range(axis_shape): for i in range(axis_shape):
{inplace_update_statement} {inplace_update_statement}
return set_out_dims(res) return {return_expr}
""" """
global_env = {
"np": np,
"res_shape_tuple_ctor": res_shape_tuple_ctor,
"numba_basic": numba_basic,
"set_out_dims": set_out_dims,
}
reduce_elemwise_fn_py = compile_function_src( reduce_elemwise_fn_py = compile_function_src(
reduce_elemwise_def_src, reduce_elemwise_fn_name, global_env reduce_elemwise_def_src, reduce_elemwise_fn_name, {**globals(), **global_env}
) )
return numba_basic.numba_njit(boundscheck=False)(reduce_elemwise_fn_py)
return reduce_elemwise_fn_py
def create_multiaxis_reducer( def create_multiaxis_reducer(
...@@ -366,6 +325,10 @@ def create_multiaxis_reducer( ...@@ -366,6 +325,10 @@ def create_multiaxis_reducer(
dtype: dtype:
The data type of the result. The data type of the result.
Returns
=======
A Python function that can be JITed.
""" """
if len(axes) == 1: if len(axes) == 1:
return create_axis_reducer(scalar_op, identity, axes[0], ndim, dtype) return create_axis_reducer(scalar_op, identity, axes[0], ndim, dtype)
...@@ -378,9 +341,13 @@ def create_multiaxis_reducer( ...@@ -378,9 +341,13 @@ def create_multiaxis_reducer(
for i, axis in enumerate(to_reduce): for i, axis in enumerate(to_reduce):
careducer_axes_fn_name = f"careduce_axes_fn_{i}" careducer_axes_fn_name = f"careduce_axes_fn_{i}"
global_env[careducer_axes_fn_name] = create_axis_reducer( reducer_py_fn = create_axis_reducer(scalar_op, identity, axis, ndim, dtype)
scalar_op, identity, axis, ndim, dtype reducer_fn = numba_basic.numba_njit(
) boundscheck=False, fastmath=config.numba__fastmath
)(reducer_py_fn)
global_env[careducer_axes_fn_name] = reducer_fn
ndim -= 1 ndim -= 1
last_var_name = var_name last_var_name = var_name
var_name = f"axis_{i}_res" var_name = f"axis_{i}_res"
...@@ -398,7 +365,40 @@ def {careduce_fn_name}({input_name}): ...@@ -398,7 +365,40 @@ def {careduce_fn_name}({input_name}):
careduce_fn = compile_function_src( careduce_fn = compile_function_src(
careduce_def_src, careduce_fn_name, {**globals(), **global_env} careduce_def_src, careduce_fn_name, {**globals(), **global_env}
) )
return numba_basic.numba_njit(fastmath=config.numba__fastmath)(careduce_fn)
return careduce_fn
def jit_compile_reducer(node, fn, **kwds):
"""Compile Python source for reduction loops using additional optimizations.
Parameters
==========
node
An node from which the signature can be derived.
fn
The Python function object to compile.
kwds
Extra keywords to be added to the :func:`numba.njit` function.
Returns
=======
A :func:`numba.njit`-compiled function.
"""
signature = create_numba_signature(node, reduce_to_scalar=True)
# Eagerly compile the function using increased optimizations. This should
# help improve nested loop reductions.
with use_optimized_cheap_pass():
res = numba_basic.numba_njit(
signature,
boundscheck=False,
fastmath=config.numba__fastmath,
**kwds,
)(fn)
return res
def create_axis_apply_fn(fn, axis, ndim, dtype): def create_axis_apply_fn(fn, axis, ndim, dtype):
...@@ -417,6 +417,57 @@ def create_axis_apply_fn(fn, axis, ndim, dtype): ...@@ -417,6 +417,57 @@ def create_axis_apply_fn(fn, axis, ndim, dtype):
return axis_apply_fn return axis_apply_fn
@numba_funcify.register(Elemwise)
def numba_funcify_Elemwise(op, node, **kwargs):
elemwise_fn = create_vectorize_func(op, node, use_signature=False)
elemwise_fn_name = elemwise_fn.__name__
if op.inplace_pattern:
input_idx = op.inplace_pattern[0]
sign_obj = inspect.signature(elemwise_fn.py_scalar_func)
input_names = list(sign_obj.parameters.keys())
unique_names = unique_name_generator([elemwise_fn_name, "np"], suffix_sep="_")
input_names = [unique_names(i, force_unique=True) for i in input_names]
updated_input_name = input_names[input_idx]
inplace_global_env = {elemwise_fn_name: elemwise_fn, "np": np}
inplace_elemwise_fn_name = f"{elemwise_fn_name}_inplace"
input_signature_str = ", ".join(input_names)
if node.inputs[input_idx].ndim > 0:
inplace_elemwise_src = f"""
def {inplace_elemwise_fn_name}({input_signature_str}):
return {elemwise_fn_name}({input_signature_str + ", " + updated_input_name})
"""
else:
# We can't perform in-place updates on Numba scalars, so we need to
# convert them to NumPy scalars.
# TODO: We should really prevent the rewrites from creating
# in-place updates on scalars when the Numba mode is selected (or
# in general?).
inplace_elemwise_src = f"""
def {inplace_elemwise_fn_name}({input_signature_str}):
{updated_input_name}_scalar = np.asarray({updated_input_name})
return {elemwise_fn_name}({input_signature_str + ", " + updated_input_name}_scalar).item()
"""
inplace_elemwise_fn = compile_function_src(
inplace_elemwise_src,
inplace_elemwise_fn_name,
{**globals(), **inplace_global_env},
)
return numba_basic.numba_njit(inline="always", fastmath=config.numba__fastmath)(
inplace_elemwise_fn
)
return elemwise_fn
@numba_funcify.register(CAReduce) @numba_funcify.register(CAReduce)
def numba_funcify_CAReduce(op, node, **kwargs): def numba_funcify_CAReduce(op, node, **kwargs):
axes = op.axis axes = op.axis
...@@ -434,15 +485,16 @@ def numba_funcify_CAReduce(op, node, **kwargs): ...@@ -434,15 +485,16 @@ def numba_funcify_CAReduce(op, node, **kwargs):
input_name = get_name_for_object(node.inputs[0]) input_name = get_name_for_object(node.inputs[0])
ndim = node.inputs[0].ndim ndim = node.inputs[0].ndim
careduce_fn = create_multiaxis_reducer( careduce_py_fn = create_multiaxis_reducer(
op.scalar_op, op.scalar_op,
scalar_op_identity, scalar_op_identity,
axes, axes,
ndim, ndim,
np_acc_dtype, np.dtype(node.outputs[0].type.dtype),
input_name=input_name, input_name=input_name,
) )
careduce_fn = jit_compile_reducer(node, careduce_py_fn)
return careduce_fn return careduce_fn
...@@ -533,24 +585,31 @@ def numba_funcify_Softmax(op, node, **kwargs): ...@@ -533,24 +585,31 @@ def numba_funcify_Softmax(op, node, **kwargs):
axis = op.axis axis = op.axis
if axis is not None: if axis is not None:
reduce_max = create_axis_reducer( reduce_max_py = create_axis_reducer(
scalar_maximum, -np.inf, axis, x_at.ndim, x_dtype, keepdims=True scalar_maximum, -np.inf, axis, x_at.ndim, x_dtype, keepdims=True
) )
reduce_sum = create_axis_reducer( reduce_sum_py = create_axis_reducer(
add_as, 0.0, axis, x_at.ndim, x_dtype, keepdims=True add_as, 0.0, axis, x_at.ndim, x_dtype, keepdims=True
) )
jit_fn = numba_basic.numba_njit(
boundscheck=False, fastmath=config.numba__fastmath
)
reduce_max = jit_fn(reduce_max_py)
reduce_sum = jit_fn(reduce_sum_py)
else: else:
reduce_max = np.max reduce_max = np.max
reduce_sum = np.sum reduce_sum = np.sum
@numba_basic.numba_njit def softmax_py_fn(x):
def softmax(x):
z = reduce_max(x) z = reduce_max(x)
e_x = np.exp(x - z) e_x = np.exp(x - z)
w = reduce_sum(e_x) w = reduce_sum(e_x)
sm = e_x / w sm = e_x / w
return sm return sm
softmax = jit_compile_reducer(node, softmax_py_fn)
return softmax return softmax
...@@ -563,19 +622,25 @@ def numba_funcify_SoftmaxGrad(op, node, **kwargs): ...@@ -563,19 +622,25 @@ def numba_funcify_SoftmaxGrad(op, node, **kwargs):
axis = op.axis axis = op.axis
if axis is not None: if axis is not None:
reduce_sum = create_axis_reducer( reduce_sum_py = create_axis_reducer(
add_as, 0.0, axis, sm_at.ndim, sm_dtype, keepdims=True add_as, 0.0, axis, sm_at.ndim, sm_dtype, keepdims=True
) )
jit_fn = numba_basic.numba_njit(
boundscheck=False, fastmath=config.numba__fastmath
)
reduce_sum = jit_fn(reduce_sum_py)
else: else:
reduce_sum = np.sum reduce_sum = np.sum
@numba_basic.numba_njit def softmax_grad_py_fn(dy, sm):
def softmax_grad(dy, sm):
dy_times_sm = dy * sm dy_times_sm = dy * sm
sum_dy_times_sm = reduce_sum(dy_times_sm) sum_dy_times_sm = reduce_sum(dy_times_sm)
dx = dy_times_sm - sum_dy_times_sm * sm dx = dy_times_sm - sum_dy_times_sm * sm
return dx return dx
softmax_grad = jit_compile_reducer(node, softmax_grad_py_fn)
return softmax_grad return softmax_grad
...@@ -588,22 +653,28 @@ def numba_funcify_LogSoftmax(op, node, **kwargs): ...@@ -588,22 +653,28 @@ def numba_funcify_LogSoftmax(op, node, **kwargs):
axis = op.axis axis = op.axis
if axis is not None: if axis is not None:
reduce_max = create_axis_reducer( reduce_max_py = create_axis_reducer(
scalar_maximum, -np.inf, axis, x_at.ndim, x_dtype, keepdims=True scalar_maximum, -np.inf, axis, x_at.ndim, x_dtype, keepdims=True
) )
reduce_sum = create_axis_reducer( reduce_sum_py = create_axis_reducer(
add_as, 0.0, axis, x_at.ndim, x_dtype, keepdims=True add_as, 0.0, axis, x_at.ndim, x_dtype, keepdims=True
) )
jit_fn = numba_basic.numba_njit(
boundscheck=False, fastmath=config.numba__fastmath
)
reduce_max = jit_fn(reduce_max_py)
reduce_sum = jit_fn(reduce_sum_py)
else: else:
reduce_max = np.max reduce_max = np.max
reduce_sum = np.sum reduce_sum = np.sum
@numba_basic.numba_njit def log_softmax_py_fn(x):
def log_softmax(x):
xdev = x - reduce_max(x) xdev = x - reduce_max(x)
lsm = xdev - np.log(reduce_sum(np.exp(xdev))) lsm = xdev - np.log(reduce_sum(np.exp(xdev)))
return lsm return lsm
log_softmax = jit_compile_reducer(node, log_softmax_py_fn)
return log_softmax return log_softmax
...@@ -629,9 +700,13 @@ def numba_funcify_MaxAndArgmax(op, node, **kwargs): ...@@ -629,9 +700,13 @@ def numba_funcify_MaxAndArgmax(op, node, **kwargs):
# work-around # work-around
keep_axes = tuple(i for i in range(x_ndim) if i not in axes) keep_axes = tuple(i for i in range(x_ndim) if i not in axes)
reduce_max = create_multiaxis_reducer( reduce_max_py_fn = create_multiaxis_reducer(
scalar_maximum, -np.inf, axes, x_ndim, x_dtype scalar_maximum, -np.inf, axes, x_ndim, x_dtype
) )
reduce_max = jit_compile_reducer(
Apply(node.op, node.inputs, [node.outputs[0].clone()]), reduce_max_py_fn
)
reduced_x_ndim = x_ndim - len(axes) + 1 reduced_x_ndim = x_ndim - len(axes) + 1
argmax_axis = create_axis_apply_fn( argmax_axis = create_axis_apply_fn(
np.argmax, reduced_x_ndim - 1, reduced_x_ndim, np.int64 np.argmax, reduced_x_ndim - 1, reduced_x_ndim, np.int64
......
...@@ -37,6 +37,7 @@ from aesara.tensor import elemwise as at_elemwise ...@@ -37,6 +37,7 @@ from aesara.tensor import elemwise as at_elemwise
from aesara.tensor import extra_ops, nlinalg, slinalg from aesara.tensor import extra_ops, nlinalg, slinalg
from aesara.tensor import subtensor as at_subtensor from aesara.tensor import subtensor as at_subtensor
from aesara.tensor.elemwise import Elemwise from aesara.tensor.elemwise import Elemwise
from aesara.tensor.math import All, Any, Max, Mean, Min, Prod, ProdWithoutZeros, Sum
from aesara.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape from aesara.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape
...@@ -1049,94 +1050,132 @@ def test_ARange(start, stop, step, dtype): ...@@ -1049,94 +1050,132 @@ def test_ARange(start, stop, step, dtype):
@pytest.mark.parametrize( @pytest.mark.parametrize(
"careduce_fn, axis, v, keepdims", "careduce_fn, axis, v",
[ [
( (
at.sum, lambda x, axis=None, dtype=None, acc_dtype=None: Sum(
axis=axis, dtype=dtype, acc_dtype=acc_dtype
)(x),
0, 0,
set_test_value(at.vector(), np.arange(3, dtype=config.floatX)), set_test_value(at.vector(), np.arange(3, dtype=config.floatX)),
False,
), ),
( (
at.all, lambda x, axis=None, dtype=None, acc_dtype=None: All(axis)(x),
0,
set_test_value(at.vector(), np.arange(3, dtype=config.floatX)),
),
(
lambda x, axis=None, dtype=None, acc_dtype=None: Any(axis)(x),
0,
set_test_value(at.vector(), np.arange(3, dtype=config.floatX)),
),
(
lambda x, axis=None, dtype=None, acc_dtype=None: Mean(axis)(x),
0, 0,
set_test_value(at.vector(), np.arange(3, dtype=config.floatX)), set_test_value(at.vector(), np.arange(3, dtype=config.floatX)),
False,
), ),
( (
at.sum, lambda x, axis=None, dtype=None, acc_dtype=None: Mean(axis)(x),
0, 0,
set_test_value( set_test_value(
at.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2)) at.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2))
), ),
False,
), ),
( (
at.sum, lambda x, axis=None, dtype=None, acc_dtype=None: Sum(
axis=axis, dtype=dtype, acc_dtype=acc_dtype
)(x),
0,
set_test_value(
at.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2))
),
),
(
lambda x, axis=None, dtype=None, acc_dtype=None: Sum(
axis=axis, dtype=dtype, acc_dtype=acc_dtype
)(x),
(0, 1), (0, 1),
set_test_value( set_test_value(
at.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2)) at.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2))
), ),
False,
), ),
( (
at.sum, lambda x, axis=None, dtype=None, acc_dtype=None: Sum(
axis=axis, dtype=dtype, acc_dtype=acc_dtype
)(x),
(1, 0), (1, 0),
set_test_value( set_test_value(
at.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2)) at.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2))
), ),
False,
), ),
( (
at.sum, lambda x, axis=None, dtype=None, acc_dtype=None: Sum(
axis=axis, dtype=dtype, acc_dtype=acc_dtype
)(x),
None, None,
set_test_value( set_test_value(
at.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2)) at.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2))
), ),
False,
), ),
( (
at.sum, lambda x, axis=None, dtype=None, acc_dtype=None: Sum(
axis=axis, dtype=dtype, acc_dtype=acc_dtype
)(x),
1, 1,
set_test_value( set_test_value(
at.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2)) at.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2))
), ),
False,
), ),
( (
at.prod, lambda x, axis=None, dtype=None, acc_dtype=None: Prod(
axis=axis, dtype=dtype, acc_dtype=acc_dtype
)(x),
0, 0,
set_test_value(at.vector(), np.arange(3, dtype=config.floatX)), set_test_value(at.vector(), np.arange(3, dtype=config.floatX)),
False,
), ),
( (
at.prod, lambda x, axis=None, dtype=None, acc_dtype=None: ProdWithoutZeros(
axis=axis, dtype=dtype, acc_dtype=acc_dtype
)(x),
0,
set_test_value(at.vector(), np.arange(3, dtype=config.floatX)),
),
(
lambda x, axis=None, dtype=None, acc_dtype=None: Prod(
axis=axis, dtype=dtype, acc_dtype=acc_dtype
)(x),
0, 0,
set_test_value( set_test_value(
at.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2)) at.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2))
), ),
False,
), ),
( (
at.prod, lambda x, axis=None, dtype=None, acc_dtype=None: Prod(
axis=axis, dtype=dtype, acc_dtype=acc_dtype
)(x),
1, 1,
set_test_value( set_test_value(
at.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2)) at.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2))
), ),
False,
), ),
( (
at.max, lambda x, axis=None, dtype=None, acc_dtype=None: Max(axis)(x),
None,
set_test_value(
at.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2))
),
),
(
lambda x, axis=None, dtype=None, acc_dtype=None: Min(axis)(x),
None, None,
set_test_value( set_test_value(
at.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2)) at.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2))
), ),
True,
), ),
], ],
) )
def test_CAReduce(careduce_fn, axis, v, keepdims): def test_CAReduce(careduce_fn, axis, v):
g = careduce_fn(v, axis=axis, keepdims=keepdims) g = careduce_fn(v, axis=axis)
g_fg = FunctionGraph(outputs=[g]) g_fg = FunctionGraph(outputs=[g])
compare_numba_and_py( compare_numba_and_py(
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论