提交 9653ade1 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Implement JAX dispatch for TriGamma

上级 a1fcb77c
......@@ -20,7 +20,17 @@ from pytensor.scalar.basic import (
Second,
Sub,
)
from pytensor.scalar.math import Erf, Erfc, Erfcinv, Erfcx, Erfinv, Iv, Log1mexp, Psi
from pytensor.scalar.math import (
Erf,
Erfc,
Erfcinv,
Erfcx,
Erfinv,
Iv,
Log1mexp,
Psi,
TriGamma,
)
def try_import_tfp_jax_op(op: ScalarOp, jax_op_name: Optional[str] = None) -> Callable:
......@@ -275,6 +285,14 @@ def jax_funcify_Psi(op, node, **kwargs):
return psi
@jax_funcify.register(TriGamma)
def jax_funcify_TriGamma(op, node, **kwargs):
def tri_gamma(x):
return jax.scipy.special.polygamma(1, x)
return tri_gamma
@jax_funcify.register(Softplus)
def jax_funcify_Softplus(op, **kwargs):
def softplus(x):
......
......@@ -171,6 +171,13 @@ def test_psi():
compare_jax_and_py(fg, [3.0])
def test_tri_gamma():
x = vector("x", dtype="float64")
out = tri_gamma(x)
fg = FunctionGraph([x], [out])
compare_jax_and_py(fg, [np.array([3.0, 5.0])])
def test_log1mexp():
x = vector("x")
out = log1mexp(x)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论