提交 881e08bd authored 作者: Adrian Seyboldt's avatar Adrian Seyboldt 提交者: Brandon T. Willard

Add numba implementation for gammaln

上级 77933bef
...@@ -27,7 +27,7 @@ from aesara.scalar.basic import ( ...@@ -27,7 +27,7 @@ from aesara.scalar.basic import (
Second, Second,
Switch, Switch,
) )
from aesara.scalar.math import Sigmoid from aesara.scalar.math import Sigmoid, GammaLn
@numba_funcify.register(ScalarOp) @numba_funcify.register(ScalarOp)
...@@ -251,3 +251,14 @@ def numba_funcify_Sigmoid(op, node, **kwargs): ...@@ -251,3 +251,14 @@ def numba_funcify_Sigmoid(op, node, **kwargs):
return 1 / (1 + np.exp(-x)) return 1 / (1 + np.exp(-x))
return sigmoid return sigmoid
@numba_funcify.register(GammaLn)
def numba_funcify_Sigmoid(op, node, **kwargs):
import math
@numba_basic.numba_njit(inline="always", fastmath=config.numba__fastmath)
def gammaln(x):
return math.lgamma(x)
return gammaln
...@@ -323,8 +323,8 @@ def test_box_unbox(input, wrapper_fn, check_fn): ...@@ -323,8 +323,8 @@ def test_box_unbox(input, wrapper_fn, check_fn):
"inputs, input_vals, output_fn, exc", "inputs, input_vals, output_fn, exc",
[ [
( (
[at.lvector()], [at.vector()],
[rng.poisson(10, size=100).astype(np.int64)], [rng.uniform(size=100).astype(config.floatX)],
lambda x: at.gammaln(x), lambda x: at.gammaln(x),
None, None,
), ),
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论