提交 859b4240 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Create a NUMBA compilation mode

上级 fc7922c0
......@@ -22,6 +22,7 @@ from aesara.compile.mode import (
FAST_COMPILE,
FAST_RUN,
JAX,
NUMBA,
OPT_FAST_COMPILE,
OPT_FAST_RUN,
OPT_FAST_RUN_STABLE,
......
......@@ -20,7 +20,8 @@ from aesara.graph.opt import (
from aesara.graph.optdb import EquilibriumDB, LocalGroupDB, Query, SequenceDB, TopoDB
from aesara.link.basic import PerformLinker
from aesara.link.c.basic import CLinker, OpWiseCLinker
from aesara.link.jax import JAXLinker
from aesara.link.jax.linker import JAXLinker
from aesara.link.numba.linker import NumbaLinker
from aesara.link.vm import VMLinker
......@@ -40,6 +41,7 @@ predefined_linkers = {
"vm_nogc": VMLinker(allow_gc=False, use_cloop=False),
"cvm_nogc": VMLinker(allow_gc=False, use_cloop=True),
"jax": JAXLinker(),
"numba": NumbaLinker(),
}
......@@ -420,12 +422,16 @@ else:
FAST_RUN = Mode("vm", "fast_run")
JAX = Mode(JAXLinker(), Query(include=["fast_run"], exclude=["cxx_only", "BlasOpt"]))
NUMBA = Mode(
NumbaLinker(), Query(include=["fast_run"], exclude=["cxx_only", "BlasOpt"])
)
predefined_modes = {
"FAST_COMPILE": FAST_COMPILE,
"FAST_RUN": FAST_RUN,
"JAX": JAX,
"NUMBA": NUMBA,
}
instantiated_default_mode = None
......
......@@ -107,6 +107,7 @@ def _filter_mode(val):
"FAST_COMPILE",
"DEBUG_MODE",
"JAX",
"NUMBA",
]
if val in str_options:
return val
......
from aesara.link.numba.linker import NumbaLinker
import numba
from aesara.link.basic import JITLinker
......@@ -12,6 +10,8 @@ class NumbaLinker(JITLinker):
return numba_funcify(fgraph, **kwargs)
def jit_compile(self, fn):
import numba
jitted_fn = numba.njit(fn)
return jitted_fn
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论