提交 5cde2e60 authored 作者: Frederic Bastien's avatar Frederic Bastien

Implemented a new system that allow an optimization to specify a…

Implemented a new system that allow an optimization to specify a values_eq_approx to use for a variable. This is usefull to make DebugMode accept that a stabilization optimization remove inf or nan from the graph.
上级 86d1f68b
......@@ -626,7 +626,11 @@ def _find_bad_optimizations0(order, reasons, r_vals):
r_val = r_vals[r]
assert r.type == new_r.type
if not r.type.values_eq_approx(r_val, new_r_val):
if hasattr(new_r,'values_eq_approx'):
check = new_r.values_eq_approx(r_val, new_r_val)
else:
check = r.type.values_eq_approx(r_val, new_r_val)
if not check:
raise BadOptimization(old_r=r,
new_r=new_r,
old_r_val=r_val,
......
......@@ -562,6 +562,18 @@ class TensorType(Type):
return False
@staticmethod
def values_eq_approx_remove_inf(a, b):
return TensorType.values_eq_approx(a,b,True)
@staticmethod
def values_eq_approx_remove_nan(a, b):
return TensorType.values_eq_approx(a,b,False,True)
@staticmethod
def values_eq_approx_remove_inf_nan(a, b):
return TensorType.values_eq_approx(a,b,True,True)
def __hash__(self):
"""Hash equal for same kinds of TensorType"""
return hashtype(self) ^ hash(self.dtype) ^ hash(self.broadcastable)
......
......@@ -2353,7 +2353,10 @@ def local_log_add(node):
if len(pre_exp) == len(zi):
# all arguments to add are exp(<something>)
max_pre = T.maximum(*pre_exp)
return [max_pre + T.log1p(T.exp(T.add(*[p - max_pre for p in pre_exp])))]
ret = max_pre + T.log1p(T.exp(T.add(*[p - max_pre for p in pre_exp])))
ret.values_eq_approx = ret.type.values_eq_approx_remove_inf
return [ret]
def add_calculate(num, denum, aslist = False, out_type=None):
#TODO: make sure that this function and mul_calculate are similar
......@@ -2698,7 +2701,9 @@ def local_log_erfc(node):
elif node.outputs[0].dtype=='float64':
threshold = 26.641747557
return [T.switch(x<threshold,node.outputs[0],stab_value)]
ret = T.switch(x<threshold,node.outputs[0],stab_value)
ret.values_eq_approx = ret.type.values_eq_approx_remove_inf
return [ret]
#Stability optimization of the grad of log(erfc(x))
......@@ -2838,7 +2843,7 @@ def local_grad_log_erfc_neg(node):
elif x.dtype=='float64':
threshold = 26.641747557
ret = T.switch(x<threshold,true_div_no_mul,stab_value)*y
#ret.values_eq_approx = ret.type.values_eq_approx_remove_inf_nan
ret.values_eq_approx = ret.type.values_eq_approx_remove_nan
return [ret]
"""
......
......@@ -1049,15 +1049,24 @@ def test_log_add():
m = 'FAST_RUN'
m = compile.mode.get_mode(m)
m = m.excluding('fusion')
m = copy.copy(m)
#No need to put them back as we have a new object
m.check_isfinite=False
# check some basic cases
x = dvector()
y = dvector()
f = function([x,y], T.log(T.exp(x) + T.exp(y)), mode=m)
theano.printing.debugprint( f)
print f([10000], [10000]) # causes overflow if handled incorrectly
assert numpy.isfinite(f([10000], [10000]))
assert numpy.allclose(f([10000], [10000]), 10000+numpy.log1p(1))
#test that it give the same result when it don't overflow
print f([10], [10]) # don't causes overflow
assert numpy.allclose(f([10], [10]), 10+numpy.log1p(1))
# test that it also works with more than two args, (this currently fails)
x = dvector()
......@@ -1068,7 +1077,7 @@ def test_log_add():
try:
print f([10000], [10000]) # causes overflow if handled incorrectly
assert numpy.allclose(f([10000], [10000]), 20000)
except:
except AssertionError:
raise KnownFailureTest
#TODO: test that the optimization works in the presence of broadcasting.
......@@ -1781,10 +1790,8 @@ class T_local_erfc(unittest.TestCase):
#their is some nan that will happear in the graph for the log of the negatives values
mode = copy.copy(self.mode)
mode.check_isfinite = False
mode.allow_remove_inf = True
mode_fusion = copy.copy(self.mode_fusion)
mode_fusion.check_isfinite = False
mode_fusion.allow_remove_inf = True
f = theano.function([x],T.log(T.erfc(x)), mode=mode)
#theano.printing.debugprint(f)
......@@ -1823,10 +1830,8 @@ class T_local_erfc(unittest.TestCase):
#their is some nan that will happear in the graph for the log of the negatives values
mode = copy.copy(self.mode)
mode.check_isfinite = False
mode.allow_remove_inf = True
mode_fusion = copy.copy(self.mode_fusion)
mode_fusion.check_isfinite = False
mode_fusion.allow_remove_inf = True
f = theano.function([x],T.grad(T.log(T.erfc(x)),x), mode=mode)
#theano.printing.debugprint(f)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论