提交 a7902a17 authored 作者: Kyle Caron's avatar Kyle Caron 提交者: Ricardo Vieira

jax implementation of log1mexp op

上级 2ccd9cca
...@@ -17,7 +17,7 @@ from aesara.link.utils import fgraph_to_python ...@@ -17,7 +17,7 @@ from aesara.link.utils import fgraph_to_python
from aesara.raise_op import CheckAndRaise from aesara.raise_op import CheckAndRaise
from aesara.scalar import Softplus from aesara.scalar import Softplus
from aesara.scalar.basic import Cast, Clip, Composite, Identity, ScalarOp, Second from aesara.scalar.basic import Cast, Clip, Composite, Identity, ScalarOp, Second
from aesara.scalar.math import Erf, Erfc, Erfinv, Psi from aesara.scalar.math import Erf, Erfc, Erfinv, Log1mexp, Psi
from aesara.scan.op import Scan from aesara.scan.op import Scan
from aesara.scan.utils import ScanArgs from aesara.scan.utils import ScanArgs
from aesara.tensor.basic import ( from aesara.tensor.basic import (
...@@ -1119,6 +1119,16 @@ def jax_funcify_Erfc(op, **kwargs): ...@@ -1119,6 +1119,16 @@ def jax_funcify_Erfc(op, **kwargs):
return erfc return erfc
@jax_funcify.register(Log1mexp)
def jax_funcify_Log1mexp(op, node, **kwargs):
def log1mexp(x):
return jnp.where(
x < jnp.log(0.5), jnp.log1p(-jnp.exp(x)), jnp.log(-jnp.expm1(x))
)
return log1mexp
# Commented out because jax.scipy does not have erfcx, # Commented out because jax.scipy does not have erfcx,
# but leaving the implementation in here just in case we ever see # but leaving the implementation in here just in case we ever see
# a JAX implementation of Erfcx. # a JAX implementation of Erfcx.
......
...@@ -32,7 +32,7 @@ from aesara.tensor import subtensor as at_subtensor ...@@ -32,7 +32,7 @@ from aesara.tensor import subtensor as at_subtensor
from aesara.tensor.elemwise import Elemwise from aesara.tensor.elemwise import Elemwise
from aesara.tensor.math import MaxAndArgmax from aesara.tensor.math import MaxAndArgmax
from aesara.tensor.math import all as at_all from aesara.tensor.math import all as at_all
from aesara.tensor.math import clip, cosh, erf, erfc, erfinv, gammaln, log from aesara.tensor.math import clip, cosh, erf, erfc, erfinv, gammaln, log, log1mexp
from aesara.tensor.math import max as at_max from aesara.tensor.math import max as at_max
from aesara.tensor.math import maximum, prod, psi, sigmoid, softplus from aesara.tensor.math import maximum, prod, psi, sigmoid, softplus
from aesara.tensor.math import sum as at_sum from aesara.tensor.math import sum as at_sum
...@@ -1394,3 +1394,11 @@ def test_psi(): ...@@ -1394,3 +1394,11 @@ def test_psi():
out = psi(x) out = psi(x)
fg = FunctionGraph([x], [out]) fg = FunctionGraph([x], [out])
compare_jax_and_py(fg, [3.0]) compare_jax_and_py(fg, [3.0])
def test_log1mexp():
x = vector("x")
out = log1mexp(x)
fg = FunctionGraph([x], [out])
compare_jax_and_py(fg, [[-1.0, -0.75, -0.5, -0.25]])
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论