提交 c80dd0e6 authored 作者: Ricardo's avatar Ricardo 提交者: Brandon T. Willard

Add jax dispatch to Psi Op

上级 6157b651
...@@ -16,7 +16,7 @@ from aesara.ifelse import IfElse ...@@ -16,7 +16,7 @@ from aesara.ifelse import IfElse
from aesara.link.utils import fgraph_to_python from aesara.link.utils import fgraph_to_python
from aesara.scalar import Softplus from aesara.scalar import Softplus
from aesara.scalar.basic import Cast, Clip, Composite, Identity, ScalarOp, Second from aesara.scalar.basic import Cast, Clip, Composite, Identity, ScalarOp, Second
from aesara.scalar.math import Erf, Erfc, Erfinv from aesara.scalar.math import Erf, Erfc, Erfinv, Psi
from aesara.scan.op import Scan from aesara.scan.op import Scan
from aesara.scan.utils import ScanArgs from aesara.scan.utils import ScanArgs
from aesara.tensor.basic import ( from aesara.tensor.basic import (
...@@ -1087,3 +1087,11 @@ def jax_funcify_Erfinv(op, **kwargs): ...@@ -1087,3 +1087,11 @@ def jax_funcify_Erfinv(op, **kwargs):
# def erfcinv(x): # def erfcinv(x):
# return jax.scipy.special.erfcinv(x) # return jax.scipy.special.erfcinv(x)
# return erfcinv # return erfcinv
@jax_funcify.register(Psi)
def jax_funcify_Psi(op, node, **kwargs):
def psi(x):
return jax.scipy.special.digamma(x)
return psi
...@@ -32,7 +32,7 @@ from aesara.tensor.math import MaxAndArgmax ...@@ -32,7 +32,7 @@ from aesara.tensor.math import MaxAndArgmax
from aesara.tensor.math import all as aet_all from aesara.tensor.math import all as aet_all
from aesara.tensor.math import clip, cosh, erf, erfc, erfinv, gammaln, log from aesara.tensor.math import clip, cosh, erf, erfc, erfinv, gammaln, log
from aesara.tensor.math import max as aet_max from aesara.tensor.math import max as aet_max
from aesara.tensor.math import maximum, prod, sigmoid, softplus from aesara.tensor.math import maximum, prod, psi, sigmoid, softplus
from aesara.tensor.math import sum as aet_sum from aesara.tensor.math import sum as aet_sum
from aesara.tensor.random.basic import RandomVariable, normal from aesara.tensor.random.basic import RandomVariable, normal
from aesara.tensor.random.utils import RandomStream from aesara.tensor.random.utils import RandomStream
...@@ -1274,3 +1274,10 @@ def test_erfinv(): ...@@ -1274,3 +1274,10 @@ def test_erfinv():
fg = FunctionGraph([x], [out]) fg = FunctionGraph([x], [out])
compare_jax_and_py(fg, [1.0]) compare_jax_and_py(fg, [1.0])
def test_psi():
x = scalar("x")
out = psi(x)
fg = FunctionGraph([x], [out])
compare_jax_and_py(fg, [3.0])
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论