提交 2dbeb781 authored 作者: Eric Ma's avatar Eric Ma 提交者: Ricardo Vieira

Add JAX implementations for Erf, Erfc, and Erfinv Ops

上级 afe290db
...@@ -16,6 +16,7 @@ from aesara.ifelse import IfElse ...@@ -16,6 +16,7 @@ from aesara.ifelse import IfElse
from aesara.link.utils import fgraph_to_python from aesara.link.utils import fgraph_to_python
from aesara.scalar import Softplus from aesara.scalar import Softplus
from aesara.scalar.basic import Cast, Clip, Composite, Identity, ScalarOp, Second from aesara.scalar.basic import Cast, Clip, Composite, Identity, ScalarOp, Second
from aesara.scalar.math import Erf, Erfc, Erfinv
from aesara.scan.op import Scan from aesara.scan.op import Scan
from aesara.scan.utils import ScanArgs from aesara.scan.utils import ScanArgs
from aesara.tensor.basic import ( from aesara.tensor.basic import (
...@@ -1047,3 +1048,48 @@ def jax_funcify_RandomVariable(op, node, **kwargs): ...@@ -1047,3 +1048,48 @@ def jax_funcify_RandomVariable(op, node, **kwargs):
return (rng, smpl_value) return (rng, smpl_value)
return random_variable return random_variable
@jax_funcify.register(Erf)
def jax_funcify_Erf(op, node, **kwargs):
def erf(x):
return jax.scipy.special.erf(x)
return erf
@jax_funcify.register(Erfc)
def jax_funcify_Erfc(op, **kwargs):
def erfc(x):
return jax.scipy.special.erfc(x)
return erfc
# Commented out because jax.scipy does not have erfcx,
# but leaving the implementation in here just in case we ever see
# a JAX implementation of Erfcx.
# See https://github.com/google/jax/issues/1987 for context.
# @jax_funcify.register(Erfcx)
# def jax_funcify_Erfcx(op, **kwargs):
# def erfcx(x):
# return jax.scipy.special.erfcx(x)
# return erfcx
@jax_funcify.register(Erfinv)
def jax_funcify_Erfinv(op, **kwargs):
def erfinv(x):
return jax.scipy.special.erfinv(x)
return erfinv
# Commented out because jax.scipy does not have Erfcinv,
# but leaving the implementation in here just in case we ever see
# a JAX implementation of Erfcinv.
# @jax_funcify.register(Erfcinv)
# def jax_funcify_Erfcinv(op, **kwargs):
# def erfcinv(x):
# return jax.scipy.special.erfcinv(x)
# return erfcinv
...@@ -30,7 +30,7 @@ from aesara.tensor import subtensor as aet_subtensor ...@@ -30,7 +30,7 @@ from aesara.tensor import subtensor as aet_subtensor
from aesara.tensor.elemwise import Elemwise from aesara.tensor.elemwise import Elemwise
from aesara.tensor.math import MaxAndArgmax from aesara.tensor.math import MaxAndArgmax
from aesara.tensor.math import all as aet_all from aesara.tensor.math import all as aet_all
from aesara.tensor.math import clip, cosh, gammaln, log from aesara.tensor.math import clip, cosh, erf, erfc, erfinv, gammaln, log
from aesara.tensor.math import max as aet_max from aesara.tensor.math import max as aet_max
from aesara.tensor.math import maximum, prod, sigmoid, softplus from aesara.tensor.math import maximum, prod, sigmoid, softplus
from aesara.tensor.math import sum as aet_sum from aesara.tensor.math import sum as aet_sum
...@@ -1254,3 +1254,27 @@ def test_RandomStream(): ...@@ -1254,3 +1254,27 @@ def test_RandomStream():
jax_res_2 = fn() jax_res_2 = fn()
assert np.array_equal(jax_res_1, jax_res_2) assert np.array_equal(jax_res_1, jax_res_2)
def test_erf():
x = scalar("x")
out = erf(x)
fg = FunctionGraph([x], [out])
compare_jax_and_py(fg, [1.0])
def test_erfc():
x = scalar("x")
out = erfc(x)
fg = FunctionGraph([x], [out])
compare_jax_and_py(fg, [1.0])
def test_erfinv():
x = scalar("x")
out = erfinv(x)
fg = FunctionGraph([x], [out])
compare_jax_and_py(fg, [1.0])
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论