提交 3ef5ce79 authored 作者: Adrian Seyboldt's avatar Adrian Seyboldt 提交者: Adrian Seyboldt

Add numba impl for CheckAndRaise

上级 0933d203
...@@ -19,6 +19,7 @@ from pytensor.tensor.extra_ops import ( ...@@ -19,6 +19,7 @@ from pytensor.tensor.extra_ops import (
Unique, Unique,
UnravelIndex, UnravelIndex,
) )
from aesara.raise_op import CheckAndRaise
@numba_funcify.register(Bartlett) @numba_funcify.register(Bartlett)
...@@ -372,3 +373,18 @@ def numba_funcify_BroadcastTo(op, node, **kwargs): ...@@ -372,3 +373,18 @@ def numba_funcify_BroadcastTo(op, node, **kwargs):
return np.broadcast_to(x, scalars_shape) return np.broadcast_to(x, scalars_shape)
return broadcast_to return broadcast_to
@numba_funcify.register(CheckAndRaise)
def numba_funcify_CheckAndRaise(op, node, **kwargs):
error = op.exc_type
msg = op.msg
@numba_basic.numba_njit
def check_and_raise(x, *conditions):
for cond in conditions:
if not cond:
raise error(msg)
return x
return check_and_raise
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论