提交 8335bbfe authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Add Aesara-to-Numba signature conversion functions

上级 d7a8c825
......@@ -13,10 +13,19 @@ from numba.cpython.unsafe.tuple import tuple_setitem
from numba.extending import box
from aesara.compile.ops import DeepCopyOp, ViewOp
from aesara.graph.basic import Apply
from aesara.graph.fg import FunctionGraph
from aesara.graph.type import Type
from aesara.link.utils import compile_function_src, fgraph_to_python
from aesara.scalar.basic import Cast, Clip, Composite, Identity, ScalarOp, Second
from aesara.scalar.basic import (
Cast,
Clip,
Composite,
Identity,
Scalar,
ScalarOp,
Second,
)
from aesara.tensor.basic import (
Alloc,
AllocDiag,
......@@ -37,9 +46,45 @@ from aesara.tensor.subtensor import (
IncSubtensor,
Subtensor,
)
from aesara.tensor.type import TensorType
from aesara.tensor.type_other import MakeSlice
def get_numba_type(
aesara_type: Type, layout: str = "A", force_scalar: bool = False
) -> numba.types.Type:
"""Create a Numba type object for a ``Type``."""
if isinstance(aesara_type, TensorType) and not force_scalar:
dtype = aesara_type.numpy_dtype
numba_dtype = numba.np.numpy_support.from_dtype(dtype)
return numba.types.Array(numba_dtype, aesara_type.ndim, layout)
elif isinstance(aesara_type, Scalar) or force_scalar:
dtype = np.dtype(aesara_type.dtype)
numba_dtype = numba.np.numpy_support.from_dtype(dtype)
return numba_dtype
else:
raise NotImplementedError(f"Numba type not implemented for {aesara_type}")
def create_numba_signature(node: Apply, force_scalar: bool = False) -> numba.types.Type:
"""Create a Numba type for the signature of an ``Apply`` node."""
input_types = []
for inp in node.inputs:
input_types.append(get_numba_type(inp.type, force_scalar=force_scalar))
output_types = []
for out in node.outputs:
output_types.append(get_numba_type(out.type, force_scalar=force_scalar))
if len(output_types) > 1:
return numba.types.Tuple(output_types)(*input_types)
elif len(output_types) == 1:
return output_types[0](*input_types)
else:
return numba.types.void(*input_types)
def slice_new(self, start, stop, step):
fnty = llvm_Type.function(self.pyobj, [self.pyobj, self.pyobj, self.pyobj])
fn = self._get_function(fnty, name="PySlice_New")
......
import contextlib
from unittest import mock
import numba
import numpy as np
import pytest
......@@ -13,9 +14,12 @@ from aesara.compile.function import function
from aesara.compile.mode import Mode
from aesara.compile.ops import ViewOp
from aesara.compile.sharedvalue import SharedVariable
from aesara.graph.basic import Constant
from aesara.graph.basic import Apply, Constant
from aesara.graph.fg import FunctionGraph
from aesara.graph.op import Op
from aesara.graph.optdb import Query
from aesara.graph.type import Type
from aesara.link.numba.dispatch import create_numba_signature, get_numba_type
from aesara.link.numba.linker import NumbaLinker
from aesara.scalar.basic import Composite
from aesara.tensor import elemwise as aet_elemwise
......@@ -24,6 +28,22 @@ from aesara.tensor.elemwise import Elemwise
from aesara.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape
class MyType(Type):
def filter(self, data):
return data
def __eq__(self, other):
return isinstance(other, MyType)
def __hash__(self):
return hash(MyType)
class MyOp(Op):
def perform(self, *args):
pass
opts = Query(include=[None], exclude=["cxx_only", "BlasOpt"])
numba_mode = Mode(NumbaLinker(), opts)
py_mode = Mode("py", opts)
......@@ -126,6 +146,71 @@ def compare_numba_and_py(
return numba_res
@pytest.mark.parametrize(
"v, expected, force_scalar, not_implemented",
[
(MyType(), None, False, True),
(aes.float32, numba.types.float32, False, False),
(aet.fscalar, numba.types.Array(numba.types.float32, 0, "A"), False, False),
(aet.fscalar, numba.types.float32, True, False),
(aet.lvector, numba.types.int64[:], False, False),
(aet.dmatrix, numba.types.float64[:, :], False, False),
(aet.dmatrix, numba.types.float64, True, False),
],
)
def test_get_numba_type(v, expected, force_scalar, not_implemented):
cm = (
contextlib.suppress()
if not not_implemented
else pytest.raises(NotImplementedError)
)
with cm:
res = get_numba_type(v, force_scalar=force_scalar)
assert res == expected
@pytest.mark.parametrize(
"v, expected, force_scalar",
[
(Apply(MyOp(), [], []), numba.types.void(), False),
(Apply(MyOp(), [], []), numba.types.void(), True),
(
Apply(MyOp(), [aet.lvector()], []),
numba.types.void(numba.types.int64[:]),
False,
),
(Apply(MyOp(), [aet.lvector()], []), numba.types.void(numba.types.int64), True),
(
Apply(MyOp(), [aet.dmatrix(), aes.float32()], [aet.dmatrix()]),
numba.types.float64[:, :](numba.types.float64[:, :], numba.types.float32),
False,
),
(
Apply(MyOp(), [aet.dmatrix(), aes.float32()], [aet.dmatrix()]),
numba.types.float64(numba.types.float64, numba.types.float32),
True,
),
(
Apply(MyOp(), [aet.dmatrix(), aes.float32()], [aet.dmatrix(), aes.int32()]),
numba.types.Tuple([numba.types.float64[:, :], numba.types.int32])(
numba.types.float64[:, :], numba.types.float32
),
False,
),
(
Apply(MyOp(), [aet.dmatrix(), aes.float32()], [aet.dmatrix(), aes.int32()]),
numba.types.Tuple([numba.types.float64, numba.types.int32])(
numba.types.float64, numba.types.float32
),
True,
),
],
)
def test_create_numba_signature(v, expected, force_scalar):
res = create_numba_signature(v, force_scalar=force_scalar)
assert res == expected
@pytest.mark.parametrize(
"inputs, input_vals, output_fn",
[
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论