提交 8763981c authored 作者: Adrian Seyboldt's avatar Adrian Seyboldt 提交者: Brandon T. Willard

Add numba log1mexp erf erfc

上级 881e08bd
import math
from functools import reduce from functools import reduce
from typing import List from typing import List
...@@ -27,7 +28,7 @@ from aesara.scalar.basic import ( ...@@ -27,7 +28,7 @@ from aesara.scalar.basic import (
Second, Second,
Switch, Switch,
) )
from aesara.scalar.math import Sigmoid, GammaLn from aesara.scalar.math import Erf, Erfc, GammaLn, Log1mexp, Sigmoid
@numba_funcify.register(ScalarOp) @numba_funcify.register(ScalarOp)
...@@ -254,11 +255,39 @@ def numba_funcify_Sigmoid(op, node, **kwargs): ...@@ -254,11 +255,39 @@ def numba_funcify_Sigmoid(op, node, **kwargs):
@numba_funcify.register(GammaLn) @numba_funcify.register(GammaLn)
def numba_funcify_Sigmoid(op, node, **kwargs): def numba_funcify_GammaLn(op, node, **kwargs):
import math
@numba_basic.numba_njit(inline="always", fastmath=config.numba__fastmath) @numba_basic.numba_njit(inline="always", fastmath=config.numba__fastmath)
def gammaln(x): def gammaln(x):
return math.lgamma(x) return math.lgamma(x)
return gammaln return gammaln
@numba_funcify.register(Log1mexp)
def numba_funcify_Log1mexp(op, node, **kwargs):
@numba_basic.numba_njit(inline="always", fastmath=config.numba__fastmath)
def logp1mexp(x):
if x < np.log(0.5):
return np.log1p(-np.exp(x))
else:
return np.log(-np.expm1(x))
return logp1mexp
@numba_funcify.register(Erf)
def numba_funcify_Erf(op, **kwargs):
@numba_basic.numba_njit(inline="always", fastmath=config.numba__fastmath)
def erf(x):
return math.erf(x)
return erf
@numba_funcify.register(Erfc)
def numba_funcify_Erfc(op, **kwargs):
@numba_basic.numba_njit(inline="always", fastmath=config.numba__fastmath)
def erfc(x):
return math.erfc(x)
return erfc
...@@ -334,6 +334,24 @@ def test_box_unbox(input, wrapper_fn, check_fn): ...@@ -334,6 +334,24 @@ def test_box_unbox(input, wrapper_fn, check_fn):
lambda x: at.sigmoid(x), lambda x: at.sigmoid(x),
None, None,
), ),
(
[at.vector()],
[rng.standard_normal(100).astype(config.floatX)],
lambda x: at.log1mexp(x),
None,
),
(
[at.vector()],
[rng.standard_normal(100).astype(config.floatX)],
lambda x: at.erf(x),
None,
),
(
[at.vector()],
[rng.standard_normal(100).astype(config.floatX)],
lambda x: at.erfc(x),
None,
),
( (
[at.vector() for i in range(4)], [at.vector() for i in range(4)],
[rng.standard_normal(100).astype(config.floatX) for i in range(4)], [rng.standard_normal(100).astype(config.floatX) for i in range(4)],
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论