Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
32bc96d7
提交
32bc96d7
authored
6月 15, 2015
作者:
David Warde-Farley
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #2849 from hantek/nan_tutorial
Tutorial about NaNs
上级
8dbef49b
87867d4c
隐藏空白字符变更
内嵌
并排
正在显示
4 个修改的文件
包含
341 行增加
和
0 行删除
+341
-0
nanguardmode.txt
doc/library/compile/nanguardmode.txt
+19
-0
nan_tutorial.txt
doc/tutorial/nan_tutorial.txt
+66
-0
nanguardmode.py
theano/compile/nanguardmode.py
+207
-0
test_nanguardmode.py
theano/compile/tests/test_nanguardmode.py
+49
-0
没有找到文件。
doc/library/compile/nanguardmode.txt
0 → 100644
浏览文件 @
32bc96d7
.. _nanguardmode:
=================
:mod:`nanguardmode`
=================
.. module:: nanguardmode
:platform: Unix, Windows
:synopsis: defines NanGuardMode
.. moduleauthor:: LISA
Guide
=====
The NanGuardMode aims to prevent the model from outputing NaNs or Infs. It has
a number of self-checks, which can help to find out which apply node is
generating those incorrect outputs.
doc/tutorial/nan_tutorial.txt
0 → 100644
浏览文件 @
32bc96d7
.. _nan_tutorial:
=================
Dealing with NaNs
=================
Having a model yielding NaNs or Infs is quite common if some of the tiny
components in your model are not set properly. NaNs are hard to deal with
because sometimes it is caused by a bug or error in the code, sometimes it's
because of the numerical stability of your computational environment (library
versions, etc.), and even, sometimes it relates to your algorithm. Here we try
to outline common issues which cause the model to yield NaNs, as well as
provide nails and hammers to diagnose it.
Check Superparameters and Weight Initialization
-----------------------------------------------
Most frequently, the cause would be that some of the hyperparameters, especially
learning rates, are set incorrectly. A high learning rate can blow up your whole
model into NaN outputs even within one epoch of training. So the first and
easiest solution is try to lower it. Keep halving your learning rate until you
start to get resonable output values.
Other hyperparameters may also play a role. For example, are your training
algorithms involve regularization terms? If so, are their corresponding
penalties set reasonably? Search a wider hyperparameter space with a few (one or
two) training eopchs each to see if the NaNs could disappear.
Some models can be very sensitive to the initialization of weight vectors. If
those weights are not initialized in a proper range, then it is not surprising
that the model ends up with yielding NaNs.
Run in DebugMode
-----------------
If adjusting hyperparameters doesn't work for you, you can still get help from
Theano's DebugMode. Run your code in DebugMode with flag mode=DebugMode,
DebugMode.check_py=False. This will give you clue about which op is causing this
problem, and then you can inspect into that op in more detail. For a detailed
of using DebugMode, please refere to :ref:`debugmode`.
Theano's MonitorMode can also help. It can be used to step through the execution
of a function. You can inspect the inputs and outputs of each node being
executed when the function is called. For how to use that, please check
:ref:`faq_monitormode`.
Numerical Stability
-------------------
After you have located the op which causes the problem, it may turn out that the
NaNs yielded by that op are related to numerical issues. For example, :math:
`1 / log(p(x) + 1)` may result in NaNs for those nodes who have learned to yield
a low probability p(x) for some input x.
Algorithm Related
-----------------
In the most difficult situations, you may go through the above steps and find
nothing wrong. If the above methods fail to uncover the cause, there is a good
chance that something is wrong with your algorithm. Go back to the mathematics
and find out if everything is derived correctly.
theano/compile/nanguardmode.py
0 → 100644
浏览文件 @
32bc96d7
import
logging
import
collections
import
numpy
as
np
import
theano
import
theano.tensor
as
T
import
theano.sandbox.cuda
as
cuda
from
theano.compile
import
Mode
logger
=
logging
.
getLogger
(
"theano.compile.nanguardmode"
)
def
flatten
(
l
):
"""
Turns a nested graph of lists/tuples/other objects into a list of objects.
Parameters
----------
l : List/tuple/other objects, might be nested.
Returns
-------
A flattened list of objects
"""
if
isinstance
(
l
,
(
list
,
tuple
,
collections
.
ValuesView
)):
rval
=
[]
for
elem
in
l
:
if
isinstance
(
elem
,
(
list
,
tuple
)):
rval
.
extend
(
flatten
(
elem
))
else
:
rval
.
append
(
elem
)
else
:
return
[
l
]
return
rval
def
contains_nan
(
arr
):
"""
Test whether a numpy.ndarray contains any `np.nan` values.
Parameters
----------
arr : np.ndarray
Returns
-------
contains_nan : bool
`True` if the array contains any `np.nan` values, `False` otherwise.
Notes
-----
Tests for the presence of `np.nan`'s using `np.isnan(np.min(ndarray))`.
This approach is faster and more memory efficient than the obvious
alternative, calling `np.any(np.isnan(ndarray))`, which requires the
construction of a boolean array with the same shape as the input array.
"""
return
np
.
isnan
(
np
.
min
(
arr
))
def
contains_inf
(
arr
):
"""
Test whether a numpy.ndarray contains any `np.inf` values.
Parameters
----------
arr : np.ndarray
Returns
-------
contains_inf : bool
`True` if the array contains any `np.inf` values, `False` otherwise.
Notes
-----
Tests for the presence of `np.inf`'s by determining whether the
values returned by `np.nanmin(arr)` and `np.nanmax(arr)` are finite.
This approach is more memory efficient than the obvious alternative,
calling `np.any(np.isinf(ndarray))`, which requires the construction of a
boolean array with the same shape as the input array.
"""
return
np
.
isinf
(
np
.
nanmax
(
arr
))
or
np
.
isinf
(
np
.
nanmin
(
arr
))
class
NanGuardMode
(
Mode
):
"""
A Theano compilation Mode that makes the compiled function automatically
detect NaNs and Infs and detect an error if they occur.
Parameters
----------
nan_is_error : bool
If True, raise an error anytime a NaN is encountered
inf_is_error: bool
If True, raise an error anytime an Inf is encountered. Note that some
pylearn2 modules currently use np.inf as a default value (e.g.
mlp.max_pool) and these will cause an error if inf_is_error is True.
big_is_error: bool
If True, raise an error when a value greater than 1e10 is encountered.
"""
def
__init__
(
self
,
nan_is_error
,
inf_is_error
,
big_is_error
=
True
):
if
cuda
.
cuda_available
:
self
.
guard_input
=
cuda
.
fvector
(
'nan_guard'
)
if
nan_is_error
or
inf_is_error
:
self
.
gpumin
=
theano
.
function
(
[
self
.
guard_input
],
T
.
min
(
self
.
guard_input
),
mode
=
'FAST_RUN'
)
if
inf_is_error
:
self
.
gpumax
=
theano
.
function
(
[
self
.
guard_input
],
T
.
max
(
self
.
guard_input
),
mode
=
'FAST_RUN'
)
if
big_is_error
:
self
.
gpuabsmax
=
theano
.
function
(
[
self
.
guard_input
],
T
.
max
(
T
.
abs_
(
self
.
guard_input
)),
mode
=
'FAST_RUN'
)
def
do_check_on
(
var
,
nd
,
f
,
is_input
):
"""
Checks `var` for NaNs / Infs. If detected, raises an exception
and / or prints information about `nd`, `f`, and `is_input` to
help the user determine the cause of the invalid values.
Parameters
----------
var : numpy.ndarray
The value to be checked.
nd : theano.gof.Apply
The Apply node being executed
f : callable
The thunk for the apply node
is_input : bool
If True, `var` is an input to `nd`.
If False, it is an output.
"""
error
=
False
if
nan_is_error
:
err
=
False
if
cuda
.
cuda_available
and
isinstance
(
var
,
cuda
.
CudaNdarray
):
err
=
np
.
isnan
(
self
.
gpumin
(
var
.
reshape
(
var
.
size
)))
else
:
err
=
contains_nan
(
var
)
if
err
:
logger
.
error
(
'NaN detected'
)
error
=
True
if
inf_is_error
:
err
=
False
if
cuda
.
cuda_available
and
isinstance
(
var
,
cuda
.
CudaNdarray
):
err
=
(
np
.
isinf
(
self
.
gpumin
(
var
.
reshape
(
var
.
size
)))
or
np
.
isinf
(
self
.
gpumax
(
var
.
reshape
(
var
.
size
))))
else
:
err
=
contains_inf
(
var
)
if
err
:
logger
.
error
(
'Inf detected'
)
error
=
True
if
big_is_error
:
err
=
False
if
cuda
.
cuda_available
and
isinstance
(
var
,
cuda
.
CudaNdarray
):
err
=
(
self
.
gpuabsmax
(
var
.
reshape
(
var
.
size
))
>
1e10
)
else
:
err
=
(
np
.
abs
(
var
)
.
max
()
>
1e10
)
if
err
:
logger
.
error
(
'Big value detected'
)
error
=
True
if
error
:
if
is_input
:
logger
.
error
(
'In an input'
)
else
:
logger
.
error
(
'In an output'
)
logger
.
error
(
'Inputs: '
)
for
ivar
,
ival
in
zip
(
nd
.
inputs
,
f
.
inputs
):
logger
.
error
(
'var'
)
logger
.
error
(
ivar
)
logger
.
error
(
theano
.
printing
.
min_informative_str
(
ivar
))
logger
.
error
(
'val'
)
logger
.
error
(
ival
)
logger
.
error
(
'Node:'
)
logger
.
error
(
nd
)
assert
False
def
nan_check
(
i
,
node
,
fn
):
"""
Runs `fn` while checking its inputs and outputs for NaNs / Infs
Parameters
----------
i : currently ignored (TODO: determine why it is here or remove)
node : theano.gof.Apply
The Apply node currently being executed
fn : callable
The thunk to execute for this Apply node
"""
inputs
=
fn
.
inputs
# TODO: figure out why individual inputs are themselves lists
# sometimes
for
x
in
flatten
(
inputs
):
do_check_on
(
x
,
node
,
fn
,
True
)
fn
()
outputs
=
fn
.
outputs
for
j
,
x
in
enumerate
(
flatten
(
outputs
)):
do_check_on
(
x
,
node
,
fn
,
False
)
wrap_linker
=
theano
.
gof
.
WrapLinkerMany
([
theano
.
gof
.
OpWiseCLinker
()],
[
nan_check
])
super
(
NanGuardMode
,
self
)
.
__init__
(
wrap_linker
,
optimizer
=
theano
.
config
.
optimizer
)
theano/compile/tests/test_nanguardmode.py
0 → 100644
浏览文件 @
32bc96d7
"""
This test is for testing the NanGuardMode.
"""
from
theano.compile.nanguardmode
import
NanGuardMode
import
numpy
import
theano
import
theano.tensor
as
T
def
test_NanGuardMode
():
"""
Tests if NanGuardMode is working by feeding in numpy.inf and numpy.nans
intentionally. A working implementation should be able to capture all
the abnormalties.
"""
x
=
T
.
matrix
()
w
=
theano
.
shared
(
numpy
.
random
.
randn
(
5
,
7
)
.
astype
(
theano
.
config
.
floatX
))
y
=
T
.
dot
(
x
,
w
)
fun
=
theano
.
function
(
[
x
],
y
,
mode
=
NanGuardMode
(
nan_is_error
=
True
,
inf_is_error
=
True
)
)
a
=
numpy
.
random
.
randn
(
3
,
5
)
.
astype
(
theano
.
config
.
floatX
)
infa
=
numpy
.
tile
(
(
numpy
.
asarray
(
100.
)
**
1000000
)
.
astype
(
theano
.
config
.
floatX
),
(
3
,
5
))
nana
=
numpy
.
tile
(
numpy
.
asarray
(
numpy
.
nan
)
.
astype
(
theano
.
config
.
floatX
),
(
3
,
5
))
biga
=
numpy
.
tile
(
numpy
.
asarray
(
1e20
)
.
astype
(
theano
.
config
.
floatX
),
(
3
,
5
))
work
=
[
False
,
False
,
False
]
fun
(
a
)
# normal values
try
:
fun
(
infa
)
# INFs
except
AssertionError
:
work
[
0
]
=
True
try
:
fun
(
nana
)
# NANs
except
AssertionError
:
work
[
1
]
=
True
try
:
fun
(
biga
)
# big values
except
AssertionError
:
work
[
2
]
=
True
if
not
(
work
[
0
]
and
work
[
1
]
and
work
[
2
]):
raise
AssertionError
(
"NanGuardMode not working."
)
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论