提交 ef97287b authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Improve CAReduce Numba implementation

上级 9e24b10a
from collections.abc import Callable from collections.abc import Callable
from functools import singledispatch from functools import singledispatch
from numbers import Number from textwrap import dedent, indent
from textwrap import indent
from typing import Any from typing import Any
import numba import numba
...@@ -15,7 +14,6 @@ from pytensor.graph.op import Op ...@@ -15,7 +14,6 @@ from pytensor.graph.op import Op
from pytensor.link.numba.dispatch import basic as numba_basic from pytensor.link.numba.dispatch import basic as numba_basic
from pytensor.link.numba.dispatch.basic import ( from pytensor.link.numba.dispatch.basic import (
create_numba_signature, create_numba_signature,
create_tuple_creator,
numba_funcify, numba_funcify,
numba_njit, numba_njit,
use_optimized_cheap_pass, use_optimized_cheap_pass,
...@@ -26,7 +24,7 @@ from pytensor.link.numba.dispatch.vectorize_codegen import ( ...@@ -26,7 +24,7 @@ from pytensor.link.numba.dispatch.vectorize_codegen import (
encode_literals, encode_literals,
store_core_outputs, store_core_outputs,
) )
from pytensor.link.utils import compile_function_src, get_name_for_object from pytensor.link.utils import compile_function_src
from pytensor.scalar.basic import ( from pytensor.scalar.basic import (
AND, AND,
OR, OR,
...@@ -163,40 +161,32 @@ def create_vectorize_func( ...@@ -163,40 +161,32 @@ def create_vectorize_func(
return elemwise_fn return elemwise_fn
def create_axis_reducer( def create_multiaxis_reducer(
scalar_op: Op, scalar_op,
identity: np.ndarray | Number, identity,
axis: int, axes,
ndim: int, ndim,
dtype: numba.types.Type, dtype,
keepdims: bool = False, keepdims: bool = False,
return_scalar=False, ):
) -> numba.core.dispatcher.Dispatcher: r"""Construct a function that reduces multiple axes.
r"""Create Python function that performs a NumPy-like reduction on a given axis.
The functions generated by this function take the following form: The functions generated by this function take the following form:
.. code-block:: python .. code-block:: python
def careduce_axis(x): def careduce_add(x):
res_shape = tuple( # For x.ndim == 3 and axes == (0, 1) and scalar_op == "Add"
shape[i] if i < axis else shape[i + 1] for i in range(ndim - 1) x_shape = x.shape
) res_shape = x_shape[2]
res = np.full(res_shape, identity, dtype=dtype) res = np.full(res_shape, numba_basic.to_scalar(0.0), dtype=out_dtype)
x_axis_first = x.transpose(reaxis_first)
for m in range(x.shape[axis]):
reduce_fn(res, x_axis_first[m], res)
if keepdims:
return np.expand_dims(res, axis)
else:
return res
for i0 in range(x_shape[0]):
for i1 in range(x_shape[1]):
for i2 in range(x_shape[2]):
res[i2] += x[i0, i1, i2]
This can be removed/replaced when return res
https://github.com/numba/numba/issues/4504 is implemented.
Parameters Parameters
========== ==========
...@@ -204,25 +194,29 @@ def create_axis_reducer( ...@@ -204,25 +194,29 @@ def create_axis_reducer(
The scalar :class:`Op` that performs the desired reduction. The scalar :class:`Op` that performs the desired reduction.
identity: identity:
The identity value for the reduction. The identity value for the reduction.
axis: axes:
The axis to reduce. The axes to reduce.
ndim: ndim:
The number of dimensions of the result. The number of dimensions of the input variable.
dtype: dtype:
The data type of the result. The data type of the result.
keepdims: keepdims: boolean, default False
Determines whether or not the reduced dimension is retained. Whether to keep the reduced dimensions.
Returns Returns
======= =======
A Python function that can be JITed. A Python function that can be JITed.
""" """
# if len(axes) == 1:
# return create_axis_reducer(scalar_op, identity, axes[0], ndim, dtype)
axis = normalize_axis_index(axis, ndim) axes = normalize_axis_tuple(axes, ndim)
if keepdims and len(axes) > 1:
raise NotImplementedError(
"Cannot keep multiple dimensions when reducing multiple axes"
)
reduce_elemwise_fn_name = "careduce_axis" careduce_fn_name = f"careduce_{scalar_op}"
identity = str(identity) identity = str(identity)
if identity == "inf": if identity == "inf":
...@@ -235,162 +229,55 @@ def create_axis_reducer( ...@@ -235,162 +229,55 @@ def create_axis_reducer(
"numba_basic": numba_basic, "numba_basic": numba_basic,
"out_dtype": dtype, "out_dtype": dtype,
} }
complete_reduction = len(axes) == ndim
kept_axis = tuple(i for i in range(ndim) if i not in axes)
res_indices = []
arr_indices = []
for i in range(ndim):
index_label = f"i{i}"
arr_indices.append(index_label)
if i not in axes:
res_indices.append(index_label)
res_indices = ", ".join(res_indices) if res_indices else ()
arr_indices = ", ".join(arr_indices) if arr_indices else ()
inplace_update_stmt = scalar_in_place_fn(
scalar_op, res_indices, "res", f"x[{arr_indices}]"
)
if ndim > 1: res_shape = f"({', '.join(f'x_shape[{i}]' for i in kept_axis)})"
res_shape_tuple_ctor = create_tuple_creator( if complete_reduction and ndim > 0:
lambda i, shape: shape[i] if i < axis else shape[i + 1], ndim - 1 # We accumulate on a scalar, not an array
) res_creator = f"np.asarray({identity}).astype(out_dtype).item()"
global_env["res_shape_tuple_ctor"] = res_shape_tuple_ctor inplace_update_stmt = inplace_update_stmt.replace("res[()]", "res")
return_obj = "np.asarray(res)"
res_indices = []
arr_indices = []
count = 0
for i in range(ndim):
if i == axis:
arr_indices.append("i")
else:
res_indices.append(f"idx_arr[{count}]")
arr_indices.append(f"idx_arr[{count}]")
count = count + 1
res_indices = ", ".join(res_indices)
arr_indices = ", ".join(arr_indices)
inplace_update_statement = scalar_in_place_fn(
scalar_op, res_indices, "res", f"x[{arr_indices}]"
)
inplace_update_statement = indent(inplace_update_statement, " " * 4 * 3)
return_expr = f"np.expand_dims(res, {axis})" if keepdims else "res"
reduce_elemwise_def_src = f"""
def {reduce_elemwise_fn_name}(x):
x_shape = np.shape(x)
res_shape = res_shape_tuple_ctor(x_shape)
res = np.full(res_shape, numba_basic.to_scalar({identity}), dtype=out_dtype)
axis_shape = x.shape[{axis}]
for idx_arr in np.ndindex(res_shape):
for i in range(axis_shape):
{inplace_update_statement}
return {return_expr}
"""
else: else:
inplace_update_statement = scalar_in_place_fn(scalar_op, "0", "res", "x[i]") res_creator = (
inplace_update_statement = indent(inplace_update_statement, " " * 4 * 2) f"np.full({res_shape}, np.asarray({identity}).item(), dtype=out_dtype)"
)
return_expr = "res" if keepdims else "res.item()" return_obj = "res"
if not return_scalar:
return_expr = f"np.asarray({return_expr})" if keepdims:
reduce_elemwise_def_src = f""" [axis] = axes
def {reduce_elemwise_fn_name}(x): return_obj = f"np.expand_dims({return_obj}, {axis})"
res = np.full(1, numba_basic.to_scalar({identity}), dtype=out_dtype) careduce_def_src = dedent(
f"""
axis_shape = x.shape[{axis}] def {careduce_fn_name}(x):
x_shape = x.shape
for i in range(axis_shape): res_shape = {res_shape}
{inplace_update_statement} res = {res_creator}
return {return_expr}
""" """
reduce_elemwise_fn_py = compile_function_src(
reduce_elemwise_def_src, reduce_elemwise_fn_name, {**globals(), **global_env}
) )
for axis in range(ndim):
return reduce_elemwise_fn_py careduce_def_src += indent(
f"for i{axis} in range(x_shape[{axis}]):\n",
" " * (4 + 4 * axis),
def create_multiaxis_reducer(
scalar_op,
identity,
axes,
ndim,
dtype,
input_name="input",
return_scalar=False,
):
r"""Construct a function that reduces multiple axes.
The functions generated by this function take the following form:
.. code-block:: python
def careduce_maximum(input):
axis_0_res = careduce_axes_fn_0(input)
axis_1_res = careduce_axes_fn_1(axis_0_res)
...
axis_N_res = careduce_axes_fn_N(axis_N_minus_1_res)
return axis_N_res
The range 0-N is determined by the `axes` argument (i.e. the
axes to be reduced).
Parameters
==========
scalar_op:
The scalar :class:`Op` that performs the desired reduction.
identity:
The identity value for the reduction.
axes:
The axes to reduce.
ndim:
The number of dimensions of the result.
dtype:
The data type of the result.
return_scalar:
If True, return a scalar, otherwise an array.
Returns
=======
A Python function that can be JITed.
"""
if len(axes) == 1:
return create_axis_reducer(scalar_op, identity, axes[0], ndim, dtype)
axes = normalize_axis_tuple(axes, ndim)
careduce_fn_name = f"careduce_{scalar_op}"
global_env = {}
to_reduce = sorted(axes, reverse=True)
careduce_lines_src = []
var_name = input_name
for i, axis in enumerate(to_reduce):
careducer_axes_fn_name = f"careduce_axes_fn_{i}"
reducer_py_fn = create_axis_reducer(scalar_op, identity, axis, ndim, dtype)
reducer_fn = numba_basic.numba_njit(
boundscheck=False, fastmath=config.numba__fastmath
)(reducer_py_fn)
global_env[careducer_axes_fn_name] = reducer_fn
ndim -= 1
last_var_name = var_name
var_name = f"axis_{i}_res"
careduce_lines_src.append(
f"{var_name} = {careducer_axes_fn_name}({last_var_name})"
) )
careduce_def_src += indent(inplace_update_stmt, " " * (4 + 4 * ndim))
careduce_assign_lines = indent("\n".join(careduce_lines_src), " " * 4) careduce_def_src += "\n\n"
if not return_scalar: careduce_def_src += indent(f"return {return_obj}", " " * 4)
pre_result = "np.asarray"
post_result = ""
else:
pre_result = "np.asarray"
post_result = ".item()"
careduce_def_src = f"""
def {careduce_fn_name}({input_name}):
{careduce_assign_lines}
return {pre_result}({var_name}){post_result}
"""
careduce_fn = compile_function_src( careduce_fn = compile_function_src(
careduce_def_src, careduce_fn_name, {**globals(), **global_env} careduce_def_src, careduce_fn_name, {**globals(), **global_env}
...@@ -545,32 +432,29 @@ def numba_funcify_Elemwise(op, node, **kwargs): ...@@ -545,32 +432,29 @@ def numba_funcify_Elemwise(op, node, **kwargs):
@numba_funcify.register(Sum) @numba_funcify.register(Sum)
def numba_funcify_Sum(op, node, **kwargs): def numba_funcify_Sum(op, node, **kwargs):
ndim_input = node.inputs[0].ndim
axes = op.axis axes = op.axis
if axes is None: if axes is None:
axes = list(range(node.inputs[0].ndim)) axes = list(range(node.inputs[0].ndim))
else:
axes = tuple(axes) axes = normalize_axis_tuple(axes, ndim_input)
ndim_input = node.inputs[0].ndim
if hasattr(op, "acc_dtype") and op.acc_dtype is not None: if hasattr(op, "acc_dtype") and op.acc_dtype is not None:
acc_dtype = op.acc_dtype acc_dtype = op.acc_dtype
else: else:
acc_dtype = node.outputs[0].type.dtype acc_dtype = node.outputs[0].type.dtype
np_acc_dtype = np.dtype(acc_dtype) np_acc_dtype = np.dtype(acc_dtype)
out_dtype = np.dtype(node.outputs[0].dtype) out_dtype = np.dtype(node.outputs[0].dtype)
if ndim_input == len(axes): if ndim_input == len(axes):
# Slightly faster than `numba_funcify_CAReduce` for this case
@numba_njit(fastmath=True) @numba_njit(fastmath=config.numba__fastmath)
def impl_sum(array): def impl_sum(array):
return np.asarray(array.sum(), dtype=np_acc_dtype).astype(out_dtype) return np.asarray(array.sum(), dtype=np_acc_dtype).astype(out_dtype)
elif len(axes) == 0: elif len(axes) == 0:
# These cases should be removed by rewrites!
@numba_njit(fastmath=True) @numba_njit(fastmath=config.numba__fastmath)
def impl_sum(array): def impl_sum(array):
return np.asarray(array, dtype=out_dtype) return np.asarray(array, dtype=out_dtype)
...@@ -603,7 +487,6 @@ def numba_funcify_CAReduce(op, node, **kwargs): ...@@ -603,7 +487,6 @@ def numba_funcify_CAReduce(op, node, **kwargs):
# Make sure it has the correct dtype # Make sure it has the correct dtype
scalar_op_identity = np.array(scalar_op_identity, dtype=np_acc_dtype) scalar_op_identity = np.array(scalar_op_identity, dtype=np_acc_dtype)
input_name = get_name_for_object(node.inputs[0])
ndim = node.inputs[0].ndim ndim = node.inputs[0].ndim
careduce_py_fn = create_multiaxis_reducer( careduce_py_fn = create_multiaxis_reducer(
op.scalar_op, op.scalar_op,
...@@ -611,7 +494,6 @@ def numba_funcify_CAReduce(op, node, **kwargs): ...@@ -611,7 +494,6 @@ def numba_funcify_CAReduce(op, node, **kwargs):
axes, axes,
ndim, ndim,
np.dtype(node.outputs[0].type.dtype), np.dtype(node.outputs[0].type.dtype),
input_name=input_name,
) )
careduce_fn = jit_compile_reducer(node, careduce_py_fn, reduce_to_scalar=False) careduce_fn = jit_compile_reducer(node, careduce_py_fn, reduce_to_scalar=False)
...@@ -724,11 +606,11 @@ def numba_funcify_Softmax(op, node, **kwargs): ...@@ -724,11 +606,11 @@ def numba_funcify_Softmax(op, node, **kwargs):
if axis is not None: if axis is not None:
axis = normalize_axis_index(axis, x_at.ndim) axis = normalize_axis_index(axis, x_at.ndim)
reduce_max_py = create_axis_reducer( reduce_max_py = create_multiaxis_reducer(
scalar_maximum, -np.inf, axis, x_at.ndim, x_dtype, keepdims=True scalar_maximum, -np.inf, axis, x_at.ndim, x_dtype, keepdims=True
) )
reduce_sum_py = create_axis_reducer( reduce_sum_py = create_multiaxis_reducer(
add_as, 0.0, axis, x_at.ndim, x_dtype, keepdims=True add_as, 0.0, (axis,), x_at.ndim, x_dtype, keepdims=True
) )
jit_fn = numba_basic.numba_njit( jit_fn = numba_basic.numba_njit(
...@@ -761,8 +643,8 @@ def numba_funcify_SoftmaxGrad(op, node, **kwargs): ...@@ -761,8 +643,8 @@ def numba_funcify_SoftmaxGrad(op, node, **kwargs):
axis = op.axis axis = op.axis
if axis is not None: if axis is not None:
axis = normalize_axis_index(axis, sm_at.ndim) axis = normalize_axis_index(axis, sm_at.ndim)
reduce_sum_py = create_axis_reducer( reduce_sum_py = create_multiaxis_reducer(
add_as, 0.0, axis, sm_at.ndim, sm_dtype, keepdims=True add_as, 0.0, (axis,), sm_at.ndim, sm_dtype, keepdims=True
) )
jit_fn = numba_basic.numba_njit( jit_fn = numba_basic.numba_njit(
...@@ -793,16 +675,16 @@ def numba_funcify_LogSoftmax(op, node, **kwargs): ...@@ -793,16 +675,16 @@ def numba_funcify_LogSoftmax(op, node, **kwargs):
if axis is not None: if axis is not None:
axis = normalize_axis_index(axis, x_at.ndim) axis = normalize_axis_index(axis, x_at.ndim)
reduce_max_py = create_axis_reducer( reduce_max_py = create_multiaxis_reducer(
scalar_maximum, scalar_maximum,
-np.inf, -np.inf,
axis, (axis,),
x_at.ndim, x_at.ndim,
x_dtype, x_dtype,
keepdims=True, keepdims=True,
) )
reduce_sum_py = create_axis_reducer( reduce_sum_py = create_multiaxis_reducer(
add_as, 0.0, axis, x_at.ndim, x_dtype, keepdims=True add_as, 0.0, (axis,), x_at.ndim, x_dtype, keepdims=True
) )
jit_fn = numba_basic.numba_njit( jit_fn = numba_basic.numba_njit(
......
...@@ -15,7 +15,7 @@ from pytensor.compile.sharedvalue import SharedVariable ...@@ -15,7 +15,7 @@ from pytensor.compile.sharedvalue import SharedVariable
from pytensor.gradient import grad from pytensor.gradient import grad
from pytensor.graph.basic import Constant from pytensor.graph.basic import Constant
from pytensor.graph.fg import FunctionGraph from pytensor.graph.fg import FunctionGraph
from pytensor.tensor.elemwise import DimShuffle from pytensor.tensor.elemwise import CAReduce, DimShuffle
from pytensor.tensor.math import All, Any, Max, Min, Prod, ProdWithoutZeros, Sum from pytensor.tensor.math import All, Any, Max, Min, Prod, ProdWithoutZeros, Sum
from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad
from tests.link.numba.test_basic import ( from tests.link.numba.test_basic import (
...@@ -23,7 +23,7 @@ from tests.link.numba.test_basic import ( ...@@ -23,7 +23,7 @@ from tests.link.numba.test_basic import (
scalar_my_multi_out, scalar_my_multi_out,
set_test_value, set_test_value,
) )
from tests.tensor.test_elemwise import TestElemwise from tests.tensor.test_elemwise import TestElemwise, careduce_benchmark_tester
rng = np.random.default_rng(42849) rng = np.random.default_rng(42849)
...@@ -249,12 +249,12 @@ def test_Dimshuffle_non_contiguous(): ...@@ -249,12 +249,12 @@ def test_Dimshuffle_non_contiguous():
( (
lambda x, axis=None, dtype=None, acc_dtype=None: All(axis)(x), lambda x, axis=None, dtype=None, acc_dtype=None: All(axis)(x),
0, 0,
set_test_value(pt.vector(), np.arange(3, dtype=config.floatX)), set_test_value(pt.vector(dtype="bool"), np.array([False, True, False])),
), ),
( (
lambda x, axis=None, dtype=None, acc_dtype=None: Any(axis)(x), lambda x, axis=None, dtype=None, acc_dtype=None: Any(axis)(x),
0, 0,
set_test_value(pt.vector(), np.arange(3, dtype=config.floatX)), set_test_value(pt.vector(dtype="bool"), np.array([False, True, False])),
), ),
( (
lambda x, axis=None, dtype=None, acc_dtype=None: Sum( lambda x, axis=None, dtype=None, acc_dtype=None: Sum(
...@@ -301,6 +301,24 @@ def test_Dimshuffle_non_contiguous(): ...@@ -301,6 +301,24 @@ def test_Dimshuffle_non_contiguous():
pt.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2)) pt.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2))
), ),
), ),
(
lambda x, axis=None, dtype=None, acc_dtype=None: Prod(
axis=axis, dtype=dtype, acc_dtype=acc_dtype
)(x),
(), # Empty axes would normally be rewritten away, but we want to test it still works
set_test_value(
pt.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2))
),
),
(
lambda x, axis=None, dtype=None, acc_dtype=None: Prod(
axis=axis, dtype=dtype, acc_dtype=acc_dtype
)(x),
None,
set_test_value(
pt.scalar(), np.array(99.0, dtype=config.floatX)
), # Scalar input would normally be rewritten away, but we want to test it still works
),
( (
lambda x, axis=None, dtype=None, acc_dtype=None: Prod( lambda x, axis=None, dtype=None, acc_dtype=None: Prod(
axis=axis, dtype=dtype, acc_dtype=acc_dtype axis=axis, dtype=dtype, acc_dtype=acc_dtype
...@@ -367,7 +385,7 @@ def test_CAReduce(careduce_fn, axis, v): ...@@ -367,7 +385,7 @@ def test_CAReduce(careduce_fn, axis, v):
g = careduce_fn(v, axis=axis) g = careduce_fn(v, axis=axis)
g_fg = FunctionGraph(outputs=[g]) g_fg = FunctionGraph(outputs=[g])
compare_numba_and_py( fn, _ = compare_numba_and_py(
g_fg, g_fg,
[ [
i.tag.test_value i.tag.test_value
...@@ -375,6 +393,10 @@ def test_CAReduce(careduce_fn, axis, v): ...@@ -375,6 +393,10 @@ def test_CAReduce(careduce_fn, axis, v):
if not isinstance(i, SharedVariable | Constant) if not isinstance(i, SharedVariable | Constant)
], ],
) )
# Confirm CAReduce is in the compiled function
fn.dprint()
[node] = fn.maker.fgraph.apply_nodes
assert isinstance(node.op, CAReduce)
def test_scalar_Elemwise_Clip(): def test_scalar_Elemwise_Clip():
...@@ -619,10 +641,10 @@ def test_logsumexp_benchmark(size, axis, benchmark): ...@@ -619,10 +641,10 @@ def test_logsumexp_benchmark(size, axis, benchmark):
X_lse_fn = pytensor.function([X], X_lse, mode="NUMBA") X_lse_fn = pytensor.function([X], X_lse, mode="NUMBA")
# JIT compile first # JIT compile first
_ = X_lse_fn(X_val) res = X_lse_fn(X_val)
res = benchmark(X_lse_fn, X_val)
exp_res = scipy.special.logsumexp(X_val, axis=axis, keepdims=True) exp_res = scipy.special.logsumexp(X_val, axis=axis, keepdims=True)
np.testing.assert_array_almost_equal(res, exp_res) np.testing.assert_array_almost_equal(res, exp_res)
benchmark(X_lse_fn, X_val)
def test_fused_elemwise_benchmark(benchmark): def test_fused_elemwise_benchmark(benchmark):
...@@ -653,3 +675,19 @@ def test_elemwise_out_type(): ...@@ -653,3 +675,19 @@ def test_elemwise_out_type():
x_val = np.broadcast_to(np.zeros((3,)), (6, 3)) x_val = np.broadcast_to(np.zeros((3,)), (6, 3))
assert func(x_val).shape == (18,) assert func(x_val).shape == (18,)
@pytest.mark.parametrize(
"axis",
(0, 1, 2, (0, 1), (0, 2), (1, 2), None),
ids=lambda x: f"axis={x}",
)
@pytest.mark.parametrize(
"c_contiguous",
(True, False),
ids=lambda x: f"c_contiguous={x}",
)
def test_numba_careduce_benchmark(axis, c_contiguous, benchmark):
return careduce_benchmark_tester(
axis, c_contiguous, mode="NUMBA", benchmark=benchmark
)
...@@ -983,27 +983,33 @@ class TestVectorize: ...@@ -983,27 +983,33 @@ class TestVectorize:
assert vect_node.inputs[0] is bool_tns assert vect_node.inputs[0] is bool_tns
@pytest.mark.parametrize( def careduce_benchmark_tester(axis, c_contiguous, mode, benchmark):
"axis",
(0, 1, 2, (0, 1), (0, 2), (1, 2), None),
ids=lambda x: f"axis={x}",
)
@pytest.mark.parametrize(
"c_contiguous",
(True, False),
ids=lambda x: f"c_contiguous={x}",
)
def test_careduce_benchmark(axis, c_contiguous, benchmark):
N = 256 N = 256
x_test = np.random.uniform(size=(N, N, N)) x_test = np.random.uniform(size=(N, N, N))
transpose_axis = (0, 1, 2) if c_contiguous else (2, 0, 1) transpose_axis = (0, 1, 2) if c_contiguous else (2, 0, 1)
x = pytensor.shared(x_test, name="x", shape=x_test.shape) x = pytensor.shared(x_test, name="x", shape=x_test.shape)
out = x.transpose(transpose_axis).sum(axis=axis) out = x.transpose(transpose_axis).sum(axis=axis)
fn = pytensor.function([], out) fn = pytensor.function([], out, mode=mode)
np.testing.assert_allclose( np.testing.assert_allclose(
fn(), fn(),
x_test.transpose(transpose_axis).sum(axis=axis), x_test.transpose(transpose_axis).sum(axis=axis),
) )
benchmark(fn) benchmark(fn)
@pytest.mark.parametrize(
"axis",
(0, 1, 2, (0, 1), (0, 2), (1, 2), None),
ids=lambda x: f"axis={x}",
)
@pytest.mark.parametrize(
"c_contiguous",
(True, False),
ids=lambda x: f"c_contiguous={x}",
)
def test_c_careduce_benchmark(axis, c_contiguous, benchmark):
return careduce_benchmark_tester(
axis, c_contiguous, mode="FAST_RUN", benchmark=benchmark
)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论