提交 acd23ab6 authored 作者: kc611's avatar kc611 提交者: Brandon T. Willard

Generalize Numba conversion of Scalar Ops

上级 aff7183b
......@@ -4,7 +4,7 @@ import numba
from aesara.graph.fg import FunctionGraph
from aesara.link.utils import fgraph_to_python
from aesara.scalar.basic import Add, Composite, Mul
from aesara.scalar.basic import Composite, ScalarOp
from aesara.tensor.elemwise import Elemwise
......@@ -41,26 +41,20 @@ def numba_funcify_FunctionGraph(
)
# TODO: Generalize Add and Mul
@numba_funcify.register(Add)
def numba_funcify_ScalarAdd(op, **kwargs):
@numba.njit
def add(x, y):
result = 0
result = x + y
return result
return add
@numba_funcify.register(ScalarOp)
def numba_funcify_ScalarOp(op, **kwargs):
import numpy as np
numpy_func = getattr(np, op.nfunc_spec[0])
@numba_funcify.register(Mul)
def numba_funcify_ScalarMul(op, **kwargs):
@numba.njit
def mul(x, y, z):
result = x * y * z
def scalar_func(*args):
result = args[0]
for arg in args[1:]:
result = numpy_func(arg, result)
return result
return mul
return scalar_func
@numba_funcify.register(Elemwise)
......
import numpy as np
import aesara
import aesara.scalar.basic as aes
import aesara.tensor as aet
from aesara.compile.mode import Mode
from aesara.graph.optdb import Query
......@@ -15,7 +16,7 @@ numba_mode = Mode(NumbaLinker(), opts)
py_mode = Mode("py", opts)
def test_composite():
def test_Composite():
y = aet.vector("y")
x = aet.vector("x")
......@@ -24,6 +25,10 @@ def test_composite():
func = aesara.function([x, y], [z], mode=py_mode)
numba_fn = aesara.function([x, y], [z], mode=numba_mode)
# Make sure the graph had a `Composite` `Op` in it
composite_op = numba_fn.maker.fgraph.outputs[0].owner.op.scalar_op
assert isinstance(composite_op, aes.Composite)
x_val = np.random.randn(1000)
y_val = np.random.randn(1000)
......@@ -32,18 +37,25 @@ def test_composite():
assert np.array_equal(res, numba_res)
# y1 = aet.vector("y1")
# x1 = aet.vector("x1")
y1 = aet.vector("y1")
x1 = aet.vector("x1")
z = (x + y) * (x1 + y1) * y
# z = (x + y) * (x1 + y1) * y
x1_val = np.random.randn(1000)
y1_val = np.random.randn(1000)
# x1_val = np.random.randn(1000)
# y1_val = np.random.randn(1000)
func = aesara.function([x, y, x1, y1], [z], mode=py_mode)
numba_fn = aesara.function([x, y, x1, y1], [z], mode=numba_mode)
# func = aesara.function([x, y, x1, y1], [z], mode=mode)
# numba_fn = compile_graph(func.maker.fgraph, debug=True)
composite_op = numba_fn.maker.fgraph.outputs[0].owner.op.scalar_op
assert isinstance(composite_op, aes.Composite)
# res = func(x_val, y_val, x1_val, y1_val) # Answer from python mode compilation of FunctionGraph
# numba_res = numba_fn(x_val, y_val,x1_val,y1_val) # Answer from Numba converted FunctionGraph
res = func(
x_val, y_val, x1_val, y1_val
) # Answer from python mode compilation of FunctionGraph
numba_res = numba_fn(
x_val, y_val, x1_val, y1_val
) # Answer from Numba converted FunctionGraph
# assert np.array_equal(res, numba_res)
assert np.array_equal(res, numba_res)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论