提交 643c9734 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Add a Numba OpFromGraph implementation

上级 308969cf
......@@ -17,6 +17,7 @@ from numba.cpython.unsafe.tuple import tuple_setitem # noqa: F401
from numba.extending import box
from aesara import config
from aesara.compile.builders import OpFromGraph
from aesara.compile.ops import DeepCopyOp
from aesara.graph.basic import Apply, NoParams
from aesara.graph.fg import FunctionGraph
......@@ -374,6 +375,25 @@ def numba_funcify(op, node=None, storage_map=None, **kwargs):
return perform
@numba_funcify.register(OpFromGraph)
def numba_funcify_OpFromGraph(op, node=None, **kwargs):
fgraph_fn = numba_njit(numba_funcify(op.fgraph, **kwargs))
if len(op.fgraph.outputs) == 1:
@numba_njit
def opfromgraph(*inputs):
return fgraph_fn(*inputs)[0]
else:
@numba_njit
def opfromgraph(*inputs):
return fgraph_fn(*inputs)
return opfromgraph
@numba_funcify.register(FunctionGraph)
def numba_funcify_FunctionGraph(
fgraph,
......
......@@ -12,6 +12,7 @@ import aesara.scalar.math as aesm
import aesara.tensor as at
import aesara.tensor.math as aem
from aesara import config, shared
from aesara.compile.builders import OpFromGraph
from aesara.compile.function import function
from aesara.compile.mode import Mode
from aesara.compile.ops import ViewOp
......@@ -1003,3 +1004,18 @@ def test_scalar_return_value_conversion():
mode=numba_mode,
)
assert isinstance(x_fn(1.0), np.ndarray)
def test_OpFromGraph():
x, y, z = at.matrices("xyz")
ofg_1 = OpFromGraph([x, y], [x + y], inline=False)
ofg_2 = OpFromGraph([x, y], [x * y, x - y], inline=False)
o1, o2 = ofg_2(y, z)
out = ofg_1(x, o1) + o2
xv = np.ones((2, 2), dtype=config.floatX)
yv = np.ones((2, 2), dtype=config.floatX) * 3
zv = np.ones((2, 2), dtype=config.floatX) * 5
compare_numba_and_py(((x, y, z), (out,)), [xv, yv, zv])
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论