提交 b1620f88 authored 作者: Frederic Bastien's avatar Frederic Bastien

rename a variable and make the stability optimizer of the grad of log(erfc(x))…

rename a variable and make the stability optimizer of the grad of log(erfc(x)) happen directly with the canonicalized graph
上级 6d0827b7
...@@ -2594,8 +2594,8 @@ def local_log_erfc(node): ...@@ -2594,8 +2594,8 @@ def local_log_erfc(node):
#Stability optimization of the grad of log(erfc(x)) #Stability optimization of the grad of log(erfc(x))
#([cst*]exp(-(x**2)))/erfc(x) # The cst* is optional #([y*]exp(-(x**2)))/erfc(x) # The y* is optional
#exp(x**2)/erfc(-x) => when x>threashold, sqrt(pi)*-x/(1-1/(2*x**2)+3/(4*x**4)-15/(8*x**6)) #([y*]exp(x**2))/erfc(-x) => [y*](when x>threashold, sqrt(pi)*-x/(1-1/(2*x**2)+3/(4*x**4)-15/(8*x**6)))
#for float64: threshold=26.63 see at the end of the fct for the explaination #for float64: threshold=26.63 see at the end of the fct for the explaination
#for float32: threshold=9.3 see at the end of the fct for the explaination #for float32: threshold=9.3 see at the end of the fct for the explaination
#TODO: remove the contraint that their is only 2 inputs to mul and the exp(x**2) is the second. #TODO: remove the contraint that their is only 2 inputs to mul and the exp(x**2) is the second.
...@@ -2618,7 +2618,7 @@ def local_grad_log_erfc_neg(node): ...@@ -2618,7 +2618,7 @@ def local_grad_log_erfc_neg(node):
#The mul is optional. #The mul is optional.
if node.inputs[0].owner.op != T.mul: if node.inputs[0].owner.op != T.mul:
mul = None mul = None
cst = 1 y = 1
if not node.inputs[0].owner or node.inputs[0].owner.op != T.exp: if not node.inputs[0].owner or node.inputs[0].owner.op != T.exp:
return False return False
exp = node.inputs[0] exp = node.inputs[0]
...@@ -2626,18 +2626,44 @@ def local_grad_log_erfc_neg(node): ...@@ -2626,18 +2626,44 @@ def local_grad_log_erfc_neg(node):
mul = node.inputs[0] mul = node.inputs[0]
if mul.owner.inputs[0].owner or len(mul.owner.inputs)!=2: if mul.owner.inputs[0].owner or len(mul.owner.inputs)!=2:
return False return False
cst = mul.owner.inputs[0] y = mul.owner.inputs[0]
if not mul.owner.inputs[1].owner or mul.owner.inputs[1].owner.op != T.exp: if not mul.owner.inputs[1].owner or mul.owner.inputs[1].owner.op != T.exp:
return False return False
exp = mul.owner.inputs[1] exp = mul.owner.inputs[1]
if not exp.owner.inputs[0].owner or exp.owner.inputs[0].owner.op != T.neg: if not exp.owner.inputs[0].owner:
return False return False
#import pdb;pdb.set_trace()
if exp.owner.inputs[0].owner.op == T.neg:
neg = exp.owner.inputs[0] neg = exp.owner.inputs[0]
if not neg.owner.inputs[0].owner or neg.owner.inputs[0].owner.op != T.sqr: if not neg.owner.inputs[0].owner or neg.owner.inputs[0].owner.op != T.sqr:
return False return False
sqr = neg.owner.inputs[0] sqr = neg.owner.inputs[0]
x = sqr.owner.inputs[0] x = sqr.owner.inputs[0]
elif exp.owner.inputs[0].owner.op == T.mul:
#in many case the neg are replaced by mul.
#This also allow to stabilize log(erfc(cst*x))
mul_neg = exp.owner.inputs[0]
try:
cst = get_constant_value(mul_neg.owner.inputs[0])
except TypeError:
return False
if cst>=0:
return False#todo implement that case
if len(mul_neg.owner.inputs) == 2:
if not mul_neg.owner.inputs[1].owner or mul_neg.owner.inputs[1].owner.op != T.sqr:
return False
sqr = mul_neg.owner.inputs[1]
x = sqr.owner.inputs[0]
#elif len(mul_neg.owner.inputs) == 3:
# if mul_neg.owner.inputs[1] is mul_neg.owner.inputs[2]:
# return False
# x = mul_neg.owner.inputs[1]
# import pdb;pdb.set_trace()
# return False
else:
return False
if hasattr(node.tag, 'local_grad_log_erfc_neg'): if hasattr(node.tag, 'local_grad_log_erfc_neg'):
#We use that flag to don't apply the optimization recursively #We use that flag to don't apply the optimization recursively
...@@ -2646,7 +2672,7 @@ def local_grad_log_erfc_neg(node): ...@@ -2646,7 +2672,7 @@ def local_grad_log_erfc_neg(node):
if erfc_x is not x: if erfc_x is not x:
return False return False
#we move the cst outside the div. #we move the y outside the div.
true_div_no_mul = T.true_div(exp,erfc) true_div_no_mul = T.true_div(exp,erfc)
true_div_no_mul.owner.tag.local_grad_log_erfc_neg=True true_div_no_mul.owner.tag.local_grad_log_erfc_neg=True
...@@ -2658,12 +2684,13 @@ def local_grad_log_erfc_neg(node): ...@@ -2658,12 +2684,13 @@ def local_grad_log_erfc_neg(node):
#threshold = 10.1 #threshold = 10.1
elif x.dtype=='float64': elif x.dtype=='float64':
threshold = 26.641747557 threshold = 26.641747557
ret = T.switch(x<threshold,true_div_no_mul,stab_value)*cst ret = T.switch(x<threshold,true_div_no_mul,stab_value)*y
#ret.values_eq_approx = ret.type.values_eq_approx_remove_inf_nan
return [ret] return [ret]
""" """
The libm used for the test is amdlibm The libm used for the test is amdlibm
#([cst*]exp(-(x**2)))/erfc(x) # The mul is optional #([y*]exp(-(x**2)))/erfc(x) # The mul is optional
#exp(x**2)/erfc(-x) => when x>threashold, -x*(1-1/(2*x**2)+3/(4*x**4)-15/(8*x**6))*sqrt(pi) #exp(x**2)/erfc(-x) => when x>threashold, -x*(1-1/(2*x**2)+3/(4*x**4)-15/(8*x**6))*sqrt(pi)
#for float64: threshold=26.63 see below #for float64: threshold=26.63 see below
#for float32: threshold=9.3 see below #for float32: threshold=9.3 see below
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论