提交 fa40b9bc authored 作者: kc611's avatar kc611 提交者: Brandon T. Willard

Add Numba support for IfElse Op

上级 c2e3dbb1
......@@ -16,6 +16,7 @@ from aesara.compile.ops import DeepCopyOp
from aesara.graph.basic import Apply
from aesara.graph.fg import FunctionGraph
from aesara.graph.type import Type
from aesara.ifelse import IfElse
from aesara.link.utils import (
compile_function_src,
fgraph_to_python,
......@@ -682,3 +683,32 @@ def numba_funcify_BatchedDot(op, node, **kwargs):
# NOTE: The remaining `aesara.tensor.blas` `Op`s appear unnecessary, because
# they're only used to optimize basic `Dot` nodes, and those GEMV and GEMM
# optimizations are apparently already performed by Numba
@numba_funcify.register(IfElse)
def numba_funcify_IfElse(op, **kwargs):
n_outs = op.n_outs
if n_outs > 1:
@numba.njit
def ifelse(cond, *args):
if cond:
res = args[:n_outs]
else:
res = args[n_outs:]
return res
else:
@numba.njit
def ifelse(cond, *args):
if cond:
res = args[:n_outs]
else:
res = args[n_outs:]
return res[0]
return ifelse
......@@ -22,9 +22,10 @@ from aesara.compile.ops import ViewOp, deep_copy_op
from aesara.compile.sharedvalue import SharedVariable
from aesara.graph.basic import Apply, Constant
from aesara.graph.fg import FunctionGraph
from aesara.graph.op import Op
from aesara.graph.op import Op, get_test_value
from aesara.graph.optdb import OptimizationQuery
from aesara.graph.type import Type
from aesara.ifelse import ifelse
from aesara.link.numba.dispatch import basic as numba_basic
from aesara.link.numba.linker import NumbaLinker
from aesara.scalar.basic import Composite
......@@ -3076,3 +3077,63 @@ def test_scan_while():
np.array(45).astype(config.floatX),
]
compare_numba_and_py(out_fg, test_input_vals)
@pytest.mark.parametrize(
"inputs, cond_fn, true_vals, false_vals",
[
([], lambda: np.array(True), np.r_[1, 2, 3], np.r_[-1, -2, -3]),
(
[set_test_value(aet.dscalar(), np.array(0.2, dtype=np.float64))],
lambda x: x < 0.5,
np.r_[1, 2, 3],
np.r_[-1, -2, -3],
),
(
[
set_test_value(aet.dscalar(), np.array(0.3, dtype=np.float64)),
set_test_value(aet.dscalar(), np.array(0.5, dtype=np.float64)),
],
lambda x, y: x > y,
x,
y,
),
(
[
set_test_value(aet.dvector(), np.array([0.3, 0.1], dtype=np.float64)),
set_test_value(aet.dvector(), np.array([0.5, 0.9], dtype=np.float64)),
],
lambda x, y: aet.all(x > y),
x,
y,
),
(
[
set_test_value(aet.dvector(), np.array([0.3, 0.1], dtype=np.float64)),
set_test_value(aet.dvector(), np.array([0.5, 0.9], dtype=np.float64)),
],
lambda x, y: aet.all(x > y),
[x, 2 * x],
[y, 3 * y],
),
(
[
set_test_value(aet.dvector(), np.array([0.5, 0.9], dtype=np.float64)),
set_test_value(aet.dvector(), np.array([0.3, 0.1], dtype=np.float64)),
],
lambda x, y: aet.all(x > y),
[x, 2 * x],
[y, 3 * y],
),
],
)
def test_numba_ifelse(inputs, cond_fn, true_vals, false_vals):
out = ifelse(cond_fn(*inputs), true_vals, false_vals)
if not isinstance(out, list):
out = [out]
out_fg = FunctionGraph(inputs, out)
compare_numba_and_py(out_fg, [get_test_value(i) for i in out_fg.inputs])
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论