提交 acd23ab6 authored 作者: kc611's avatar kc611 提交者: Brandon T. Willard

Generalize Numba conversion of Scalar Ops

上级 aff7183b
...@@ -4,7 +4,7 @@ import numba ...@@ -4,7 +4,7 @@ import numba
from aesara.graph.fg import FunctionGraph from aesara.graph.fg import FunctionGraph
from aesara.link.utils import fgraph_to_python from aesara.link.utils import fgraph_to_python
from aesara.scalar.basic import Add, Composite, Mul from aesara.scalar.basic import Composite, ScalarOp
from aesara.tensor.elemwise import Elemwise from aesara.tensor.elemwise import Elemwise
...@@ -41,26 +41,20 @@ def numba_funcify_FunctionGraph( ...@@ -41,26 +41,20 @@ def numba_funcify_FunctionGraph(
) )
# TODO: Generalize Add and Mul @numba_funcify.register(ScalarOp)
@numba_funcify.register(Add) def numba_funcify_ScalarOp(op, **kwargs):
def numba_funcify_ScalarAdd(op, **kwargs): import numpy as np
@numba.njit
def add(x, y):
result = 0
result = x + y
return result
return add
numpy_func = getattr(np, op.nfunc_spec[0])
@numba_funcify.register(Mul)
def numba_funcify_ScalarMul(op, **kwargs):
@numba.njit @numba.njit
def mul(x, y, z): def scalar_func(*args):
result = x * y * z result = args[0]
for arg in args[1:]:
result = numpy_func(arg, result)
return result return result
return mul return scalar_func
@numba_funcify.register(Elemwise) @numba_funcify.register(Elemwise)
......
import numpy as np import numpy as np
import aesara import aesara
import aesara.scalar.basic as aes
import aesara.tensor as aet import aesara.tensor as aet
from aesara.compile.mode import Mode from aesara.compile.mode import Mode
from aesara.graph.optdb import Query from aesara.graph.optdb import Query
...@@ -15,7 +16,7 @@ numba_mode = Mode(NumbaLinker(), opts) ...@@ -15,7 +16,7 @@ numba_mode = Mode(NumbaLinker(), opts)
py_mode = Mode("py", opts) py_mode = Mode("py", opts)
def test_composite(): def test_Composite():
y = aet.vector("y") y = aet.vector("y")
x = aet.vector("x") x = aet.vector("x")
...@@ -24,6 +25,10 @@ def test_composite(): ...@@ -24,6 +25,10 @@ def test_composite():
func = aesara.function([x, y], [z], mode=py_mode) func = aesara.function([x, y], [z], mode=py_mode)
numba_fn = aesara.function([x, y], [z], mode=numba_mode) numba_fn = aesara.function([x, y], [z], mode=numba_mode)
# Make sure the graph had a `Composite` `Op` in it
composite_op = numba_fn.maker.fgraph.outputs[0].owner.op.scalar_op
assert isinstance(composite_op, aes.Composite)
x_val = np.random.randn(1000) x_val = np.random.randn(1000)
y_val = np.random.randn(1000) y_val = np.random.randn(1000)
...@@ -32,18 +37,25 @@ def test_composite(): ...@@ -32,18 +37,25 @@ def test_composite():
assert np.array_equal(res, numba_res) assert np.array_equal(res, numba_res)
# y1 = aet.vector("y1") y1 = aet.vector("y1")
# x1 = aet.vector("x1") x1 = aet.vector("x1")
z = (x + y) * (x1 + y1) * y
# z = (x + y) * (x1 + y1) * y x1_val = np.random.randn(1000)
y1_val = np.random.randn(1000)
# x1_val = np.random.randn(1000) func = aesara.function([x, y, x1, y1], [z], mode=py_mode)
# y1_val = np.random.randn(1000) numba_fn = aesara.function([x, y, x1, y1], [z], mode=numba_mode)
# func = aesara.function([x, y, x1, y1], [z], mode=mode) composite_op = numba_fn.maker.fgraph.outputs[0].owner.op.scalar_op
# numba_fn = compile_graph(func.maker.fgraph, debug=True) assert isinstance(composite_op, aes.Composite)
# res = func(x_val, y_val, x1_val, y1_val) # Answer from python mode compilation of FunctionGraph res = func(
# numba_res = numba_fn(x_val, y_val,x1_val,y1_val) # Answer from Numba converted FunctionGraph x_val, y_val, x1_val, y1_val
) # Answer from python mode compilation of FunctionGraph
numba_res = numba_fn(
x_val, y_val, x1_val, y1_val
) # Answer from Numba converted FunctionGraph
# assert np.array_equal(res, numba_res) assert np.array_equal(res, numba_res)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论