提交 aff7183b authored 作者: kc611's avatar kc611 提交者: Brandon T. Willard

Added NumbaLinker

上级 e212ff03
from functools import singledispatch
import numba
from aesara.graph.fg import FunctionGraph
from aesara.link.utils import fgraph_to_python
from aesara.scalar.basic import Add, Composite, Mul
from aesara.tensor.elemwise import Elemwise
@singledispatch
def numba_typify(data, dtype=None, **kwargs):
return data
@singledispatch
def numba_funcify(op, **kwargs):
"""Create a Numba compatible function from an Aesara `Op`."""
raise NotImplementedError(f"No Numba conversion for the given `Op`: {op}")
@numba_funcify.register(FunctionGraph)
def numba_funcify_FunctionGraph(
fgraph,
order=None,
input_storage=None,
output_storage=None,
storage_map=None,
**kwargs,
):
return fgraph_to_python(
fgraph,
numba_funcify,
numba_typify,
order,
input_storage,
output_storage,
storage_map,
fgraph_name="numba_funcified_fgraph",
**kwargs,
)
# TODO: Generalize Add and Mul
@numba_funcify.register(Add)
def numba_funcify_ScalarAdd(op, **kwargs):
@numba.njit
def add(x, y):
result = 0
result = x + y
return result
return add
@numba_funcify.register(Mul)
def numba_funcify_ScalarMul(op, **kwargs):
@numba.njit
def mul(x, y, z):
result = x * y * z
return result
return mul
@numba_funcify.register(Elemwise)
def numba_funcify_Elemwise(op, **kwargs):
scalar_op = op.scalar_op
# TODO:Vectorize this
return numba_funcify(scalar_op)
@numba_funcify.register(Composite)
def numba_funcify_Composite(op, vectorize=True, **kwargs):
numba_impl = numba.njit(numba_funcify(op.fgraph))
@numba.njit
def composite(*args):
return numba_impl(*args)[0]
return composite
import numba
from aesara.link.basic import JITLinker
class NumbaLinker(JITLinker):
"""A `Linker` that JIT-compiles NumPy-based operations using Numba."""
def fgraph_convert(
self, fgraph, order, input_storage, output_storage, storage_map, **kwargs
):
from aesara.link.numba.dispatch import numba_funcify
return numba_funcify(
fgraph, order, input_storage, output_storage, storage_map, **kwargs
)
def jit_compile(self, fn):
jitted_fn = numba.njit(fn)
return jitted_fn
def create_thunk_inputs(self, storage_map):
thunk_inputs = []
for n in self.fgraph.inputs:
sinput = storage_map[n]
# TODO:When RandomVariable conversion is implemented
# do RandomState typification over here.
thunk_inputs.append(sinput)
return thunk_inputs
......@@ -13,6 +13,7 @@ sympy
versioneer
jax; python_version > '3.6'
jaxlib; python_version > '3.6'
numba
diff-cover
pre-commit
isort
......
import numpy as np
import aesara
import aesara.tensor as aet
from aesara.compile.mode import Mode
from aesara.graph.optdb import Query
from aesara.link.numba.linker import NumbaLinker
# from aesara.graph.fg import FunctionGraph
opts = Query(include=["fusion"], exclude=["cxx_only", "BlasOpt"])
numba_mode = Mode(NumbaLinker(), opts)
py_mode = Mode("py", opts)
def test_composite():
y = aet.vector("y")
x = aet.vector("x")
z = (x + y) * (x + y) * y
func = aesara.function([x, y], [z], mode=py_mode)
numba_fn = aesara.function([x, y], [z], mode=numba_mode)
x_val = np.random.randn(1000)
y_val = np.random.randn(1000)
res = func(x_val, y_val) # Answer from python mode compilation of FunctionGraph
numba_res = numba_fn(x_val, y_val) # Answer from Numba converted FunctionGraph
assert np.array_equal(res, numba_res)
# y1 = aet.vector("y1")
# x1 = aet.vector("x1")
# z = (x + y) * (x1 + y1) * y
# x1_val = np.random.randn(1000)
# y1_val = np.random.randn(1000)
# func = aesara.function([x, y, x1, y1], [z], mode=mode)
# numba_fn = compile_graph(func.maker.fgraph, debug=True)
# res = func(x_val, y_val, x1_val, y1_val) # Answer from python mode compilation of FunctionGraph
# numba_res = numba_fn(x_val, y_val,x1_val,y1_val) # Answer from Numba converted FunctionGraph
# assert np.array_equal(res, numba_res)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论