提交 33998b20 authored 作者: kc611's avatar kc611 提交者: Brandon T. Willard

Added more optimizations to the Numba cheap pass-manager

This only applies to reduction `Op`s (e.g. `CAReduce`).
上级 66760618
import operator
import warnings
from contextlib import contextmanager
from functools import singledispatch
import numba
......@@ -57,14 +58,31 @@ def numba_vectorize(*args, **kwargs):
def get_numba_type(
aesara_type: Type, layout: str = "A", force_scalar: bool = False
aesara_type: Type,
layout: str = "A",
force_scalar: bool = False,
reduce_to_scalar: bool = False,
) -> numba.types.Type:
"""Create a Numba type object for a ``Type``."""
r"""Create a Numba type object for a :class:`Type`.
Parameters
----------
aesara_type
The :class:`Type` to convert.
layout
The :class:`numpy.ndarray` layout to use.
force_scalar
Ignore dimension information and return the corresponding Numba scalar types.
reduce_to_scalar
Return Numba scalars for zero dimensional :class:`TensorType`\s.
"""
if isinstance(aesara_type, TensorType):
dtype = aesara_type.numpy_dtype
numba_dtype = numba.from_dtype(dtype)
if force_scalar:
if force_scalar or (
reduce_to_scalar and getattr(aesara_type, "ndim", None) == 0
):
return numba_dtype
return numba.types.Array(numba_dtype, aesara_type.ndim, layout)
elif isinstance(aesara_type, Scalar):
......@@ -75,15 +93,25 @@ def get_numba_type(
raise NotImplementedError(f"Numba type not implemented for {aesara_type}")
def create_numba_signature(node: Apply, force_scalar: bool = False) -> numba.types.Type:
def create_numba_signature(
node: Apply, force_scalar: bool = False, reduce_to_scalar: bool = False
) -> numba.types.Type:
"""Create a Numba type for the signature of an ``Apply`` node."""
input_types = []
for inp in node.inputs:
input_types.append(get_numba_type(inp.type, force_scalar=force_scalar))
input_types.append(
get_numba_type(
inp.type, force_scalar=force_scalar, reduce_to_scalar=reduce_to_scalar
)
)
output_types = []
for out in node.outputs:
output_types.append(get_numba_type(out.type, force_scalar=force_scalar))
output_types.append(
get_numba_type(
out.type, force_scalar=force_scalar, reduce_to_scalar=reduce_to_scalar
)
)
if len(output_types) > 1:
return numba.types.Tuple(output_types)(*input_types)
......@@ -263,6 +291,23 @@ def create_arg_string(x):
return args
@contextmanager
def use_optimized_cheap_pass(*args, **kwargs):
"""Temporarily replace the cheap optimization pass with a better one."""
from numba.core.registry import cpu_target
context = cpu_target.target_context._internal_codegen
old_pm = context._mpm_cheap
new_pm = context._module_pass_manager(
loop_vectorize=True, slp_vectorize=True, opt=3, cost="cheap"
)
context._mpm_cheap = new_pm
try:
yield
finally:
context._mpm_cheap = old_pm
@singledispatch
def numba_typify(data, dtype=None, **kwargs):
return data
......
......@@ -37,6 +37,7 @@ from aesara.tensor import elemwise as at_elemwise
from aesara.tensor import extra_ops, nlinalg, slinalg
from aesara.tensor import subtensor as at_subtensor
from aesara.tensor.elemwise import Elemwise
from aesara.tensor.math import All, Any, Max, Mean, Min, Prod, ProdWithoutZeros, Sum
from aesara.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape
......@@ -1049,94 +1050,132 @@ def test_ARange(start, stop, step, dtype):
@pytest.mark.parametrize(
"careduce_fn, axis, v, keepdims",
"careduce_fn, axis, v",
[
(
at.sum,
lambda x, axis=None, dtype=None, acc_dtype=None: Sum(
axis=axis, dtype=dtype, acc_dtype=acc_dtype
)(x),
0,
set_test_value(at.vector(), np.arange(3, dtype=config.floatX)),
False,
),
(
at.all,
lambda x, axis=None, dtype=None, acc_dtype=None: All(axis)(x),
0,
set_test_value(at.vector(), np.arange(3, dtype=config.floatX)),
),
(
lambda x, axis=None, dtype=None, acc_dtype=None: Any(axis)(x),
0,
set_test_value(at.vector(), np.arange(3, dtype=config.floatX)),
),
(
lambda x, axis=None, dtype=None, acc_dtype=None: Mean(axis)(x),
0,
set_test_value(at.vector(), np.arange(3, dtype=config.floatX)),
False,
),
(
at.sum,
lambda x, axis=None, dtype=None, acc_dtype=None: Mean(axis)(x),
0,
set_test_value(
at.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2))
),
False,
),
(
at.sum,
lambda x, axis=None, dtype=None, acc_dtype=None: Sum(
axis=axis, dtype=dtype, acc_dtype=acc_dtype
)(x),
0,
set_test_value(
at.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2))
),
),
(
lambda x, axis=None, dtype=None, acc_dtype=None: Sum(
axis=axis, dtype=dtype, acc_dtype=acc_dtype
)(x),
(0, 1),
set_test_value(
at.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2))
),
False,
),
(
at.sum,
lambda x, axis=None, dtype=None, acc_dtype=None: Sum(
axis=axis, dtype=dtype, acc_dtype=acc_dtype
)(x),
(1, 0),
set_test_value(
at.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2))
),
False,
),
(
at.sum,
lambda x, axis=None, dtype=None, acc_dtype=None: Sum(
axis=axis, dtype=dtype, acc_dtype=acc_dtype
)(x),
None,
set_test_value(
at.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2))
),
False,
),
(
at.sum,
lambda x, axis=None, dtype=None, acc_dtype=None: Sum(
axis=axis, dtype=dtype, acc_dtype=acc_dtype
)(x),
1,
set_test_value(
at.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2))
),
False,
),
(
at.prod,
lambda x, axis=None, dtype=None, acc_dtype=None: Prod(
axis=axis, dtype=dtype, acc_dtype=acc_dtype
)(x),
0,
set_test_value(at.vector(), np.arange(3, dtype=config.floatX)),
False,
),
(
at.prod,
lambda x, axis=None, dtype=None, acc_dtype=None: ProdWithoutZeros(
axis=axis, dtype=dtype, acc_dtype=acc_dtype
)(x),
0,
set_test_value(at.vector(), np.arange(3, dtype=config.floatX)),
),
(
lambda x, axis=None, dtype=None, acc_dtype=None: Prod(
axis=axis, dtype=dtype, acc_dtype=acc_dtype
)(x),
0,
set_test_value(
at.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2))
),
False,
),
(
at.prod,
lambda x, axis=None, dtype=None, acc_dtype=None: Prod(
axis=axis, dtype=dtype, acc_dtype=acc_dtype
)(x),
1,
set_test_value(
at.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2))
),
False,
),
(
at.max,
lambda x, axis=None, dtype=None, acc_dtype=None: Max(axis)(x),
None,
set_test_value(
at.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2))
),
),
(
lambda x, axis=None, dtype=None, acc_dtype=None: Min(axis)(x),
None,
set_test_value(
at.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2))
),
True,
),
],
)
def test_CAReduce(careduce_fn, axis, v, keepdims):
g = careduce_fn(v, axis=axis, keepdims=keepdims)
def test_CAReduce(careduce_fn, axis, v):
g = careduce_fn(v, axis=axis)
g_fg = FunctionGraph(outputs=[g])
compare_numba_and_py(
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论