Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
a7902a17
提交
a7902a17
authored
6月 14, 2022
作者:
Kyle Caron
提交者:
Ricardo Vieira
6月 15, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
jax implementation of log1mexp op
上级
2ccd9cca
隐藏空白字符变更
内嵌
并排
正在显示
2 个修改的文件
包含
20 行增加
和
2 行删除
+20
-2
dispatch.py
aesara/link/jax/dispatch.py
+11
-1
test_jax.py
tests/link/test_jax.py
+9
-1
没有找到文件。
aesara/link/jax/dispatch.py
浏览文件 @
a7902a17
...
@@ -17,7 +17,7 @@ from aesara.link.utils import fgraph_to_python
...
@@ -17,7 +17,7 @@ from aesara.link.utils import fgraph_to_python
from
aesara.raise_op
import
CheckAndRaise
from
aesara.raise_op
import
CheckAndRaise
from
aesara.scalar
import
Softplus
from
aesara.scalar
import
Softplus
from
aesara.scalar.basic
import
Cast
,
Clip
,
Composite
,
Identity
,
ScalarOp
,
Second
from
aesara.scalar.basic
import
Cast
,
Clip
,
Composite
,
Identity
,
ScalarOp
,
Second
from
aesara.scalar.math
import
Erf
,
Erfc
,
Erfinv
,
Psi
from
aesara.scalar.math
import
Erf
,
Erfc
,
Erfinv
,
Log1mexp
,
Psi
from
aesara.scan.op
import
Scan
from
aesara.scan.op
import
Scan
from
aesara.scan.utils
import
ScanArgs
from
aesara.scan.utils
import
ScanArgs
from
aesara.tensor.basic
import
(
from
aesara.tensor.basic
import
(
...
@@ -1119,6 +1119,16 @@ def jax_funcify_Erfc(op, **kwargs):
...
@@ -1119,6 +1119,16 @@ def jax_funcify_Erfc(op, **kwargs):
return
erfc
return
erfc
@jax_funcify.register
(
Log1mexp
)
def
jax_funcify_Log1mexp
(
op
,
node
,
**
kwargs
):
def
log1mexp
(
x
):
return
jnp
.
where
(
x
<
jnp
.
log
(
0.5
),
jnp
.
log1p
(
-
jnp
.
exp
(
x
)),
jnp
.
log
(
-
jnp
.
expm1
(
x
))
)
return
log1mexp
# Commented out because jax.scipy does not have erfcx,
# Commented out because jax.scipy does not have erfcx,
# but leaving the implementation in here just in case we ever see
# but leaving the implementation in here just in case we ever see
# a JAX implementation of Erfcx.
# a JAX implementation of Erfcx.
...
...
tests/link/test_jax.py
浏览文件 @
a7902a17
...
@@ -32,7 +32,7 @@ from aesara.tensor import subtensor as at_subtensor
...
@@ -32,7 +32,7 @@ from aesara.tensor import subtensor as at_subtensor
from
aesara.tensor.elemwise
import
Elemwise
from
aesara.tensor.elemwise
import
Elemwise
from
aesara.tensor.math
import
MaxAndArgmax
from
aesara.tensor.math
import
MaxAndArgmax
from
aesara.tensor.math
import
all
as
at_all
from
aesara.tensor.math
import
all
as
at_all
from
aesara.tensor.math
import
clip
,
cosh
,
erf
,
erfc
,
erfinv
,
gammaln
,
log
from
aesara.tensor.math
import
clip
,
cosh
,
erf
,
erfc
,
erfinv
,
gammaln
,
log
,
log1mexp
from
aesara.tensor.math
import
max
as
at_max
from
aesara.tensor.math
import
max
as
at_max
from
aesara.tensor.math
import
maximum
,
prod
,
psi
,
sigmoid
,
softplus
from
aesara.tensor.math
import
maximum
,
prod
,
psi
,
sigmoid
,
softplus
from
aesara.tensor.math
import
sum
as
at_sum
from
aesara.tensor.math
import
sum
as
at_sum
...
@@ -1394,3 +1394,11 @@ def test_psi():
...
@@ -1394,3 +1394,11 @@ def test_psi():
out
=
psi
(
x
)
out
=
psi
(
x
)
fg
=
FunctionGraph
([
x
],
[
out
])
fg
=
FunctionGraph
([
x
],
[
out
])
compare_jax_and_py
(
fg
,
[
3.0
])
compare_jax_and_py
(
fg
,
[
3.0
])
def
test_log1mexp
():
x
=
vector
(
"x"
)
out
=
log1mexp
(
x
)
fg
=
FunctionGraph
([
x
],
[
out
])
compare_jax_and_py
(
fg
,
[[
-
1.0
,
-
0.75
,
-
0.5
,
-
0.25
]])
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论