提交 78c94889 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Fix Numba conversion of N-ary Sum and Prod scalar Ops

上级 d03acf4e
...@@ -3,7 +3,7 @@ import warnings ...@@ -3,7 +3,7 @@ import warnings
from functools import reduce, singledispatch from functools import reduce, singledispatch
from numbers import Number from numbers import Number
from textwrap import indent from textwrap import indent
from typing import Union from typing import List, Union
import numba import numba
import numpy as np import numpy as np
...@@ -18,7 +18,7 @@ from numba.np.unsafe.ndarray import to_fixed_tuple ...@@ -18,7 +18,7 @@ from numba.np.unsafe.ndarray import to_fixed_tuple
from numpy.core.multiarray import normalize_axis_index from numpy.core.multiarray import normalize_axis_index
from aesara.compile.ops import DeepCopyOp, ViewOp from aesara.compile.ops import DeepCopyOp, ViewOp
from aesara.graph.basic import Apply from aesara.graph.basic import Apply, Variable
from aesara.graph.fg import FunctionGraph from aesara.graph.fg import FunctionGraph
from aesara.graph.type import Type from aesara.graph.type import Type
from aesara.link.utils import ( from aesara.link.utils import (
...@@ -28,10 +28,12 @@ from aesara.link.utils import ( ...@@ -28,10 +28,12 @@ from aesara.link.utils import (
unique_name_generator, unique_name_generator,
) )
from aesara.scalar.basic import ( from aesara.scalar.basic import (
Add,
Cast, Cast,
Clip, Clip,
Composite, Composite,
Identity, Identity,
Mul,
Scalar, Scalar,
ScalarOp, ScalarOp,
Second, Second,
...@@ -381,6 +383,42 @@ def numba_funcify_Switch(op, node, **kwargs): ...@@ -381,6 +383,42 @@ def numba_funcify_Switch(op, node, **kwargs):
return switch return switch
def binary_to_nary_func(inputs: List[Variable], binary_op_name: str, binary_op: str):
"""Create a Numba-compatible N-ary function from a binary function."""
unique_names = unique_name_generator(["binary_op_name"], suffix_sep="_")
input_names = [unique_names(v, force_unique=True) for v in inputs]
input_signature = ", ".join(input_names)
output_expr = binary_op.join(input_names)
nary_src = f"""
def {binary_op_name}({input_signature}):
return {output_expr}
"""
nary_fn = compile_function_src(nary_src, binary_op_name)
return nary_fn
@numba_funcify.register(Add)
def numba_funcify_Add(op, node, **kwargs):
signature = create_numba_signature(node, force_scalar=True)
nary_add_fn = binary_to_nary_func(node.inputs, "add", "+")
return numba.njit(signature)(nary_add_fn)
@numba_funcify.register(Mul)
def numba_funcify_Mul(op, node, **kwargs):
signature = create_numba_signature(node, force_scalar=True)
nary_mul_fn = binary_to_nary_func(node.inputs, "mul", "*")
return numba.njit(signature)(nary_mul_fn)
@numba_funcify.register(Elemwise) @numba_funcify.register(Elemwise)
def numba_funcify_Elemwise(op, node, use_signature=False, identity=None, **kwargs): def numba_funcify_Elemwise(op, node, use_signature=False, identity=None, **kwargs):
scalar_op_fn = numba_funcify(op.scalar_op, node, **kwargs) scalar_op_fn = numba_funcify(op.scalar_op, node, **kwargs)
......
...@@ -301,6 +301,24 @@ def test_Elemwise(inputs, input_vals, output_fn): ...@@ -301,6 +301,24 @@ def test_Elemwise(inputs, input_vals, output_fn):
@pytest.mark.parametrize( @pytest.mark.parametrize(
"inputs, input_values, scalar_fn", "inputs, input_values, scalar_fn",
[ [
(
[aet.scalar("x"), aet.scalar("y"), aet.scalar("z")],
[
np.array(10, dtype=config.floatX),
np.array(20, dtype=config.floatX),
np.array(30, dtype=config.floatX),
],
lambda x, y, z: aes.add(x, y, z),
),
(
[aet.scalar("x"), aet.scalar("y"), aet.scalar("z")],
[
np.array(10, dtype=config.floatX),
np.array(20, dtype=config.floatX),
np.array(30, dtype=config.floatX),
],
lambda x, y, z: aes.mul(x, y, z),
),
( (
[aet.scalar("x"), aet.scalar("y")], [aet.scalar("x"), aet.scalar("y")],
[ [
...@@ -312,9 +330,8 @@ def test_Elemwise(inputs, input_vals, output_fn): ...@@ -312,9 +330,8 @@ def test_Elemwise(inputs, input_vals, output_fn):
], ],
) )
def test_numba_Composite(inputs, input_values, scalar_fn): def test_numba_Composite(inputs, input_values, scalar_fn):
x_s = aes.float64("x") composite_inputs = [aes.float64(i.name) for i in inputs]
y_s = aes.float64("y") comp_op = Elemwise(Composite(composite_inputs, [scalar_fn(*composite_inputs)]))
comp_op = Elemwise(Composite([x_s, y_s], [scalar_fn(x_s, y_s)]))
out_fg = FunctionGraph(inputs, [comp_op(*inputs)]) out_fg = FunctionGraph(inputs, [comp_op(*inputs)])
compare_numba_and_py(out_fg, input_values) compare_numba_and_py(out_fg, input_values)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论