提交 33998b20 authored 作者: kc611's avatar kc611 提交者: Brandon T. Willard

Added more optimizations to the Numba cheap pass-manager

This only applies to reduction `Op`s (e.g. `CAReduce`).
上级 66760618
import operator
import warnings
from contextlib import contextmanager
from functools import singledispatch
import numba
......@@ -57,14 +58,31 @@ def numba_vectorize(*args, **kwargs):
def get_numba_type(
aesara_type: Type, layout: str = "A", force_scalar: bool = False
aesara_type: Type,
layout: str = "A",
force_scalar: bool = False,
reduce_to_scalar: bool = False,
) -> numba.types.Type:
"""Create a Numba type object for a ``Type``."""
r"""Create a Numba type object for a :class:`Type`.
Parameters
----------
aesara_type
The :class:`Type` to convert.
layout
The :class:`numpy.ndarray` layout to use.
force_scalar
Ignore dimension information and return the corresponding Numba scalar types.
reduce_to_scalar
Return Numba scalars for zero dimensional :class:`TensorType`\s.
"""
if isinstance(aesara_type, TensorType):
dtype = aesara_type.numpy_dtype
numba_dtype = numba.from_dtype(dtype)
if force_scalar:
if force_scalar or (
reduce_to_scalar and getattr(aesara_type, "ndim", None) == 0
):
return numba_dtype
return numba.types.Array(numba_dtype, aesara_type.ndim, layout)
elif isinstance(aesara_type, Scalar):
......@@ -75,15 +93,25 @@ def get_numba_type(
raise NotImplementedError(f"Numba type not implemented for {aesara_type}")
def create_numba_signature(node: Apply, force_scalar: bool = False) -> numba.types.Type:
def create_numba_signature(
node: Apply, force_scalar: bool = False, reduce_to_scalar: bool = False
) -> numba.types.Type:
"""Create a Numba type for the signature of an ``Apply`` node."""
input_types = []
for inp in node.inputs:
input_types.append(get_numba_type(inp.type, force_scalar=force_scalar))
input_types.append(
get_numba_type(
inp.type, force_scalar=force_scalar, reduce_to_scalar=reduce_to_scalar
)
)
output_types = []
for out in node.outputs:
output_types.append(get_numba_type(out.type, force_scalar=force_scalar))
output_types.append(
get_numba_type(
out.type, force_scalar=force_scalar, reduce_to_scalar=reduce_to_scalar
)
)
if len(output_types) > 1:
return numba.types.Tuple(output_types)(*input_types)
......@@ -263,6 +291,23 @@ def create_arg_string(x):
return args
@contextmanager
def use_optimized_cheap_pass(*args, **kwargs):
"""Temporarily replace the cheap optimization pass with a better one."""
from numba.core.registry import cpu_target
context = cpu_target.target_context._internal_codegen
old_pm = context._mpm_cheap
new_pm = context._module_pass_manager(
loop_vectorize=True, slp_vectorize=True, opt=3, cost="cheap"
)
context._mpm_cheap = new_pm
try:
yield
finally:
context._mpm_cheap = old_pm
@singledispatch
def numba_typify(data, dtype=None, **kwargs):
return data
......
......@@ -9,12 +9,14 @@ import numpy as np
from numba.cpython.unsafe.tuple import tuple_setitem
from aesara import config
from aesara.graph.basic import Apply
from aesara.graph.op import Op
from aesara.link.numba.dispatch import basic as numba_basic
from aesara.link.numba.dispatch.basic import (
create_numba_signature,
create_tuple_creator,
numba_funcify,
use_optimized_cheap_pass,
)
from aesara.link.utils import (
compile_function_src,
......@@ -27,99 +29,20 @@ from aesara.scalar.basic import (
XOR,
Add,
IntDiv,
Mean,
Mul,
ScalarMaximum,
ScalarMinimum,
Sub,
TrueDiv,
)
from aesara.scalar.basic import add as add_as
from aesara.scalar.basic import scalar_maximum
from aesara.tensor.elemwise import CAReduce, DimShuffle, Elemwise
from aesara.tensor.math import MaxAndArgmax
from aesara.tensor.math import MaxAndArgmax, MulWithoutZeros
from aesara.tensor.nnet.basic import LogSoftmax, Softmax, SoftmaxGrad
def create_vectorize_func(op, node, use_signature=False, identity=None, **kwargs):
scalar_op_fn = numba_funcify(op.scalar_op, node=node, inline="always", **kwargs)
if len(node.outputs) > 1:
raise NotImplementedError(
"Multi-output Elemwise Ops are not supported by the Numba backend"
)
if use_signature:
signature = [create_numba_signature(node, force_scalar=True)]
else:
signature = []
target = (
getattr(node.tag, "numba__vectorize_target", None)
or config.numba__vectorize_target
)
numba_vectorized_fn = numba_basic.numba_vectorize(
signature, identity=identity, target=target, fastmath=config.numba__fastmath
)
py_scalar_func = getattr(scalar_op_fn, "py_func", scalar_op_fn)
elemwise_fn = numba_vectorized_fn(scalar_op_fn)
elemwise_fn.py_scalar_func = py_scalar_func
return elemwise_fn
@numba_funcify.register(Elemwise)
def numba_funcify_Elemwise(op, node, **kwargs):
elemwise_fn = create_vectorize_func(op, node, use_signature=False)
elemwise_fn_name = elemwise_fn.__name__
if op.inplace_pattern:
input_idx = op.inplace_pattern[0]
sign_obj = inspect.signature(elemwise_fn.py_scalar_func)
input_names = list(sign_obj.parameters.keys())
unique_names = unique_name_generator([elemwise_fn_name, "np"], suffix_sep="_")
input_names = [unique_names(i, force_unique=True) for i in input_names]
updated_input_name = input_names[input_idx]
inplace_global_env = {elemwise_fn_name: elemwise_fn, "np": np}
inplace_elemwise_fn_name = f"{elemwise_fn_name}_inplace"
input_signature_str = ", ".join(input_names)
if node.inputs[input_idx].ndim > 0:
inplace_elemwise_src = f"""
def {inplace_elemwise_fn_name}({input_signature_str}):
return {elemwise_fn_name}({input_signature_str + ", " + updated_input_name})
"""
else:
# We can't perform in-place updates on Numba scalars, so we need to
# convert them to NumPy scalars.
# TODO: We should really prevent the rewrites from creating
# in-place updates on scalars when the Numba mode is selected (or
# in general?).
inplace_elemwise_src = f"""
def {inplace_elemwise_fn_name}({input_signature_str}):
{updated_input_name}_scalar = np.asarray({updated_input_name})
return {elemwise_fn_name}({input_signature_str + ", " + updated_input_name}_scalar).item()
"""
inplace_elemwise_fn = compile_function_src(
inplace_elemwise_src,
inplace_elemwise_fn_name,
{**globals(), **inplace_global_env},
)
return numba_basic.numba_njit(inline="always", fastmath=config.numba__fastmath)(
inplace_elemwise_fn
)
return elemwise_fn
@singledispatch
def scalar_in_place_fn(op: Op, idx: str, res: str, arr: str):
"""Return code for an in-place update on an array using a binary scalar :class:`Op`.
......@@ -135,7 +58,7 @@ def scalar_in_place_fn(op: Op, idx: str, res: str, arr: str):
arr
The symbol name for the second input.
"""
return f"{res}[{idx}] = {op.nfunc_spec[0]}({res}[{idx}], arr)"
raise NotImplementedError()
@scalar_in_place_fn.register(Add)
......@@ -143,14 +66,24 @@ def scalar_in_place_fn_Add(op, idx, res, arr):
return f"{res}[{idx}] += {arr}"
@scalar_in_place_fn.register(Sub)
def scalar_in_place_fn_Sub(op, idx, res, arr):
return f"{res}[{idx}] -= {arr}"
@scalar_in_place_fn.register(Mean)
def scalar_in_place_fn_Mean(op, idx, res, arr):
return f"{res}[{idx}] += ({arr} - {res}[{idx}]) / (i + 1)"
@scalar_in_place_fn.register(Mul)
def scalar_in_place_fn_Mul(op, idx, res, arr):
return f"{res}[{idx}] *= {arr}"
@scalar_in_place_fn.register(Sub)
def scalar_in_place_fn_Sub(op, idx, res, arr):
return f"{res}[{idx}] -= {arr}"
@scalar_in_place_fn.register(MulWithoutZeros)
def scalar_in_place_fn_MulWithoutZeros(op, idx, res, arr):
return f"{res}[{idx}] = {arr} if {res}[{idx}] == 0 else ({res}[{idx}] if {arr} == 0 else {res}[{idx}] * {arr})"
@scalar_in_place_fn.register(AND)
......@@ -186,6 +119,44 @@ if {res}[{idx}] < {arr}:
"""
@scalar_in_place_fn.register(ScalarMinimum)
def scalar_in_place_fn_ScalarMinimum(op, idx, res, arr):
return f"""
if {res}[{idx}] > {arr}:
{res}[{idx}] = {arr}
"""
def create_vectorize_func(op, node, use_signature=False, identity=None, **kwargs):
scalar_op_fn = numba_funcify(op.scalar_op, node=node, inline="always", **kwargs)
if len(node.outputs) > 1:
raise NotImplementedError(
"Multi-output Elemwise Ops are not supported by the Numba backend"
)
if use_signature:
signature = [create_numba_signature(node, force_scalar=True)]
else:
signature = []
target = (
getattr(node.tag, "numba__vectorize_target", None)
or config.numba__vectorize_target
)
numba_vectorized_fn = numba_basic.numba_vectorize(
signature, identity=identity, target=target, fastmath=config.numba__fastmath
)
py_scalar_func = getattr(scalar_op_fn, "py_func", scalar_op_fn)
elemwise_fn = numba_vectorized_fn(scalar_op_fn)
elemwise_fn.py_scalar_func = py_scalar_func
return elemwise_fn
def create_axis_reducer(
scalar_op: Op,
identity: Union[np.ndarray, Number],
......@@ -194,7 +165,7 @@ def create_axis_reducer(
dtype: numba.types.Type,
keepdims: bool = False,
) -> numba.core.dispatcher.Dispatcher:
r"""Create a Numba JITed function that performs a NumPy reduction on a given axis.
r"""Create Python function that performs a NumPy-like reduction on a given axis.
The functions generated by this function take the following form:
......@@ -232,35 +203,15 @@ def create_axis_reducer(
The data type of the result.
keepdims:
Determines whether or not the reduced dimension is retained.
"""
reduce_elemwise_fn_name = "careduce_axis"
if ndim > 1:
res_shape_tuple_ctor = create_tuple_creator(
lambda i, shape: shape[i] if i < axis else shape[i + 1], ndim - 1
)
if keepdims:
set_out_dims = numba_basic.numba_njit(
lambda x: np.expand_dims(x, axis), inline="always"
)
else:
set_out_dims = numba_basic.numba_njit(lambda x: x, inline="always")
else:
Returns
=======
A Python function that can be JITed.
@numba_basic.numba_njit
def res_shape_tuple_ctor(args):
return 1
"""
if keepdims:
set_out_dims = numba_basic.numba_njit(
lambda x: numba_basic.direct_cast(x, dtype), inline="always"
)
else:
set_out_dims = numba_basic.numba_njit(
lambda x: numba_basic.direct_cast(x[0], dtype), inline="always"
)
reduce_elemwise_fn_name = "careduce_axis"
identity = str(identity)
if identity == "inf":
......@@ -268,7 +219,18 @@ def create_axis_reducer(
elif identity == "-inf":
identity = "-np.inf"
global_env = {
"np": np,
"numba_basic": numba_basic,
"out_dtype": dtype,
}
if ndim > 1:
res_shape_tuple_ctor = create_tuple_creator(
lambda i, shape: shape[i] if i < axis else shape[i + 1], ndim - 1
)
global_env["res_shape_tuple_ctor"] = res_shape_tuple_ctor
res_indices = []
arr_indices = []
count = 0
......@@ -289,48 +251,45 @@ def create_axis_reducer(
)
inplace_update_statement = indent(inplace_update_statement, " " * 4 * 3)
return_expr = f"np.expand_dims(res, {axis})" if keepdims else "res"
reduce_elemwise_def_src = f"""
def {reduce_elemwise_fn_name}(x):
res_shape = res_shape_tuple_ctor(x.shape)
res = np.full(res_shape, numba_basic.to_scalar({identity}))
x_shape = np.shape(x)
res_shape = res_shape_tuple_ctor(x_shape)
res = np.full(res_shape, numba_basic.to_scalar({identity}), dtype=out_dtype)
axis_shape = x.shape[{axis}]
for idx_arr in np.ndindex(res_shape):
for i in range(axis_shape):
{inplace_update_statement}
{inplace_update_statement}
return set_out_dims(res)
return {return_expr}
"""
else:
inplace_update_statement = scalar_in_place_fn(scalar_op, "0", "res", "x[i]")
inplace_update_statement = indent(inplace_update_statement, " " * 4 * 3)
inplace_update_statement = indent(inplace_update_statement, " " * 4 * 2)
return_expr = "res" if keepdims else "res.item()"
reduce_elemwise_def_src = f"""
def {reduce_elemwise_fn_name}(x):
res_shape = res_shape_tuple_ctor(x.shape)
res = np.full(res_shape, numba_basic.to_scalar({identity}))
res = np.full(1, numba_basic.to_scalar({identity}), dtype=out_dtype)
axis_shape = x.shape[{axis}]
for i in range(axis_shape):
{inplace_update_statement}
{inplace_update_statement}
return set_out_dims(res)
return {return_expr}
"""
global_env = {
"np": np,
"res_shape_tuple_ctor": res_shape_tuple_ctor,
"numba_basic": numba_basic,
"set_out_dims": set_out_dims,
}
reduce_elemwise_fn_py = compile_function_src(
reduce_elemwise_def_src, reduce_elemwise_fn_name, global_env
reduce_elemwise_def_src, reduce_elemwise_fn_name, {**globals(), **global_env}
)
return numba_basic.numba_njit(boundscheck=False)(reduce_elemwise_fn_py)
return reduce_elemwise_fn_py
def create_multiaxis_reducer(
......@@ -366,6 +325,10 @@ def create_multiaxis_reducer(
dtype:
The data type of the result.
Returns
=======
A Python function that can be JITed.
"""
if len(axes) == 1:
return create_axis_reducer(scalar_op, identity, axes[0], ndim, dtype)
......@@ -378,9 +341,13 @@ def create_multiaxis_reducer(
for i, axis in enumerate(to_reduce):
careducer_axes_fn_name = f"careduce_axes_fn_{i}"
global_env[careducer_axes_fn_name] = create_axis_reducer(
scalar_op, identity, axis, ndim, dtype
)
reducer_py_fn = create_axis_reducer(scalar_op, identity, axis, ndim, dtype)
reducer_fn = numba_basic.numba_njit(
boundscheck=False, fastmath=config.numba__fastmath
)(reducer_py_fn)
global_env[careducer_axes_fn_name] = reducer_fn
ndim -= 1
last_var_name = var_name
var_name = f"axis_{i}_res"
......@@ -398,7 +365,40 @@ def {careduce_fn_name}({input_name}):
careduce_fn = compile_function_src(
careduce_def_src, careduce_fn_name, {**globals(), **global_env}
)
return numba_basic.numba_njit(fastmath=config.numba__fastmath)(careduce_fn)
return careduce_fn
def jit_compile_reducer(node, fn, **kwds):
"""Compile Python source for reduction loops using additional optimizations.
Parameters
==========
node
An node from which the signature can be derived.
fn
The Python function object to compile.
kwds
Extra keywords to be added to the :func:`numba.njit` function.
Returns
=======
A :func:`numba.njit`-compiled function.
"""
signature = create_numba_signature(node, reduce_to_scalar=True)
# Eagerly compile the function using increased optimizations. This should
# help improve nested loop reductions.
with use_optimized_cheap_pass():
res = numba_basic.numba_njit(
signature,
boundscheck=False,
fastmath=config.numba__fastmath,
**kwds,
)(fn)
return res
def create_axis_apply_fn(fn, axis, ndim, dtype):
......@@ -417,6 +417,57 @@ def create_axis_apply_fn(fn, axis, ndim, dtype):
return axis_apply_fn
@numba_funcify.register(Elemwise)
def numba_funcify_Elemwise(op, node, **kwargs):
elemwise_fn = create_vectorize_func(op, node, use_signature=False)
elemwise_fn_name = elemwise_fn.__name__
if op.inplace_pattern:
input_idx = op.inplace_pattern[0]
sign_obj = inspect.signature(elemwise_fn.py_scalar_func)
input_names = list(sign_obj.parameters.keys())
unique_names = unique_name_generator([elemwise_fn_name, "np"], suffix_sep="_")
input_names = [unique_names(i, force_unique=True) for i in input_names]
updated_input_name = input_names[input_idx]
inplace_global_env = {elemwise_fn_name: elemwise_fn, "np": np}
inplace_elemwise_fn_name = f"{elemwise_fn_name}_inplace"
input_signature_str = ", ".join(input_names)
if node.inputs[input_idx].ndim > 0:
inplace_elemwise_src = f"""
def {inplace_elemwise_fn_name}({input_signature_str}):
return {elemwise_fn_name}({input_signature_str + ", " + updated_input_name})
"""
else:
# We can't perform in-place updates on Numba scalars, so we need to
# convert them to NumPy scalars.
# TODO: We should really prevent the rewrites from creating
# in-place updates on scalars when the Numba mode is selected (or
# in general?).
inplace_elemwise_src = f"""
def {inplace_elemwise_fn_name}({input_signature_str}):
{updated_input_name}_scalar = np.asarray({updated_input_name})
return {elemwise_fn_name}({input_signature_str + ", " + updated_input_name}_scalar).item()
"""
inplace_elemwise_fn = compile_function_src(
inplace_elemwise_src,
inplace_elemwise_fn_name,
{**globals(), **inplace_global_env},
)
return numba_basic.numba_njit(inline="always", fastmath=config.numba__fastmath)(
inplace_elemwise_fn
)
return elemwise_fn
@numba_funcify.register(CAReduce)
def numba_funcify_CAReduce(op, node, **kwargs):
axes = op.axis
......@@ -434,15 +485,16 @@ def numba_funcify_CAReduce(op, node, **kwargs):
input_name = get_name_for_object(node.inputs[0])
ndim = node.inputs[0].ndim
careduce_fn = create_multiaxis_reducer(
careduce_py_fn = create_multiaxis_reducer(
op.scalar_op,
scalar_op_identity,
axes,
ndim,
np_acc_dtype,
np.dtype(node.outputs[0].type.dtype),
input_name=input_name,
)
careduce_fn = jit_compile_reducer(node, careduce_py_fn)
return careduce_fn
......@@ -533,24 +585,31 @@ def numba_funcify_Softmax(op, node, **kwargs):
axis = op.axis
if axis is not None:
reduce_max = create_axis_reducer(
reduce_max_py = create_axis_reducer(
scalar_maximum, -np.inf, axis, x_at.ndim, x_dtype, keepdims=True
)
reduce_sum = create_axis_reducer(
reduce_sum_py = create_axis_reducer(
add_as, 0.0, axis, x_at.ndim, x_dtype, keepdims=True
)
jit_fn = numba_basic.numba_njit(
boundscheck=False, fastmath=config.numba__fastmath
)
reduce_max = jit_fn(reduce_max_py)
reduce_sum = jit_fn(reduce_sum_py)
else:
reduce_max = np.max
reduce_sum = np.sum
@numba_basic.numba_njit
def softmax(x):
def softmax_py_fn(x):
z = reduce_max(x)
e_x = np.exp(x - z)
w = reduce_sum(e_x)
sm = e_x / w
return sm
softmax = jit_compile_reducer(node, softmax_py_fn)
return softmax
......@@ -563,19 +622,25 @@ def numba_funcify_SoftmaxGrad(op, node, **kwargs):
axis = op.axis
if axis is not None:
reduce_sum = create_axis_reducer(
reduce_sum_py = create_axis_reducer(
add_as, 0.0, axis, sm_at.ndim, sm_dtype, keepdims=True
)
jit_fn = numba_basic.numba_njit(
boundscheck=False, fastmath=config.numba__fastmath
)
reduce_sum = jit_fn(reduce_sum_py)
else:
reduce_sum = np.sum
@numba_basic.numba_njit
def softmax_grad(dy, sm):
def softmax_grad_py_fn(dy, sm):
dy_times_sm = dy * sm
sum_dy_times_sm = reduce_sum(dy_times_sm)
dx = dy_times_sm - sum_dy_times_sm * sm
return dx
softmax_grad = jit_compile_reducer(node, softmax_grad_py_fn)
return softmax_grad
......@@ -588,22 +653,28 @@ def numba_funcify_LogSoftmax(op, node, **kwargs):
axis = op.axis
if axis is not None:
reduce_max = create_axis_reducer(
reduce_max_py = create_axis_reducer(
scalar_maximum, -np.inf, axis, x_at.ndim, x_dtype, keepdims=True
)
reduce_sum = create_axis_reducer(
reduce_sum_py = create_axis_reducer(
add_as, 0.0, axis, x_at.ndim, x_dtype, keepdims=True
)
jit_fn = numba_basic.numba_njit(
boundscheck=False, fastmath=config.numba__fastmath
)
reduce_max = jit_fn(reduce_max_py)
reduce_sum = jit_fn(reduce_sum_py)
else:
reduce_max = np.max
reduce_sum = np.sum
@numba_basic.numba_njit
def log_softmax(x):
def log_softmax_py_fn(x):
xdev = x - reduce_max(x)
lsm = xdev - np.log(reduce_sum(np.exp(xdev)))
return lsm
log_softmax = jit_compile_reducer(node, log_softmax_py_fn)
return log_softmax
......@@ -629,9 +700,13 @@ def numba_funcify_MaxAndArgmax(op, node, **kwargs):
# work-around
keep_axes = tuple(i for i in range(x_ndim) if i not in axes)
reduce_max = create_multiaxis_reducer(
reduce_max_py_fn = create_multiaxis_reducer(
scalar_maximum, -np.inf, axes, x_ndim, x_dtype
)
reduce_max = jit_compile_reducer(
Apply(node.op, node.inputs, [node.outputs[0].clone()]), reduce_max_py_fn
)
reduced_x_ndim = x_ndim - len(axes) + 1
argmax_axis = create_axis_apply_fn(
np.argmax, reduced_x_ndim - 1, reduced_x_ndim, np.int64
......
......@@ -37,6 +37,7 @@ from aesara.tensor import elemwise as at_elemwise
from aesara.tensor import extra_ops, nlinalg, slinalg
from aesara.tensor import subtensor as at_subtensor
from aesara.tensor.elemwise import Elemwise
from aesara.tensor.math import All, Any, Max, Mean, Min, Prod, ProdWithoutZeros, Sum
from aesara.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape
......@@ -1049,94 +1050,132 @@ def test_ARange(start, stop, step, dtype):
@pytest.mark.parametrize(
"careduce_fn, axis, v, keepdims",
"careduce_fn, axis, v",
[
(
at.sum,
lambda x, axis=None, dtype=None, acc_dtype=None: Sum(
axis=axis, dtype=dtype, acc_dtype=acc_dtype
)(x),
0,
set_test_value(at.vector(), np.arange(3, dtype=config.floatX)),
False,
),
(
at.all,
lambda x, axis=None, dtype=None, acc_dtype=None: All(axis)(x),
0,
set_test_value(at.vector(), np.arange(3, dtype=config.floatX)),
),
(
lambda x, axis=None, dtype=None, acc_dtype=None: Any(axis)(x),
0,
set_test_value(at.vector(), np.arange(3, dtype=config.floatX)),
),
(
lambda x, axis=None, dtype=None, acc_dtype=None: Mean(axis)(x),
0,
set_test_value(at.vector(), np.arange(3, dtype=config.floatX)),
False,
),
(
at.sum,
lambda x, axis=None, dtype=None, acc_dtype=None: Mean(axis)(x),
0,
set_test_value(
at.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2))
),
False,
),
(
at.sum,
lambda x, axis=None, dtype=None, acc_dtype=None: Sum(
axis=axis, dtype=dtype, acc_dtype=acc_dtype
)(x),
0,
set_test_value(
at.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2))
),
),
(
lambda x, axis=None, dtype=None, acc_dtype=None: Sum(
axis=axis, dtype=dtype, acc_dtype=acc_dtype
)(x),
(0, 1),
set_test_value(
at.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2))
),
False,
),
(
at.sum,
lambda x, axis=None, dtype=None, acc_dtype=None: Sum(
axis=axis, dtype=dtype, acc_dtype=acc_dtype
)(x),
(1, 0),
set_test_value(
at.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2))
),
False,
),
(
at.sum,
lambda x, axis=None, dtype=None, acc_dtype=None: Sum(
axis=axis, dtype=dtype, acc_dtype=acc_dtype
)(x),
None,
set_test_value(
at.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2))
),
False,
),
(
at.sum,
lambda x, axis=None, dtype=None, acc_dtype=None: Sum(
axis=axis, dtype=dtype, acc_dtype=acc_dtype
)(x),
1,
set_test_value(
at.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2))
),
False,
),
(
at.prod,
lambda x, axis=None, dtype=None, acc_dtype=None: Prod(
axis=axis, dtype=dtype, acc_dtype=acc_dtype
)(x),
0,
set_test_value(at.vector(), np.arange(3, dtype=config.floatX)),
False,
),
(
at.prod,
lambda x, axis=None, dtype=None, acc_dtype=None: ProdWithoutZeros(
axis=axis, dtype=dtype, acc_dtype=acc_dtype
)(x),
0,
set_test_value(at.vector(), np.arange(3, dtype=config.floatX)),
),
(
lambda x, axis=None, dtype=None, acc_dtype=None: Prod(
axis=axis, dtype=dtype, acc_dtype=acc_dtype
)(x),
0,
set_test_value(
at.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2))
),
False,
),
(
at.prod,
lambda x, axis=None, dtype=None, acc_dtype=None: Prod(
axis=axis, dtype=dtype, acc_dtype=acc_dtype
)(x),
1,
set_test_value(
at.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2))
),
False,
),
(
at.max,
lambda x, axis=None, dtype=None, acc_dtype=None: Max(axis)(x),
None,
set_test_value(
at.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2))
),
),
(
lambda x, axis=None, dtype=None, acc_dtype=None: Min(axis)(x),
None,
set_test_value(
at.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2))
),
True,
),
],
)
def test_CAReduce(careduce_fn, axis, v, keepdims):
g = careduce_fn(v, axis=axis, keepdims=keepdims)
def test_CAReduce(careduce_fn, axis, v):
g = careduce_fn(v, axis=axis)
g_fg = FunctionGraph(outputs=[g])
compare_numba_and_py(
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论