提交 fa40b9bc authored 作者: kc611's avatar kc611 提交者: Brandon T. Willard

Add Numba support for IfElse Op

上级 c2e3dbb1
...@@ -16,6 +16,7 @@ from aesara.compile.ops import DeepCopyOp ...@@ -16,6 +16,7 @@ from aesara.compile.ops import DeepCopyOp
from aesara.graph.basic import Apply from aesara.graph.basic import Apply
from aesara.graph.fg import FunctionGraph from aesara.graph.fg import FunctionGraph
from aesara.graph.type import Type from aesara.graph.type import Type
from aesara.ifelse import IfElse
from aesara.link.utils import ( from aesara.link.utils import (
compile_function_src, compile_function_src,
fgraph_to_python, fgraph_to_python,
...@@ -682,3 +683,32 @@ def numba_funcify_BatchedDot(op, node, **kwargs): ...@@ -682,3 +683,32 @@ def numba_funcify_BatchedDot(op, node, **kwargs):
# NOTE: The remaining `aesara.tensor.blas` `Op`s appear unnecessary, because # NOTE: The remaining `aesara.tensor.blas` `Op`s appear unnecessary, because
# they're only used to optimize basic `Dot` nodes, and those GEMV and GEMM # they're only used to optimize basic `Dot` nodes, and those GEMV and GEMM
# optimizations are apparently already performed by Numba # optimizations are apparently already performed by Numba
@numba_funcify.register(IfElse)
def numba_funcify_IfElse(op, **kwargs):
n_outs = op.n_outs
if n_outs > 1:
@numba.njit
def ifelse(cond, *args):
if cond:
res = args[:n_outs]
else:
res = args[n_outs:]
return res
else:
@numba.njit
def ifelse(cond, *args):
if cond:
res = args[:n_outs]
else:
res = args[n_outs:]
return res[0]
return ifelse
...@@ -22,9 +22,10 @@ from aesara.compile.ops import ViewOp, deep_copy_op ...@@ -22,9 +22,10 @@ from aesara.compile.ops import ViewOp, deep_copy_op
from aesara.compile.sharedvalue import SharedVariable from aesara.compile.sharedvalue import SharedVariable
from aesara.graph.basic import Apply, Constant from aesara.graph.basic import Apply, Constant
from aesara.graph.fg import FunctionGraph from aesara.graph.fg import FunctionGraph
from aesara.graph.op import Op from aesara.graph.op import Op, get_test_value
from aesara.graph.optdb import OptimizationQuery from aesara.graph.optdb import OptimizationQuery
from aesara.graph.type import Type from aesara.graph.type import Type
from aesara.ifelse import ifelse
from aesara.link.numba.dispatch import basic as numba_basic from aesara.link.numba.dispatch import basic as numba_basic
from aesara.link.numba.linker import NumbaLinker from aesara.link.numba.linker import NumbaLinker
from aesara.scalar.basic import Composite from aesara.scalar.basic import Composite
...@@ -3076,3 +3077,63 @@ def test_scan_while(): ...@@ -3076,3 +3077,63 @@ def test_scan_while():
np.array(45).astype(config.floatX), np.array(45).astype(config.floatX),
] ]
compare_numba_and_py(out_fg, test_input_vals) compare_numba_and_py(out_fg, test_input_vals)
@pytest.mark.parametrize(
"inputs, cond_fn, true_vals, false_vals",
[
([], lambda: np.array(True), np.r_[1, 2, 3], np.r_[-1, -2, -3]),
(
[set_test_value(aet.dscalar(), np.array(0.2, dtype=np.float64))],
lambda x: x < 0.5,
np.r_[1, 2, 3],
np.r_[-1, -2, -3],
),
(
[
set_test_value(aet.dscalar(), np.array(0.3, dtype=np.float64)),
set_test_value(aet.dscalar(), np.array(0.5, dtype=np.float64)),
],
lambda x, y: x > y,
x,
y,
),
(
[
set_test_value(aet.dvector(), np.array([0.3, 0.1], dtype=np.float64)),
set_test_value(aet.dvector(), np.array([0.5, 0.9], dtype=np.float64)),
],
lambda x, y: aet.all(x > y),
x,
y,
),
(
[
set_test_value(aet.dvector(), np.array([0.3, 0.1], dtype=np.float64)),
set_test_value(aet.dvector(), np.array([0.5, 0.9], dtype=np.float64)),
],
lambda x, y: aet.all(x > y),
[x, 2 * x],
[y, 3 * y],
),
(
[
set_test_value(aet.dvector(), np.array([0.5, 0.9], dtype=np.float64)),
set_test_value(aet.dvector(), np.array([0.3, 0.1], dtype=np.float64)),
],
lambda x, y: aet.all(x > y),
[x, 2 * x],
[y, 3 * y],
),
],
)
def test_numba_ifelse(inputs, cond_fn, true_vals, false_vals):
out = ifelse(cond_fn(*inputs), true_vals, false_vals)
if not isinstance(out, list):
out = [out]
out_fg = FunctionGraph(inputs, out)
compare_numba_and_py(out_fg, [get_test_value(i) for i in out_fg.inputs])
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论