提交 3b552872 authored 作者: kc611's avatar kc611 提交者: Brandon T. Willard

Generalize Numba's NumPy and SciPy ScalarOp dispatches

上级 fe9f2580
import ast import ast
from functools import singledispatch from functools import reduce, singledispatch
from tempfile import NamedTemporaryFile from tempfile import NamedTemporaryFile
import numba import numba
import numpy as np import numpy as np
import scipy
import scipy.special
from aesara.compile.ops import DeepCopyOp from aesara.compile.ops import DeepCopyOp
from aesara.graph.fg import FunctionGraph from aesara.graph.fg import FunctionGraph
...@@ -51,16 +53,24 @@ def numba_funcify_FunctionGraph( ...@@ -51,16 +53,24 @@ def numba_funcify_FunctionGraph(
@numba_funcify.register(ScalarOp) @numba_funcify.register(ScalarOp)
def numba_funcify_ScalarOp(op, **kwargs): def numba_funcify_ScalarOp(op, **kwargs):
numpy_func = getattr(np, op.nfunc_spec[0]) scalar_func_name = op.nfunc_spec[0]
if scalar_func_name.startswith("scipy."):
func_package = scipy
scalar_func_name = scalar_func_name.split(".", 1)[-1]
else:
func_package = np
if "." in scalar_func_name:
scalar_func = reduce(getattr, [scipy] + scalar_func_name.split("."))
else:
scalar_func = getattr(func_package, scalar_func_name)
@numba.njit @numba.njit
def scalar_func(*args): def scalar_op(*args):
result = args[0] return scalar_func(*args)
for arg in args[1:]:
result = numpy_func(arg, result)
return result
return scalar_func return scalar_op
@numba_funcify.register(Elemwise) @numba_funcify.register(Elemwise)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论