提交 859b4240 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Create a NUMBA compilation mode

上级 fc7922c0
...@@ -22,6 +22,7 @@ from aesara.compile.mode import ( ...@@ -22,6 +22,7 @@ from aesara.compile.mode import (
FAST_COMPILE, FAST_COMPILE,
FAST_RUN, FAST_RUN,
JAX, JAX,
NUMBA,
OPT_FAST_COMPILE, OPT_FAST_COMPILE,
OPT_FAST_RUN, OPT_FAST_RUN,
OPT_FAST_RUN_STABLE, OPT_FAST_RUN_STABLE,
......
...@@ -20,7 +20,8 @@ from aesara.graph.opt import ( ...@@ -20,7 +20,8 @@ from aesara.graph.opt import (
from aesara.graph.optdb import EquilibriumDB, LocalGroupDB, Query, SequenceDB, TopoDB from aesara.graph.optdb import EquilibriumDB, LocalGroupDB, Query, SequenceDB, TopoDB
from aesara.link.basic import PerformLinker from aesara.link.basic import PerformLinker
from aesara.link.c.basic import CLinker, OpWiseCLinker from aesara.link.c.basic import CLinker, OpWiseCLinker
from aesara.link.jax import JAXLinker from aesara.link.jax.linker import JAXLinker
from aesara.link.numba.linker import NumbaLinker
from aesara.link.vm import VMLinker from aesara.link.vm import VMLinker
...@@ -40,6 +41,7 @@ predefined_linkers = { ...@@ -40,6 +41,7 @@ predefined_linkers = {
"vm_nogc": VMLinker(allow_gc=False, use_cloop=False), "vm_nogc": VMLinker(allow_gc=False, use_cloop=False),
"cvm_nogc": VMLinker(allow_gc=False, use_cloop=True), "cvm_nogc": VMLinker(allow_gc=False, use_cloop=True),
"jax": JAXLinker(), "jax": JAXLinker(),
"numba": NumbaLinker(),
} }
...@@ -420,12 +422,16 @@ else: ...@@ -420,12 +422,16 @@ else:
FAST_RUN = Mode("vm", "fast_run") FAST_RUN = Mode("vm", "fast_run")
JAX = Mode(JAXLinker(), Query(include=["fast_run"], exclude=["cxx_only", "BlasOpt"])) JAX = Mode(JAXLinker(), Query(include=["fast_run"], exclude=["cxx_only", "BlasOpt"]))
NUMBA = Mode(
NumbaLinker(), Query(include=["fast_run"], exclude=["cxx_only", "BlasOpt"])
)
predefined_modes = { predefined_modes = {
"FAST_COMPILE": FAST_COMPILE, "FAST_COMPILE": FAST_COMPILE,
"FAST_RUN": FAST_RUN, "FAST_RUN": FAST_RUN,
"JAX": JAX, "JAX": JAX,
"NUMBA": NUMBA,
} }
instantiated_default_mode = None instantiated_default_mode = None
......
...@@ -107,6 +107,7 @@ def _filter_mode(val): ...@@ -107,6 +107,7 @@ def _filter_mode(val):
"FAST_COMPILE", "FAST_COMPILE",
"DEBUG_MODE", "DEBUG_MODE",
"JAX", "JAX",
"NUMBA",
] ]
if val in str_options: if val in str_options:
return val return val
......
from aesara.link.numba.linker import NumbaLinker
import numba
from aesara.link.basic import JITLinker from aesara.link.basic import JITLinker
...@@ -12,6 +10,8 @@ class NumbaLinker(JITLinker): ...@@ -12,6 +10,8 @@ class NumbaLinker(JITLinker):
return numba_funcify(fgraph, **kwargs) return numba_funcify(fgraph, **kwargs)
def jit_compile(self, fn): def jit_compile(self, fn):
import numba
jitted_fn = numba.njit(fn) jitted_fn = numba.njit(fn)
return jitted_fn return jitted_fn
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论