提交 ef97287b authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Improve CAReduce Numba implementation

上级 9e24b10a
from collections.abc import Callable
from functools import singledispatch
from numbers import Number
from textwrap import indent
from textwrap import dedent, indent
from typing import Any
import numba
......@@ -15,7 +14,6 @@ from pytensor.graph.op import Op
from pytensor.link.numba.dispatch import basic as numba_basic
from pytensor.link.numba.dispatch.basic import (
create_numba_signature,
create_tuple_creator,
numba_funcify,
numba_njit,
use_optimized_cheap_pass,
......@@ -26,7 +24,7 @@ from pytensor.link.numba.dispatch.vectorize_codegen import (
encode_literals,
store_core_outputs,
)
from pytensor.link.utils import compile_function_src, get_name_for_object
from pytensor.link.utils import compile_function_src
from pytensor.scalar.basic import (
AND,
OR,
......@@ -163,40 +161,32 @@ def create_vectorize_func(
return elemwise_fn
def create_axis_reducer(
scalar_op: Op,
identity: np.ndarray | Number,
axis: int,
ndim: int,
dtype: numba.types.Type,
def create_multiaxis_reducer(
scalar_op,
identity,
axes,
ndim,
dtype,
keepdims: bool = False,
return_scalar=False,
) -> numba.core.dispatcher.Dispatcher:
r"""Create Python function that performs a NumPy-like reduction on a given axis.
):
r"""Construct a function that reduces multiple axes.
The functions generated by this function take the following form:
.. code-block:: python
def careduce_axis(x):
res_shape = tuple(
shape[i] if i < axis else shape[i + 1] for i in range(ndim - 1)
)
res = np.full(res_shape, identity, dtype=dtype)
x_axis_first = x.transpose(reaxis_first)
for m in range(x.shape[axis]):
reduce_fn(res, x_axis_first[m], res)
if keepdims:
return np.expand_dims(res, axis)
else:
return res
def careduce_add(x):
# For x.ndim == 3 and axes == (0, 1) and scalar_op == "Add"
x_shape = x.shape
res_shape = x_shape[2]
res = np.full(res_shape, numba_basic.to_scalar(0.0), dtype=out_dtype)
for i0 in range(x_shape[0]):
for i1 in range(x_shape[1]):
for i2 in range(x_shape[2]):
res[i2] += x[i0, i1, i2]
This can be removed/replaced when
https://github.com/numba/numba/issues/4504 is implemented.
return res
Parameters
==========
......@@ -204,25 +194,29 @@ def create_axis_reducer(
The scalar :class:`Op` that performs the desired reduction.
identity:
The identity value for the reduction.
axis:
The axis to reduce.
axes:
The axes to reduce.
ndim:
The number of dimensions of the result.
The number of dimensions of the input variable.
dtype:
The data type of the result.
keepdims:
Determines whether or not the reduced dimension is retained.
keepdims: boolean, default False
Whether to keep the reduced dimensions.
Returns
=======
A Python function that can be JITed.
"""
# if len(axes) == 1:
# return create_axis_reducer(scalar_op, identity, axes[0], ndim, dtype)
axis = normalize_axis_index(axis, ndim)
axes = normalize_axis_tuple(axes, ndim)
if keepdims and len(axes) > 1:
raise NotImplementedError(
"Cannot keep multiple dimensions when reducing multiple axes"
)
reduce_elemwise_fn_name = "careduce_axis"
careduce_fn_name = f"careduce_{scalar_op}"
identity = str(identity)
if identity == "inf":
......@@ -235,162 +229,55 @@ def create_axis_reducer(
"numba_basic": numba_basic,
"out_dtype": dtype,
}
complete_reduction = len(axes) == ndim
kept_axis = tuple(i for i in range(ndim) if i not in axes)
res_indices = []
arr_indices = []
for i in range(ndim):
index_label = f"i{i}"
arr_indices.append(index_label)
if i not in axes:
res_indices.append(index_label)
res_indices = ", ".join(res_indices) if res_indices else ()
arr_indices = ", ".join(arr_indices) if arr_indices else ()
inplace_update_stmt = scalar_in_place_fn(
scalar_op, res_indices, "res", f"x[{arr_indices}]"
)
if ndim > 1:
res_shape_tuple_ctor = create_tuple_creator(
lambda i, shape: shape[i] if i < axis else shape[i + 1], ndim - 1
)
global_env["res_shape_tuple_ctor"] = res_shape_tuple_ctor
res_indices = []
arr_indices = []
count = 0
for i in range(ndim):
if i == axis:
arr_indices.append("i")
else:
res_indices.append(f"idx_arr[{count}]")
arr_indices.append(f"idx_arr[{count}]")
count = count + 1
res_indices = ", ".join(res_indices)
arr_indices = ", ".join(arr_indices)
inplace_update_statement = scalar_in_place_fn(
scalar_op, res_indices, "res", f"x[{arr_indices}]"
)
inplace_update_statement = indent(inplace_update_statement, " " * 4 * 3)
return_expr = f"np.expand_dims(res, {axis})" if keepdims else "res"
reduce_elemwise_def_src = f"""
def {reduce_elemwise_fn_name}(x):
x_shape = np.shape(x)
res_shape = res_shape_tuple_ctor(x_shape)
res = np.full(res_shape, numba_basic.to_scalar({identity}), dtype=out_dtype)
axis_shape = x.shape[{axis}]
for idx_arr in np.ndindex(res_shape):
for i in range(axis_shape):
{inplace_update_statement}
return {return_expr}
"""
res_shape = f"({', '.join(f'x_shape[{i}]' for i in kept_axis)})"
if complete_reduction and ndim > 0:
# We accumulate on a scalar, not an array
res_creator = f"np.asarray({identity}).astype(out_dtype).item()"
inplace_update_stmt = inplace_update_stmt.replace("res[()]", "res")
return_obj = "np.asarray(res)"
else:
inplace_update_statement = scalar_in_place_fn(scalar_op, "0", "res", "x[i]")
inplace_update_statement = indent(inplace_update_statement, " " * 4 * 2)
return_expr = "res" if keepdims else "res.item()"
if not return_scalar:
return_expr = f"np.asarray({return_expr})"
reduce_elemwise_def_src = f"""
def {reduce_elemwise_fn_name}(x):
res = np.full(1, numba_basic.to_scalar({identity}), dtype=out_dtype)
axis_shape = x.shape[{axis}]
for i in range(axis_shape):
{inplace_update_statement}
return {return_expr}
res_creator = (
f"np.full({res_shape}, np.asarray({identity}).item(), dtype=out_dtype)"
)
return_obj = "res"
if keepdims:
[axis] = axes
return_obj = f"np.expand_dims({return_obj}, {axis})"
careduce_def_src = dedent(
f"""
def {careduce_fn_name}(x):
x_shape = x.shape
res_shape = {res_shape}
res = {res_creator}
"""
reduce_elemwise_fn_py = compile_function_src(
reduce_elemwise_def_src, reduce_elemwise_fn_name, {**globals(), **global_env}
)
return reduce_elemwise_fn_py
def create_multiaxis_reducer(
scalar_op,
identity,
axes,
ndim,
dtype,
input_name="input",
return_scalar=False,
):
r"""Construct a function that reduces multiple axes.
The functions generated by this function take the following form:
.. code-block:: python
def careduce_maximum(input):
axis_0_res = careduce_axes_fn_0(input)
axis_1_res = careduce_axes_fn_1(axis_0_res)
...
axis_N_res = careduce_axes_fn_N(axis_N_minus_1_res)
return axis_N_res
The range 0-N is determined by the `axes` argument (i.e. the
axes to be reduced).
Parameters
==========
scalar_op:
The scalar :class:`Op` that performs the desired reduction.
identity:
The identity value for the reduction.
axes:
The axes to reduce.
ndim:
The number of dimensions of the result.
dtype:
The data type of the result.
return_scalar:
If True, return a scalar, otherwise an array.
Returns
=======
A Python function that can be JITed.
"""
if len(axes) == 1:
return create_axis_reducer(scalar_op, identity, axes[0], ndim, dtype)
axes = normalize_axis_tuple(axes, ndim)
careduce_fn_name = f"careduce_{scalar_op}"
global_env = {}
to_reduce = sorted(axes, reverse=True)
careduce_lines_src = []
var_name = input_name
for i, axis in enumerate(to_reduce):
careducer_axes_fn_name = f"careduce_axes_fn_{i}"
reducer_py_fn = create_axis_reducer(scalar_op, identity, axis, ndim, dtype)
reducer_fn = numba_basic.numba_njit(
boundscheck=False, fastmath=config.numba__fastmath
)(reducer_py_fn)
global_env[careducer_axes_fn_name] = reducer_fn
ndim -= 1
last_var_name = var_name
var_name = f"axis_{i}_res"
careduce_lines_src.append(
f"{var_name} = {careducer_axes_fn_name}({last_var_name})"
for axis in range(ndim):
careduce_def_src += indent(
f"for i{axis} in range(x_shape[{axis}]):\n",
" " * (4 + 4 * axis),
)
careduce_assign_lines = indent("\n".join(careduce_lines_src), " " * 4)
if not return_scalar:
pre_result = "np.asarray"
post_result = ""
else:
pre_result = "np.asarray"
post_result = ".item()"
careduce_def_src = f"""
def {careduce_fn_name}({input_name}):
{careduce_assign_lines}
return {pre_result}({var_name}){post_result}
"""
careduce_def_src += indent(inplace_update_stmt, " " * (4 + 4 * ndim))
careduce_def_src += "\n\n"
careduce_def_src += indent(f"return {return_obj}", " " * 4)
careduce_fn = compile_function_src(
careduce_def_src, careduce_fn_name, {**globals(), **global_env}
......@@ -545,32 +432,29 @@ def numba_funcify_Elemwise(op, node, **kwargs):
@numba_funcify.register(Sum)
def numba_funcify_Sum(op, node, **kwargs):
ndim_input = node.inputs[0].ndim
axes = op.axis
if axes is None:
axes = list(range(node.inputs[0].ndim))
axes = tuple(axes)
ndim_input = node.inputs[0].ndim
else:
axes = normalize_axis_tuple(axes, ndim_input)
if hasattr(op, "acc_dtype") and op.acc_dtype is not None:
acc_dtype = op.acc_dtype
else:
acc_dtype = node.outputs[0].type.dtype
np_acc_dtype = np.dtype(acc_dtype)
out_dtype = np.dtype(node.outputs[0].dtype)
if ndim_input == len(axes):
@numba_njit(fastmath=True)
# Slightly faster than `numba_funcify_CAReduce` for this case
@numba_njit(fastmath=config.numba__fastmath)
def impl_sum(array):
return np.asarray(array.sum(), dtype=np_acc_dtype).astype(out_dtype)
elif len(axes) == 0:
@numba_njit(fastmath=True)
# These cases should be removed by rewrites!
@numba_njit(fastmath=config.numba__fastmath)
def impl_sum(array):
return np.asarray(array, dtype=out_dtype)
......@@ -603,7 +487,6 @@ def numba_funcify_CAReduce(op, node, **kwargs):
# Make sure it has the correct dtype
scalar_op_identity = np.array(scalar_op_identity, dtype=np_acc_dtype)
input_name = get_name_for_object(node.inputs[0])
ndim = node.inputs[0].ndim
careduce_py_fn = create_multiaxis_reducer(
op.scalar_op,
......@@ -611,7 +494,6 @@ def numba_funcify_CAReduce(op, node, **kwargs):
axes,
ndim,
np.dtype(node.outputs[0].type.dtype),
input_name=input_name,
)
careduce_fn = jit_compile_reducer(node, careduce_py_fn, reduce_to_scalar=False)
......@@ -724,11 +606,11 @@ def numba_funcify_Softmax(op, node, **kwargs):
if axis is not None:
axis = normalize_axis_index(axis, x_at.ndim)
reduce_max_py = create_axis_reducer(
reduce_max_py = create_multiaxis_reducer(
scalar_maximum, -np.inf, axis, x_at.ndim, x_dtype, keepdims=True
)
reduce_sum_py = create_axis_reducer(
add_as, 0.0, axis, x_at.ndim, x_dtype, keepdims=True
reduce_sum_py = create_multiaxis_reducer(
add_as, 0.0, (axis,), x_at.ndim, x_dtype, keepdims=True
)
jit_fn = numba_basic.numba_njit(
......@@ -761,8 +643,8 @@ def numba_funcify_SoftmaxGrad(op, node, **kwargs):
axis = op.axis
if axis is not None:
axis = normalize_axis_index(axis, sm_at.ndim)
reduce_sum_py = create_axis_reducer(
add_as, 0.0, axis, sm_at.ndim, sm_dtype, keepdims=True
reduce_sum_py = create_multiaxis_reducer(
add_as, 0.0, (axis,), sm_at.ndim, sm_dtype, keepdims=True
)
jit_fn = numba_basic.numba_njit(
......@@ -793,16 +675,16 @@ def numba_funcify_LogSoftmax(op, node, **kwargs):
if axis is not None:
axis = normalize_axis_index(axis, x_at.ndim)
reduce_max_py = create_axis_reducer(
reduce_max_py = create_multiaxis_reducer(
scalar_maximum,
-np.inf,
axis,
(axis,),
x_at.ndim,
x_dtype,
keepdims=True,
)
reduce_sum_py = create_axis_reducer(
add_as, 0.0, axis, x_at.ndim, x_dtype, keepdims=True
reduce_sum_py = create_multiaxis_reducer(
add_as, 0.0, (axis,), x_at.ndim, x_dtype, keepdims=True
)
jit_fn = numba_basic.numba_njit(
......
......@@ -15,7 +15,7 @@ from pytensor.compile.sharedvalue import SharedVariable
from pytensor.gradient import grad
from pytensor.graph.basic import Constant
from pytensor.graph.fg import FunctionGraph
from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.elemwise import CAReduce, DimShuffle
from pytensor.tensor.math import All, Any, Max, Min, Prod, ProdWithoutZeros, Sum
from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad
from tests.link.numba.test_basic import (
......@@ -23,7 +23,7 @@ from tests.link.numba.test_basic import (
scalar_my_multi_out,
set_test_value,
)
from tests.tensor.test_elemwise import TestElemwise
from tests.tensor.test_elemwise import TestElemwise, careduce_benchmark_tester
rng = np.random.default_rng(42849)
......@@ -249,12 +249,12 @@ def test_Dimshuffle_non_contiguous():
(
lambda x, axis=None, dtype=None, acc_dtype=None: All(axis)(x),
0,
set_test_value(pt.vector(), np.arange(3, dtype=config.floatX)),
set_test_value(pt.vector(dtype="bool"), np.array([False, True, False])),
),
(
lambda x, axis=None, dtype=None, acc_dtype=None: Any(axis)(x),
0,
set_test_value(pt.vector(), np.arange(3, dtype=config.floatX)),
set_test_value(pt.vector(dtype="bool"), np.array([False, True, False])),
),
(
lambda x, axis=None, dtype=None, acc_dtype=None: Sum(
......@@ -301,6 +301,24 @@ def test_Dimshuffle_non_contiguous():
pt.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2))
),
),
(
lambda x, axis=None, dtype=None, acc_dtype=None: Prod(
axis=axis, dtype=dtype, acc_dtype=acc_dtype
)(x),
(), # Empty axes would normally be rewritten away, but we want to test it still works
set_test_value(
pt.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2))
),
),
(
lambda x, axis=None, dtype=None, acc_dtype=None: Prod(
axis=axis, dtype=dtype, acc_dtype=acc_dtype
)(x),
None,
set_test_value(
pt.scalar(), np.array(99.0, dtype=config.floatX)
), # Scalar input would normally be rewritten away, but we want to test it still works
),
(
lambda x, axis=None, dtype=None, acc_dtype=None: Prod(
axis=axis, dtype=dtype, acc_dtype=acc_dtype
......@@ -367,7 +385,7 @@ def test_CAReduce(careduce_fn, axis, v):
g = careduce_fn(v, axis=axis)
g_fg = FunctionGraph(outputs=[g])
compare_numba_and_py(
fn, _ = compare_numba_and_py(
g_fg,
[
i.tag.test_value
......@@ -375,6 +393,10 @@ def test_CAReduce(careduce_fn, axis, v):
if not isinstance(i, SharedVariable | Constant)
],
)
# Confirm CAReduce is in the compiled function
fn.dprint()
[node] = fn.maker.fgraph.apply_nodes
assert isinstance(node.op, CAReduce)
def test_scalar_Elemwise_Clip():
......@@ -619,10 +641,10 @@ def test_logsumexp_benchmark(size, axis, benchmark):
X_lse_fn = pytensor.function([X], X_lse, mode="NUMBA")
# JIT compile first
_ = X_lse_fn(X_val)
res = benchmark(X_lse_fn, X_val)
res = X_lse_fn(X_val)
exp_res = scipy.special.logsumexp(X_val, axis=axis, keepdims=True)
np.testing.assert_array_almost_equal(res, exp_res)
benchmark(X_lse_fn, X_val)
def test_fused_elemwise_benchmark(benchmark):
......@@ -653,3 +675,19 @@ def test_elemwise_out_type():
x_val = np.broadcast_to(np.zeros((3,)), (6, 3))
assert func(x_val).shape == (18,)
@pytest.mark.parametrize(
"axis",
(0, 1, 2, (0, 1), (0, 2), (1, 2), None),
ids=lambda x: f"axis={x}",
)
@pytest.mark.parametrize(
"c_contiguous",
(True, False),
ids=lambda x: f"c_contiguous={x}",
)
def test_numba_careduce_benchmark(axis, c_contiguous, benchmark):
return careduce_benchmark_tester(
axis, c_contiguous, mode="NUMBA", benchmark=benchmark
)
......@@ -983,27 +983,33 @@ class TestVectorize:
assert vect_node.inputs[0] is bool_tns
@pytest.mark.parametrize(
"axis",
(0, 1, 2, (0, 1), (0, 2), (1, 2), None),
ids=lambda x: f"axis={x}",
)
@pytest.mark.parametrize(
"c_contiguous",
(True, False),
ids=lambda x: f"c_contiguous={x}",
)
def test_careduce_benchmark(axis, c_contiguous, benchmark):
def careduce_benchmark_tester(axis, c_contiguous, mode, benchmark):
N = 256
x_test = np.random.uniform(size=(N, N, N))
transpose_axis = (0, 1, 2) if c_contiguous else (2, 0, 1)
x = pytensor.shared(x_test, name="x", shape=x_test.shape)
out = x.transpose(transpose_axis).sum(axis=axis)
fn = pytensor.function([], out)
fn = pytensor.function([], out, mode=mode)
np.testing.assert_allclose(
fn(),
x_test.transpose(transpose_axis).sum(axis=axis),
)
benchmark(fn)
@pytest.mark.parametrize(
"axis",
(0, 1, 2, (0, 1), (0, 2), (1, 2), None),
ids=lambda x: f"axis={x}",
)
@pytest.mark.parametrize(
"c_contiguous",
(True, False),
ids=lambda x: f"c_contiguous={x}",
)
def test_c_careduce_benchmark(axis, c_contiguous, benchmark):
return careduce_benchmark_tester(
axis, c_contiguous, mode="FAST_RUN", benchmark=benchmark
)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论