提交 aa3d16ff authored 作者: Frederic Bastien's avatar Frederic Bastien

Added flags DebugMode.allow_remove_inf and TensorType.value_eq_approx(…

Added flags DebugMode.allow_remove_inf and TensorType.value_eq_approx( allow_remove_inf=False) parameter. When True, we allow an optimization to replace an inf to some other value. Usefull to test stability optimization.
上级 3b212df4
......@@ -42,6 +42,9 @@ AddConfigVar('DebugMode.check_strides',
"On difference: (0) - ignore, (1) warn, or (2) raise error"),
IntParam(1, lambda i: i in (0,1,2)))
AddConfigVar('DebugMode.allow_remove_inf',
"True -> do we allow an optimization to replace an inf with another value",
BoolParam(False))
import logging
_logger=logging.getLogger("theano.compile.debugmode")
......@@ -599,7 +602,7 @@ def _lessbroken_deepcopy(a):
assert rval.dtype == a.dtype
return rval
def _find_bad_optimizations0(order, reasons, r_vals):
def _find_bad_optimizations0(order, reasons, r_vals, allow_remove_inf=False):
"""Use a simple algorithm to find broken optimizations.
This algorithm is simple to understand, but sometimes when there's a problem it identifies
......@@ -617,8 +620,8 @@ def _find_bad_optimizations0(order, reasons, r_vals):
new_r_val = r_vals[new_r]
r_val = r_vals[r]
assert r.type == new_r.type
if not r.type.values_eq_approx(r_val, new_r_val):
if not r.type.values_eq_approx(r_val, new_r_val, allow_remove_inf=allow_remove_inf):
raise BadOptimization(old_r=r,
new_r=new_r,
old_r_val=r_val,
......@@ -1045,7 +1048,8 @@ class _Linker(gof.link.LocalLinker):
if s[0] is not None:
print r, s
assert s[0] is None
allow_remove_inf = self.maker.mode.allow_remove_inf
#try:
# compute the value of all variables
for i, (thunk_py, thunk_c, node) in enumerate(zip(thunks_py, thunks_c, order)):
......@@ -1142,7 +1146,7 @@ class _Linker(gof.link.LocalLinker):
#print >> sys.stderr, i, "DEBUGMODE clearing output", r
# compares the version from thunk_py (in r_vals)
# to the version produced by thunk_c (in storage_map)
if not r.type.values_eq_approx(r_vals[r], storage_map[r][0]):
if not r.type.values_eq_approx(r_vals[r], storage_map[r][0], allow_remove_inf=allow_remove_inf):
raise BadCLinkerOutput(r, val_py=r_vals[r], val_c=storage_map[r][0])
else:
#print >> sys.stderr, i, "DEBUGMODE storing reference output %x" % id(storage_map[r][0])
......@@ -1162,7 +1166,7 @@ class _Linker(gof.link.LocalLinker):
if True:
gc.collect()
_find_bad_optimizations(order, env.equivalence_tracker.reasons, r_vals)
_find_bad_optimizations(order, env.equivalence_tracker.reasons, r_vals, allow_remove_inf=allow_remove_inf)
#####
# Postcondition: the input and output variables are in the storage map, nothing more
......@@ -1523,6 +1527,12 @@ class DebugMode(Mode):
different strides? (This can catch bugs, but is generally overly strict.) 0 no check, 1 warn, 2 err.
"""
allow_remove_inf = config.DebugMode.allow_remove_inf
"""
Default False. Do we allow that an optimization remove inf value.
This is usefull to test stabilization optimization.
"""
# This function will be used to create a FunctionMaker in
# function_module.function
def function_maker(self, i,o,m, *args, **kwargs):
......
......@@ -488,7 +488,11 @@ class TensorType(Type):
else:
return False
@staticmethod
def values_eq_approx(a, b):
def values_eq_approx(a, b, allow_remove_inf = False):
"""
:param allow_remove_inf: If True, when their is an inf in a,
we allow any value in b in that position.
"""
if type(a) is numpy.ndarray and type(b) is numpy.ndarray:
if a.shape != b.shape:
return False
......@@ -516,7 +520,9 @@ class TensorType(Type):
# for now we use a home-made recipe, that should probably be
# revisited in the future.
a_missing = numpy.isnan(a)
if not a_missing.any():
a_inf = numpy.isinf(a)
if not (a_missing.any() or (allow_remove_inf and a_inf.any())):
# There are no missing values in a, thus this is not the
# reason why numpy.allclose(a, b) returned False.
_info('numpy allclose failed for abs_err %f and rel_err %f' %(
......@@ -531,6 +537,9 @@ class TensorType(Type):
(atol + rtol * numpy.absolute(b)))
# Find places where both a and b have missing values.
both_missing = a_missing * numpy.isnan(b)
if allow_remove_inf:
both_missing += a_inf
# Combine all information.
return (cmp_elemwise + both_missing).all()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论