提交 b6494726 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Fix Numba conversion of Switch

上级 7a6c42e6
......@@ -35,6 +35,7 @@ from aesara.scalar.basic import (
Scalar,
ScalarOp,
Second,
Switch,
)
from aesara.scalar.math import Softplus
from aesara.tensor.basic import (
......@@ -371,6 +372,15 @@ def {scalar_op_fn_name}({input_names}):
return numba.njit(scalar_op_fn)
@numba_funcify.register(Switch)
def numba_funcify_Switch(op, node, **kwargs):
@numba.njit
def switch(condition, x, y):
return x if np.all(condition) else y
return switch
@numba_funcify.register(Elemwise)
def numba_funcify_Elemwise(op, node, use_signature=False, identity=None, **kwargs):
scalar_op_fn = numba_funcify(op.scalar_op, node, **kwargs)
......
......@@ -285,6 +285,12 @@ def test_create_numba_signature(v, expected, force_scalar):
[np.random.randn(100).astype(config.floatX) for i in range(4)],
lambda x, y, x1, y1: (x + y) * (x1 + y1) * y,
),
(
# This also tests the use of repeated arguments
[aet.matrix(), aet.scalar()],
[np.random.normal(size=(2, 2)).astype(config.floatX), 0.0],
lambda a, b: aet.switch(a, b, a),
),
],
)
def test_Elemwise(inputs, input_vals, output_fn):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论