提交 67d7f140 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Add a generic Numba conversion function for Op.perform

上级 8f336639
import operator import operator
import warnings
from functools import reduce, singledispatch from functools import reduce, singledispatch
from textwrap import indent from textwrap import indent
...@@ -171,7 +172,32 @@ def numba_typify(data, dtype=None, **kwargs): ...@@ -171,7 +172,32 @@ def numba_typify(data, dtype=None, **kwargs):
@singledispatch @singledispatch
def numba_funcify(op, node=None, storage_map=None, **kwargs): def numba_funcify(op, node=None, storage_map=None, **kwargs):
"""Create a Numba compatible function from an Aesara `Op`.""" """Create a Numba compatible function from an Aesara `Op`."""
raise NotImplementedError(f"No Numba conversion for the given `Op`: {op}")
warnings.warn(
(f"Numba will use object mode to run {op}'s perform method"),
UserWarning,
)
n_outputs = len(node.outputs)
if n_outputs > 1:
ret_sig = numba.types.Tuple([get_numba_type(o.type) for o in node.outputs])
else:
ret_sig = get_numba_type(node.outputs[0].type)
@numba.njit
def perform(*inputs):
with numba.objmode(ret=ret_sig):
outputs = [[None] for i in range(n_outputs)]
op.perform(node, inputs, outputs)
outputs = tuple([o[0] for o in outputs])
if n_outputs == 1:
ret = outputs[0]
else:
ret = outputs
return ret
return perform
@numba_funcify.register(FunctionGraph) @numba_funcify.register(FunctionGraph)
......
...@@ -45,6 +45,26 @@ class MyOp(Op): ...@@ -45,6 +45,26 @@ class MyOp(Op):
pass pass
class MySingleOut(Op):
def make_node(self, a, b):
return Apply(self, [a, b], [a.type()])
def perform(self, node, inputs, outputs):
res = (inputs[0] + inputs[1]).astype(inputs[0][0].dtype)
outputs[0][0] = res
class MyMultiOut(Op):
def make_node(self, a, b):
return Apply(self, [a, b], [a.type(), b.type()])
def perform(self, node, inputs, outputs):
res1 = 2 * inputs[0]
res2 = 2 * inputs[1]
outputs[0][0] = res1
outputs[1][0] = res2
opts = Query(include=[None], exclude=["cxx_only", "BlasOpt"]) opts = Query(include=[None], exclude=["cxx_only", "BlasOpt"])
numba_mode = Mode(NumbaLinker(), opts) numba_mode = Mode(NumbaLinker(), opts)
py_mode = Mode("py", opts) py_mode = Mode("py", opts)
...@@ -1082,3 +1102,49 @@ def test_Eye(n, m, k, dtype): ...@@ -1082,3 +1102,49 @@ def test_Eye(n, m, k, dtype):
if not isinstance(i, (SharedVariable, Constant)) if not isinstance(i, (SharedVariable, Constant))
], ],
) )
@pytest.mark.parametrize(
"inputs, op, exc",
[
(
[
set_test_value(
aet.matrix(), np.random.random(size=(2, 3)).astype(config.floatX)
),
set_test_value(aet.lmatrix(), np.random.poisson(size=(2, 3))),
],
MySingleOut,
UserWarning,
),
(
[
set_test_value(
aet.matrix(), np.random.random(size=(2, 3)).astype(config.floatX)
),
set_test_value(aet.lmatrix(), np.random.poisson(size=(2, 3))),
],
MyMultiOut,
UserWarning,
),
],
)
def test_perform(inputs, op, exc):
g = op()(*inputs)
if isinstance(g, list):
g_fg = FunctionGraph(outputs=g)
else:
g_fg = FunctionGraph(outputs=[g])
cm = contextlib.suppress() if exc is None else pytest.warns(exc)
with cm:
compare_numba_and_py(
g_fg,
[
i.tag.test_value
for i in g_fg.inputs
if not isinstance(i, (SharedVariable, Constant))
],
)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论