提交 67d7f140 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Add a generic Numba conversion function for Op.perform

上级 8f336639
import operator
import warnings
from functools import reduce, singledispatch
from textwrap import indent
......@@ -171,7 +172,32 @@ def numba_typify(data, dtype=None, **kwargs):
@singledispatch
def numba_funcify(op, node=None, storage_map=None, **kwargs):
"""Create a Numba compatible function from an Aesara `Op`."""
raise NotImplementedError(f"No Numba conversion for the given `Op`: {op}")
warnings.warn(
(f"Numba will use object mode to run {op}'s perform method"),
UserWarning,
)
n_outputs = len(node.outputs)
if n_outputs > 1:
ret_sig = numba.types.Tuple([get_numba_type(o.type) for o in node.outputs])
else:
ret_sig = get_numba_type(node.outputs[0].type)
@numba.njit
def perform(*inputs):
with numba.objmode(ret=ret_sig):
outputs = [[None] for i in range(n_outputs)]
op.perform(node, inputs, outputs)
outputs = tuple([o[0] for o in outputs])
if n_outputs == 1:
ret = outputs[0]
else:
ret = outputs
return ret
return perform
@numba_funcify.register(FunctionGraph)
......
......@@ -45,6 +45,26 @@ class MyOp(Op):
pass
class MySingleOut(Op):
def make_node(self, a, b):
return Apply(self, [a, b], [a.type()])
def perform(self, node, inputs, outputs):
res = (inputs[0] + inputs[1]).astype(inputs[0][0].dtype)
outputs[0][0] = res
class MyMultiOut(Op):
def make_node(self, a, b):
return Apply(self, [a, b], [a.type(), b.type()])
def perform(self, node, inputs, outputs):
res1 = 2 * inputs[0]
res2 = 2 * inputs[1]
outputs[0][0] = res1
outputs[1][0] = res2
opts = Query(include=[None], exclude=["cxx_only", "BlasOpt"])
numba_mode = Mode(NumbaLinker(), opts)
py_mode = Mode("py", opts)
......@@ -1082,3 +1102,49 @@ def test_Eye(n, m, k, dtype):
if not isinstance(i, (SharedVariable, Constant))
],
)
@pytest.mark.parametrize(
"inputs, op, exc",
[
(
[
set_test_value(
aet.matrix(), np.random.random(size=(2, 3)).astype(config.floatX)
),
set_test_value(aet.lmatrix(), np.random.poisson(size=(2, 3))),
],
MySingleOut,
UserWarning,
),
(
[
set_test_value(
aet.matrix(), np.random.random(size=(2, 3)).astype(config.floatX)
),
set_test_value(aet.lmatrix(), np.random.poisson(size=(2, 3))),
],
MyMultiOut,
UserWarning,
),
],
)
def test_perform(inputs, op, exc):
g = op()(*inputs)
if isinstance(g, list):
g_fg = FunctionGraph(outputs=g)
else:
g_fg = FunctionGraph(outputs=[g])
cm = contextlib.suppress() if exc is None else pytest.warns(exc)
with cm:
compare_numba_and_py(
g_fg,
[
i.tag.test_value
for i in g_fg.inputs
if not isinstance(i, (SharedVariable, Constant))
],
)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论