提交 643c9734 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Add a Numba OpFromGraph implementation

上级 308969cf
...@@ -17,6 +17,7 @@ from numba.cpython.unsafe.tuple import tuple_setitem # noqa: F401 ...@@ -17,6 +17,7 @@ from numba.cpython.unsafe.tuple import tuple_setitem # noqa: F401
from numba.extending import box from numba.extending import box
from aesara import config from aesara import config
from aesara.compile.builders import OpFromGraph
from aesara.compile.ops import DeepCopyOp from aesara.compile.ops import DeepCopyOp
from aesara.graph.basic import Apply, NoParams from aesara.graph.basic import Apply, NoParams
from aesara.graph.fg import FunctionGraph from aesara.graph.fg import FunctionGraph
...@@ -374,6 +375,25 @@ def numba_funcify(op, node=None, storage_map=None, **kwargs): ...@@ -374,6 +375,25 @@ def numba_funcify(op, node=None, storage_map=None, **kwargs):
return perform return perform
@numba_funcify.register(OpFromGraph)
def numba_funcify_OpFromGraph(op, node=None, **kwargs):
fgraph_fn = numba_njit(numba_funcify(op.fgraph, **kwargs))
if len(op.fgraph.outputs) == 1:
@numba_njit
def opfromgraph(*inputs):
return fgraph_fn(*inputs)[0]
else:
@numba_njit
def opfromgraph(*inputs):
return fgraph_fn(*inputs)
return opfromgraph
@numba_funcify.register(FunctionGraph) @numba_funcify.register(FunctionGraph)
def numba_funcify_FunctionGraph( def numba_funcify_FunctionGraph(
fgraph, fgraph,
......
...@@ -12,6 +12,7 @@ import aesara.scalar.math as aesm ...@@ -12,6 +12,7 @@ import aesara.scalar.math as aesm
import aesara.tensor as at import aesara.tensor as at
import aesara.tensor.math as aem import aesara.tensor.math as aem
from aesara import config, shared from aesara import config, shared
from aesara.compile.builders import OpFromGraph
from aesara.compile.function import function from aesara.compile.function import function
from aesara.compile.mode import Mode from aesara.compile.mode import Mode
from aesara.compile.ops import ViewOp from aesara.compile.ops import ViewOp
...@@ -1003,3 +1004,18 @@ def test_scalar_return_value_conversion(): ...@@ -1003,3 +1004,18 @@ def test_scalar_return_value_conversion():
mode=numba_mode, mode=numba_mode,
) )
assert isinstance(x_fn(1.0), np.ndarray) assert isinstance(x_fn(1.0), np.ndarray)
def test_OpFromGraph():
x, y, z = at.matrices("xyz")
ofg_1 = OpFromGraph([x, y], [x + y], inline=False)
ofg_2 = OpFromGraph([x, y], [x * y, x - y], inline=False)
o1, o2 = ofg_2(y, z)
out = ofg_1(x, o1) + o2
xv = np.ones((2, 2), dtype=config.floatX)
yv = np.ones((2, 2), dtype=config.floatX) * 3
zv = np.ones((2, 2), dtype=config.floatX) * 5
compare_numba_and_py(((x, y, z), (out,)), [xv, yv, zv])
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论