Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
8968a387
提交
8968a387
authored
4月 05, 2022
作者:
Ricardo
提交者:
Ricardo Vieira
11月 26, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Harmonize softplus implementations
上级
671a821d
隐藏空白字符变更
内嵌
并排
正在显示
2 个修改的文件
包含
31 行增加
和
35 行删除
+31
-35
scalar.py
pytensor/link/jax/dispatch/scalar.py
+11
-3
math.py
pytensor/scalar/math.py
+20
-32
没有找到文件。
pytensor/link/jax/dispatch/scalar.py
浏览文件 @
8968a387
...
@@ -125,10 +125,18 @@ def jax_funcify_Psi(op, node, **kwargs):
...
@@ -125,10 +125,18 @@ def jax_funcify_Psi(op, node, **kwargs):
@jax_funcify.register
(
Softplus
)
@jax_funcify.register
(
Softplus
)
def
jax_funcify_Softplus
(
op
,
**
kwargs
):
def
jax_funcify_Softplus
(
op
,
**
kwargs
):
def
softplus
(
x
):
def
softplus
(
x
):
# This expression is numerically equivalent to the PyTensor one
# It just contains one "speed" optimization less than the PyTensor counterpart
return
jnp
.
where
(
return
jnp
.
where
(
x
<
-
37.0
,
jnp
.
exp
(
x
),
jnp
.
where
(
x
>
33.3
,
x
,
jnp
.
log1p
(
jnp
.
exp
(
x
)))
x
<
-
37.0
,
jnp
.
exp
(
x
),
jnp
.
where
(
x
<
18.0
,
jnp
.
log1p
(
jnp
.
exp
(
x
)),
jnp
.
where
(
x
<
33.3
,
x
+
jnp
.
exp
(
-
x
),
x
,
),
),
)
)
return
softplus
return
softplus
pytensor/scalar/math.py
浏览文件 @
8968a387
...
@@ -6,6 +6,7 @@ As SciPy is not always available, we treat them separately.
...
@@ -6,6 +6,7 @@ As SciPy is not always available, we treat them separately.
import
os
import
os
import
warnings
import
warnings
from
textwrap
import
dedent
import
numpy
as
np
import
numpy
as
np
import
scipy.special
import
scipy.special
...
@@ -1134,7 +1135,8 @@ class Softplus(UnaryScalarOp):
...
@@ -1134,7 +1135,8 @@ class Softplus(UnaryScalarOp):
r"""
r"""
Compute log(1 + exp(x)), also known as softplus or log1pexp
Compute log(1 + exp(x)), also known as softplus or log1pexp
This function is numerically more stable than the naive approach.
This function is numerically faster than the naive approach, and does not overflow
for large values of x.
For details, see
For details, see
https://cran.r-project.org/web/packages/Rmpfr/vignettes/log1mexp-note.pdf
https://cran.r-project.org/web/packages/Rmpfr/vignettes/log1mexp-note.pdf
...
@@ -1172,44 +1174,30 @@ class Softplus(UnaryScalarOp):
...
@@ -1172,44 +1174,30 @@ class Softplus(UnaryScalarOp):
def
c_code
(
self
,
node
,
name
,
inp
,
out
,
sub
):
def
c_code
(
self
,
node
,
name
,
inp
,
out
,
sub
):
(
x
,)
=
inp
(
x
,)
=
inp
(
z
,)
=
out
(
z
,)
=
out
# The boundary constants were obtained by looking at the output of
# We use the same limits for all precisions, which may be suboptimal. The reference
# python commands like:
# paper only looked at double precision
# import numpy, pytensor
# dt='float32' # or float64
# for i in range(750):
# print i, repr(numpy.log1p(numpy.exp(_asarray([i,-i], dtype=dt))))
# the upper boundary check prevents us from generating inf, whereas the
# the lower boundary check prevents using exp when the result will be 0 anyway.
# The intermediate constants are taken from Machler (2012).
# We use the float32 limits for float16 for now as the
# computation will happen in float32 anyway.
if
node
.
inputs
[
0
]
.
type
in
float_types
:
if
node
.
inputs
[
0
]
.
type
in
float_types
:
if
node
.
inputs
[
0
]
.
type
==
float64
:
if
node
.
inputs
[
0
]
.
type
==
float64
:
return
(
return
dedent
(
"""
f
"""
%(z)
s = (
{z} = (
%(x)
s < -745.0 ? 0.0 :
{x} < -37.0 ? exp({x}) :
%(x)
s < -37.0 ? exp(
%(x)
s) :
{x} < 18.0 ? log1p(exp({x})) :
%(x)
s < 18.0 ? log1p(exp(
%(x)
s)) :
{x} < 33.3 ? {x} + exp(-{x}) :
%(x)
s < 33.3 ?
%(x)
s + exp(-
%(x)
s) :
{x}
%(x)
s
);
);
"""
"""
%
locals
()
)
)
else
:
else
:
return
(
return
dedent
(
"""
f
"""
%(z)
s = (
{z} = (
%(x)
s < -103.0f ? 0.0 :
{x} < -37.0f ? exp({x}) :
%(x)
s < -37.0f ? exp(
%(x)
s) :
{x} < 18.0f ? log1p(exp({x})) :
%(x)
s < 18.0f ? log1p(exp(
%(x)
s)) :
{x} < 33.3f ? {x} + exp(-{x}) :
%(x)
s < 33.3f ?
%(x)
s + exp(-
%(x)
s) :
{x}
%(x)
s
);
);
"""
"""
%
locals
()
)
)
else
:
else
:
raise
NotImplementedError
(
"only floatingpoint is implemented"
)
raise
NotImplementedError
(
"only floatingpoint is implemented"
)
...
@@ -1217,7 +1205,7 @@ class Softplus(UnaryScalarOp):
...
@@ -1217,7 +1205,7 @@ class Softplus(UnaryScalarOp):
def
c_code_cache_version
(
self
):
def
c_code_cache_version
(
self
):
v
=
super
()
.
c_code_cache_version
()
v
=
super
()
.
c_code_cache_version
()
if
v
:
if
v
:
return
(
2
,)
+
v
return
(
3
,)
+
v
else
:
else
:
return
v
return
v
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论