Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
f7d1bb3b
提交
f7d1bb3b
authored
6月 22, 2021
作者:
Ricardo
提交者:
Brandon T. Willard
7月 01, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Implement `log1mexp` op
Closes #360
上级
16c2c5cf
隐藏空白字符变更
内嵌
并排
正在显示
4 个修改的文件
包含
91 行增加
和
0 行删除
+91
-0
math.py
aesara/scalar/math.py
+47
-0
inplace.py
aesara/tensor/inplace.py
+5
-0
math.py
aesara/tensor/math.py
+6
-0
test_math_scipy.py
tests/tensor/test_math_scipy.py
+33
-0
没有找到文件。
aesara/scalar/math.py
浏览文件 @
f7d1bb3b
...
@@ -20,6 +20,7 @@ from aesara.scalar.basic import (
...
@@ -20,6 +20,7 @@ from aesara.scalar.basic import (
exp
,
exp
,
float64
,
float64
,
float_types
,
float_types
,
true_div
,
upcast
,
upcast
,
upgrade_to_float
,
upgrade_to_float
,
upgrade_to_float64
,
upgrade_to_float64
,
...
@@ -997,3 +998,49 @@ class Softplus(UnaryScalarOp):
...
@@ -997,3 +998,49 @@ class Softplus(UnaryScalarOp):
softplus
=
Softplus
(
upgrade_to_float
,
name
=
"scalar_softplus"
)
softplus
=
Softplus
(
upgrade_to_float
,
name
=
"scalar_softplus"
)
class
Log1mexp
(
UnaryScalarOp
):
r"""
Compute log(1 - exp(x)), also known as log1mexp
This function is numerically more stable than the naive approach.
For details, see
https://cran.r-project.org/web/packages/Rmpfr/vignettes/log1mexp-note.pdf
References
----------
.. [Machler2012] Martin Mächler (2012).
"Accurately computing `\log(1-\exp(- \mid a \mid))` Assessed by the Rmpfr package"
"""
@staticmethod
def
static_impl
(
x
):
if
x
<
np
.
log
(
0.5
):
return
np
.
log1p
(
-
np
.
exp
(
x
))
else
:
return
np
.
log
(
-
np
.
expm1
(
x
))
def
impl
(
self
,
x
):
return
Log1mexp
.
static_impl
(
x
)
def
grad
(
self
,
inp
,
grads
):
(
x
,)
=
inp
(
gz
,)
=
grads
return
[
gz
*
true_div
(
1.0
,
1.0
-
exp
(
-
x
))]
def
c_code
(
self
,
node
,
name
,
inp
,
out
,
sub
):
(
x
,)
=
inp
(
z
,)
=
out
if
node
.
inputs
[
0
]
.
type
in
float_types
:
if
node
.
inputs
[
0
]
.
type
==
float64
:
return
f
"{z} = {x} < -0.6931471805599453 ? log1p(-exp({x})) : log(-expm1({x}));"
else
:
return
f
"{z} = {x} < -0.6931471805599453f ? log1p(-exp({x})) : log(-expm1({x}));"
else
:
raise
NotImplementedError
(
"only floating point is implemented"
)
log1mexp
=
Log1mexp
(
upgrade_to_float
,
name
=
"scalar_log1mexp"
)
aesara/tensor/inplace.py
浏览文件 @
f7d1bb3b
...
@@ -318,6 +318,11 @@ def softplus_inplace(x):
...
@@ -318,6 +318,11 @@ def softplus_inplace(x):
"""Compute log(1 + exp(x)), also known as softplus or log1pexp"""
"""Compute log(1 + exp(x)), also known as softplus or log1pexp"""
@scalar_elemwise
def
log1mexp_inplace
(
x
):
"""Compute log(1 - exp(x)), also known as log1mexp"""
@scalar_elemwise
@scalar_elemwise
def
second_inplace
(
a
):
def
second_inplace
(
a
):
"""Fill `a` with `b`"""
"""Fill `a` with `b`"""
...
...
aesara/tensor/math.py
浏览文件 @
f7d1bb3b
...
@@ -1424,6 +1424,11 @@ def softplus(x):
...
@@ -1424,6 +1424,11 @@ def softplus(x):
log1pexp
=
softplus
log1pexp
=
softplus
@scalar_elemwise
def
log1mexp
(
x
):
"""Compute log(1 - exp(x)), also known as log1mexp"""
@scalar_elemwise
@scalar_elemwise
def
real
(
z
):
def
real
(
z
):
"""Return real component of complex-valued tensor `z`"""
"""Return real component of complex-valued tensor `z`"""
...
@@ -2903,6 +2908,7 @@ __all__ = [
...
@@ -2903,6 +2908,7 @@ __all__ = [
"expit"
,
"expit"
,
"softplus"
,
"softplus"
,
"log1pexp"
,
"log1pexp"
,
"log1mexp"
,
"real"
,
"real"
,
"imag"
,
"imag"
,
"angle"
,
"angle"
,
...
...
tests/tensor/test_math_scipy.py
浏览文件 @
f7d1bb3b
...
@@ -567,6 +567,39 @@ class TestSoftplus:
...
@@ -567,6 +567,39 @@ class TestSoftplus:
np
.
testing
.
assert_allclose
(
y_th
,
y_np
,
rtol
=
10e-10
)
np
.
testing
.
assert_allclose
(
y_th
,
y_np
,
rtol
=
10e-10
)
_good_broadcast_unary_log1mexp
=
dict
(
normal
=
(
random_ranged
(
-
10.0
,
0
,
(
2
,
3
)),),
float32
=
(
random_ranged
(
-
10.0
,
0
,
(
2
,
3
))
.
astype
(
"float32"
),),
empty
=
(
np
.
asarray
([],
dtype
=
config
.
floatX
),),
int
=
(
integers_ranged
(
-
10
,
-
1
,
(
2
,
3
)),),
)
_grad_broadcast_unary_log1mexp
=
dict
(
normal
=
(
random_ranged
(
-
10.0
,
0.0
,
(
2
,
3
)),),
)
def
expected_log1mexp
(
x
):
return
check_floatX
(
x
,
np
.
log
(
-
np
.
expm1
(
x
)))
TestLog1mexpBroadcast
=
makeBroadcastTester
(
op
=
aet
.
log1mexp
,
expected
=
expected_log1mexp
,
good
=
_good_broadcast_unary_log1mexp
,
grad
=
_grad_broadcast_unary_log1mexp
,
eps
=
1e-8
,
)
TestLog1mexpInplaceBroadcast
=
makeBroadcastTester
(
op
=
inplace
.
log1mexp_inplace
,
expected
=
expected_log1mexp
,
good
=
_good_broadcast_unary_log1mexp
,
eps
=
1e-8
,
inplace
=
True
,
)
def
test_deprecated_module
():
def
test_deprecated_module
():
with
pytest
.
warns
(
DeprecationWarning
):
with
pytest
.
warns
(
DeprecationWarning
):
import
aesara.scalar.basic_scipy
# noqa: F401
import
aesara.scalar.basic_scipy
# noqa: F401
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论