提交 33998b20 authored 作者: kc611's avatar kc611 提交者: Brandon T. Willard

Added more optimizations to the Numba cheap pass-manager

This only applies to reduction `Op`s (e.g. `CAReduce`).
上级 66760618
import operator import operator
import warnings import warnings
from contextlib import contextmanager
from functools import singledispatch from functools import singledispatch
import numba import numba
...@@ -57,14 +58,31 @@ def numba_vectorize(*args, **kwargs): ...@@ -57,14 +58,31 @@ def numba_vectorize(*args, **kwargs):
def get_numba_type( def get_numba_type(
aesara_type: Type, layout: str = "A", force_scalar: bool = False aesara_type: Type,
layout: str = "A",
force_scalar: bool = False,
reduce_to_scalar: bool = False,
) -> numba.types.Type: ) -> numba.types.Type:
"""Create a Numba type object for a ``Type``.""" r"""Create a Numba type object for a :class:`Type`.
Parameters
----------
aesara_type
The :class:`Type` to convert.
layout
The :class:`numpy.ndarray` layout to use.
force_scalar
Ignore dimension information and return the corresponding Numba scalar types.
reduce_to_scalar
Return Numba scalars for zero dimensional :class:`TensorType`\s.
"""
if isinstance(aesara_type, TensorType): if isinstance(aesara_type, TensorType):
dtype = aesara_type.numpy_dtype dtype = aesara_type.numpy_dtype
numba_dtype = numba.from_dtype(dtype) numba_dtype = numba.from_dtype(dtype)
if force_scalar: if force_scalar or (
reduce_to_scalar and getattr(aesara_type, "ndim", None) == 0
):
return numba_dtype return numba_dtype
return numba.types.Array(numba_dtype, aesara_type.ndim, layout) return numba.types.Array(numba_dtype, aesara_type.ndim, layout)
elif isinstance(aesara_type, Scalar): elif isinstance(aesara_type, Scalar):
...@@ -75,15 +93,25 @@ def get_numba_type( ...@@ -75,15 +93,25 @@ def get_numba_type(
raise NotImplementedError(f"Numba type not implemented for {aesara_type}") raise NotImplementedError(f"Numba type not implemented for {aesara_type}")
def create_numba_signature(node: Apply, force_scalar: bool = False) -> numba.types.Type: def create_numba_signature(
node: Apply, force_scalar: bool = False, reduce_to_scalar: bool = False
) -> numba.types.Type:
"""Create a Numba type for the signature of an ``Apply`` node.""" """Create a Numba type for the signature of an ``Apply`` node."""
input_types = [] input_types = []
for inp in node.inputs: for inp in node.inputs:
input_types.append(get_numba_type(inp.type, force_scalar=force_scalar)) input_types.append(
get_numba_type(
inp.type, force_scalar=force_scalar, reduce_to_scalar=reduce_to_scalar
)
)
output_types = [] output_types = []
for out in node.outputs: for out in node.outputs:
output_types.append(get_numba_type(out.type, force_scalar=force_scalar)) output_types.append(
get_numba_type(
out.type, force_scalar=force_scalar, reduce_to_scalar=reduce_to_scalar
)
)
if len(output_types) > 1: if len(output_types) > 1:
return numba.types.Tuple(output_types)(*input_types) return numba.types.Tuple(output_types)(*input_types)
...@@ -263,6 +291,23 @@ def create_arg_string(x): ...@@ -263,6 +291,23 @@ def create_arg_string(x):
return args return args
@contextmanager
def use_optimized_cheap_pass(*args, **kwargs):
"""Temporarily replace the cheap optimization pass with a better one."""
from numba.core.registry import cpu_target
context = cpu_target.target_context._internal_codegen
old_pm = context._mpm_cheap
new_pm = context._module_pass_manager(
loop_vectorize=True, slp_vectorize=True, opt=3, cost="cheap"
)
context._mpm_cheap = new_pm
try:
yield
finally:
context._mpm_cheap = old_pm
@singledispatch @singledispatch
def numba_typify(data, dtype=None, **kwargs): def numba_typify(data, dtype=None, **kwargs):
return data return data
......
...@@ -37,6 +37,7 @@ from aesara.tensor import elemwise as at_elemwise ...@@ -37,6 +37,7 @@ from aesara.tensor import elemwise as at_elemwise
from aesara.tensor import extra_ops, nlinalg, slinalg from aesara.tensor import extra_ops, nlinalg, slinalg
from aesara.tensor import subtensor as at_subtensor from aesara.tensor import subtensor as at_subtensor
from aesara.tensor.elemwise import Elemwise from aesara.tensor.elemwise import Elemwise
from aesara.tensor.math import All, Any, Max, Mean, Min, Prod, ProdWithoutZeros, Sum
from aesara.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape from aesara.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape
...@@ -1049,94 +1050,132 @@ def test_ARange(start, stop, step, dtype): ...@@ -1049,94 +1050,132 @@ def test_ARange(start, stop, step, dtype):
@pytest.mark.parametrize( @pytest.mark.parametrize(
"careduce_fn, axis, v, keepdims", "careduce_fn, axis, v",
[ [
( (
at.sum, lambda x, axis=None, dtype=None, acc_dtype=None: Sum(
axis=axis, dtype=dtype, acc_dtype=acc_dtype
)(x),
0, 0,
set_test_value(at.vector(), np.arange(3, dtype=config.floatX)), set_test_value(at.vector(), np.arange(3, dtype=config.floatX)),
False,
), ),
( (
at.all, lambda x, axis=None, dtype=None, acc_dtype=None: All(axis)(x),
0,
set_test_value(at.vector(), np.arange(3, dtype=config.floatX)),
),
(
lambda x, axis=None, dtype=None, acc_dtype=None: Any(axis)(x),
0,
set_test_value(at.vector(), np.arange(3, dtype=config.floatX)),
),
(
lambda x, axis=None, dtype=None, acc_dtype=None: Mean(axis)(x),
0, 0,
set_test_value(at.vector(), np.arange(3, dtype=config.floatX)), set_test_value(at.vector(), np.arange(3, dtype=config.floatX)),
False,
), ),
( (
at.sum, lambda x, axis=None, dtype=None, acc_dtype=None: Mean(axis)(x),
0, 0,
set_test_value( set_test_value(
at.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2)) at.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2))
), ),
False,
), ),
( (
at.sum, lambda x, axis=None, dtype=None, acc_dtype=None: Sum(
axis=axis, dtype=dtype, acc_dtype=acc_dtype
)(x),
0,
set_test_value(
at.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2))
),
),
(
lambda x, axis=None, dtype=None, acc_dtype=None: Sum(
axis=axis, dtype=dtype, acc_dtype=acc_dtype
)(x),
(0, 1), (0, 1),
set_test_value( set_test_value(
at.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2)) at.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2))
), ),
False,
), ),
( (
at.sum, lambda x, axis=None, dtype=None, acc_dtype=None: Sum(
axis=axis, dtype=dtype, acc_dtype=acc_dtype
)(x),
(1, 0), (1, 0),
set_test_value( set_test_value(
at.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2)) at.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2))
), ),
False,
), ),
( (
at.sum, lambda x, axis=None, dtype=None, acc_dtype=None: Sum(
axis=axis, dtype=dtype, acc_dtype=acc_dtype
)(x),
None, None,
set_test_value( set_test_value(
at.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2)) at.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2))
), ),
False,
), ),
( (
at.sum, lambda x, axis=None, dtype=None, acc_dtype=None: Sum(
axis=axis, dtype=dtype, acc_dtype=acc_dtype
)(x),
1, 1,
set_test_value( set_test_value(
at.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2)) at.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2))
), ),
False,
), ),
( (
at.prod, lambda x, axis=None, dtype=None, acc_dtype=None: Prod(
axis=axis, dtype=dtype, acc_dtype=acc_dtype
)(x),
0, 0,
set_test_value(at.vector(), np.arange(3, dtype=config.floatX)), set_test_value(at.vector(), np.arange(3, dtype=config.floatX)),
False,
), ),
( (
at.prod, lambda x, axis=None, dtype=None, acc_dtype=None: ProdWithoutZeros(
axis=axis, dtype=dtype, acc_dtype=acc_dtype
)(x),
0,
set_test_value(at.vector(), np.arange(3, dtype=config.floatX)),
),
(
lambda x, axis=None, dtype=None, acc_dtype=None: Prod(
axis=axis, dtype=dtype, acc_dtype=acc_dtype
)(x),
0, 0,
set_test_value( set_test_value(
at.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2)) at.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2))
), ),
False,
), ),
( (
at.prod, lambda x, axis=None, dtype=None, acc_dtype=None: Prod(
axis=axis, dtype=dtype, acc_dtype=acc_dtype
)(x),
1, 1,
set_test_value( set_test_value(
at.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2)) at.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2))
), ),
False,
), ),
( (
at.max, lambda x, axis=None, dtype=None, acc_dtype=None: Max(axis)(x),
None,
set_test_value(
at.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2))
),
),
(
lambda x, axis=None, dtype=None, acc_dtype=None: Min(axis)(x),
None, None,
set_test_value( set_test_value(
at.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2)) at.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2))
), ),
True,
), ),
], ],
) )
def test_CAReduce(careduce_fn, axis, v, keepdims): def test_CAReduce(careduce_fn, axis, v):
g = careduce_fn(v, axis=axis, keepdims=keepdims) g = careduce_fn(v, axis=axis)
g_fg = FunctionGraph(outputs=[g]) g_fg = FunctionGraph(outputs=[g])
compare_numba_and_py( compare_numba_and_py(
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论