提交 a7902a17 authored 作者: Kyle Caron's avatar Kyle Caron 提交者: Ricardo Vieira

jax implementation of log1mexp op

上级 2ccd9cca
......@@ -17,7 +17,7 @@ from aesara.link.utils import fgraph_to_python
from aesara.raise_op import CheckAndRaise
from aesara.scalar import Softplus
from aesara.scalar.basic import Cast, Clip, Composite, Identity, ScalarOp, Second
from aesara.scalar.math import Erf, Erfc, Erfinv, Psi
from aesara.scalar.math import Erf, Erfc, Erfinv, Log1mexp, Psi
from aesara.scan.op import Scan
from aesara.scan.utils import ScanArgs
from aesara.tensor.basic import (
......@@ -1119,6 +1119,16 @@ def jax_funcify_Erfc(op, **kwargs):
return erfc
@jax_funcify.register(Log1mexp)
def jax_funcify_Log1mexp(op, node, **kwargs):
def log1mexp(x):
return jnp.where(
x < jnp.log(0.5), jnp.log1p(-jnp.exp(x)), jnp.log(-jnp.expm1(x))
)
return log1mexp
# Commented out because jax.scipy does not have erfcx,
# but leaving the implementation in here just in case we ever see
# a JAX implementation of Erfcx.
......
......@@ -32,7 +32,7 @@ from aesara.tensor import subtensor as at_subtensor
from aesara.tensor.elemwise import Elemwise
from aesara.tensor.math import MaxAndArgmax
from aesara.tensor.math import all as at_all
from aesara.tensor.math import clip, cosh, erf, erfc, erfinv, gammaln, log
from aesara.tensor.math import clip, cosh, erf, erfc, erfinv, gammaln, log, log1mexp
from aesara.tensor.math import max as at_max
from aesara.tensor.math import maximum, prod, psi, sigmoid, softplus
from aesara.tensor.math import sum as at_sum
......@@ -1394,3 +1394,11 @@ def test_psi():
out = psi(x)
fg = FunctionGraph([x], [out])
compare_jax_and_py(fg, [3.0])
def test_log1mexp():
x = vector("x")
out = log1mexp(x)
fg = FunctionGraph([x], [out])
compare_jax_and_py(fg, [[-1.0, -0.75, -0.5, -0.25]])
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论