提交 b6494726 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Fix Numba conversion of Switch

上级 7a6c42e6
...@@ -35,6 +35,7 @@ from aesara.scalar.basic import ( ...@@ -35,6 +35,7 @@ from aesara.scalar.basic import (
Scalar, Scalar,
ScalarOp, ScalarOp,
Second, Second,
Switch,
) )
from aesara.scalar.math import Softplus from aesara.scalar.math import Softplus
from aesara.tensor.basic import ( from aesara.tensor.basic import (
...@@ -371,6 +372,15 @@ def {scalar_op_fn_name}({input_names}): ...@@ -371,6 +372,15 @@ def {scalar_op_fn_name}({input_names}):
return numba.njit(scalar_op_fn) return numba.njit(scalar_op_fn)
@numba_funcify.register(Switch)
def numba_funcify_Switch(op, node, **kwargs):
@numba.njit
def switch(condition, x, y):
return x if np.all(condition) else y
return switch
@numba_funcify.register(Elemwise) @numba_funcify.register(Elemwise)
def numba_funcify_Elemwise(op, node, use_signature=False, identity=None, **kwargs): def numba_funcify_Elemwise(op, node, use_signature=False, identity=None, **kwargs):
scalar_op_fn = numba_funcify(op.scalar_op, node, **kwargs) scalar_op_fn = numba_funcify(op.scalar_op, node, **kwargs)
......
...@@ -285,6 +285,12 @@ def test_create_numba_signature(v, expected, force_scalar): ...@@ -285,6 +285,12 @@ def test_create_numba_signature(v, expected, force_scalar):
[np.random.randn(100).astype(config.floatX) for i in range(4)], [np.random.randn(100).astype(config.floatX) for i in range(4)],
lambda x, y, x1, y1: (x + y) * (x1 + y1) * y, lambda x, y, x1, y1: (x + y) * (x1 + y1) * y,
), ),
(
# This also tests the use of repeated arguments
[aet.matrix(), aet.scalar()],
[np.random.normal(size=(2, 2)).astype(config.floatX), 0.0],
lambda a, b: aet.switch(a, b, a),
),
], ],
) )
def test_Elemwise(inputs, input_vals, output_fn): def test_Elemwise(inputs, input_vals, output_fn):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论