Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
2dbeb781
提交
2dbeb781
authored
9月 25, 2021
作者:
Eric Ma
提交者:
Ricardo Vieira
10月 09, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add JAX implementations for Erf, Erfc, and Erfinv Ops
上级
afe290db
隐藏空白字符变更
内嵌
并排
正在显示
2 个修改的文件
包含
71 行增加
和
1 行删除
+71
-1
dispatch.py
aesara/link/jax/dispatch.py
+46
-0
test_jax.py
tests/link/test_jax.py
+25
-1
没有找到文件。
aesara/link/jax/dispatch.py
浏览文件 @
2dbeb781
...
...
@@ -16,6 +16,7 @@ from aesara.ifelse import IfElse
from
aesara.link.utils
import
fgraph_to_python
from
aesara.scalar
import
Softplus
from
aesara.scalar.basic
import
Cast
,
Clip
,
Composite
,
Identity
,
ScalarOp
,
Second
from
aesara.scalar.math
import
Erf
,
Erfc
,
Erfinv
from
aesara.scan.op
import
Scan
from
aesara.scan.utils
import
ScanArgs
from
aesara.tensor.basic
import
(
...
...
@@ -1047,3 +1048,48 @@ def jax_funcify_RandomVariable(op, node, **kwargs):
return
(
rng
,
smpl_value
)
return
random_variable
@jax_funcify.register
(
Erf
)
def
jax_funcify_Erf
(
op
,
node
,
**
kwargs
):
def
erf
(
x
):
return
jax
.
scipy
.
special
.
erf
(
x
)
return
erf
@jax_funcify.register
(
Erfc
)
def
jax_funcify_Erfc
(
op
,
**
kwargs
):
def
erfc
(
x
):
return
jax
.
scipy
.
special
.
erfc
(
x
)
return
erfc
# Commented out because jax.scipy does not have erfcx,
# but leaving the implementation in here just in case we ever see
# a JAX implementation of Erfcx.
# See https://github.com/google/jax/issues/1987 for context.
# @jax_funcify.register(Erfcx)
# def jax_funcify_Erfcx(op, **kwargs):
# def erfcx(x):
# return jax.scipy.special.erfcx(x)
# return erfcx
@jax_funcify.register
(
Erfinv
)
def
jax_funcify_Erfinv
(
op
,
**
kwargs
):
def
erfinv
(
x
):
return
jax
.
scipy
.
special
.
erfinv
(
x
)
return
erfinv
# Commented out because jax.scipy does not have Erfcinv,
# but leaving the implementation in here just in case we ever see
# a JAX implementation of Erfcinv.
# @jax_funcify.register(Erfcinv)
# def jax_funcify_Erfcinv(op, **kwargs):
# def erfcinv(x):
# return jax.scipy.special.erfcinv(x)
# return erfcinv
tests/link/test_jax.py
浏览文件 @
2dbeb781
...
...
@@ -30,7 +30,7 @@ from aesara.tensor import subtensor as aet_subtensor
from
aesara.tensor.elemwise
import
Elemwise
from
aesara.tensor.math
import
MaxAndArgmax
from
aesara.tensor.math
import
all
as
aet_all
from
aesara.tensor.math
import
clip
,
cosh
,
gammaln
,
log
from
aesara.tensor.math
import
clip
,
cosh
,
erf
,
erfc
,
erfinv
,
gammaln
,
log
from
aesara.tensor.math
import
max
as
aet_max
from
aesara.tensor.math
import
maximum
,
prod
,
sigmoid
,
softplus
from
aesara.tensor.math
import
sum
as
aet_sum
...
...
@@ -1254,3 +1254,27 @@ def test_RandomStream():
jax_res_2
=
fn
()
assert
np
.
array_equal
(
jax_res_1
,
jax_res_2
)
def
test_erf
():
x
=
scalar
(
"x"
)
out
=
erf
(
x
)
fg
=
FunctionGraph
([
x
],
[
out
])
compare_jax_and_py
(
fg
,
[
1.0
])
def
test_erfc
():
x
=
scalar
(
"x"
)
out
=
erfc
(
x
)
fg
=
FunctionGraph
([
x
],
[
out
])
compare_jax_and_py
(
fg
,
[
1.0
])
def
test_erfinv
():
x
=
scalar
(
"x"
)
out
=
erfinv
(
x
)
fg
=
FunctionGraph
([
x
],
[
out
])
compare_jax_and_py
(
fg
,
[
1.0
])
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论