提交 efa90388 authored 作者: Frederic Bastien's avatar Frederic Bastien

add a stabilization optimization for the grad of log(erfc).

上级 8bd24b6b
...@@ -2589,6 +2589,105 @@ def local_log_erfc(node): ...@@ -2589,6 +2589,105 @@ def local_log_erfc(node):
return [T.switch(x<threshold,node.outputs[0],stab_value)] return [T.switch(x<threshold,node.outputs[0],stab_value)]
#Stability optimization of the grad of log(erfc(x))
#([cst*]exp(-(x**2)))/erfc(x) # The cst* is optional
#exp(x**2)/erfc(-x) => when x>threashold, sqrt(pi)*-x/(1-1/(2*x**2)+3/(4*x**4)-15/(8*x**6))
#for float64: threshold=26.63 see at the end of the fct for the explaination
#for float32: threshold=9.3 see at the end of the fct for the explaination
#TODO: remove the contraint that their is only 2 inputs to mul and the exp(x**2) is the second.
#TODO: at the test point 10 in float32, their is instability in the original value.
# the original give -30.0, the stab -20.1 and in float64 -18.1.
# Make the test don't generate error in that case!
@register_stabilize
@register_specialize
@gof.local_optimizer([T.true_div])
def local_grad_log_erfc_neg(node):
if node.op!=T.true_div:
return False
if not node.inputs[1].owner or node.inputs[1].owner.op != T.erfc:
return False
erfc = node.inputs[1]
if not node.inputs[0].owner:
return False
#The mul is optional.
if node.inputs[0].owner.op != T.mul:
mul = None
cst = 1
if not node.inputs[0].owner or node.inputs[0].owner.op != T.exp:
return False
exp = node.inputs[0]
else:
mul = node.inputs[0]
if mul.owner.inputs[0].owner or len(mul.owner.inputs)!=2:
return False
cst = mul.owner.inputs[0]
if not mul.owner.inputs[1].owner or mul.owner.inputs[1].owner.op != T.exp:
return False
exp = mul.owner.inputs[1]
if not exp.owner.inputs[0].owner or exp.owner.inputs[0].owner.op != T.neg:
return False
neg = exp.owner.inputs[0]
if not neg.owner.inputs[0].owner or neg.owner.inputs[0].owner.op != T.sqr:
return False
sqr = neg.owner.inputs[0]
x = sqr.owner.inputs[0]
if hasattr(node.tag, 'local_grad_log_erfc_neg'):
#We use that flag to don't apply the optimization recursively
return False
#we move the cst outside the div.
true_div_no_mul = T.true_div(exp,erfc)
true_div_no_mul.owner.tag.local_grad_log_erfc_neg=True
#aaron value
stab_value = x*T.pow(1-1/(2*(x**2))+3/(4*(x**4))-15/(8*(x**6)),-1)*T.cast(T.sqrt(numpy.pi),dtype=x.dtype)
if x.dtype=='float32':
threshold = 9.3
#threshold = 10.1
elif x.dtype=='float64':
threshold = 26.641747557
ret = T.switch(x<threshold,true_div_no_mul,stab_value)*cst
return [ret]
"""
The libm used for the test is amdlibm
#([cst*]exp(-(x**2)))/erfc(x) # The mul is optional
#exp(x**2)/erfc(-x) => when x>threashold, -x*(1-1/(2*x**2)+3/(4*x**4)-15/(8*x**6))*sqrt(pi)
#for float64: threshold=26.63 see below
#for float32: threshold=9.3 see below
#TODO remove the contraint that their is only 2 inputs to mul
#TODO: should we cast numpy.pi to x.dtype?
#float32 threshold 9.3 as the approximation is more precise at that point and more stable.
import numpy, scipy.special
r = numpy.arange(9,10.06,.01)
p64=[(numpy.exp(-(x**2)))/scipy.special.erfc(x)for x in r]
p32=[(numpy.exp(-(x**2)))/scipy.special.erfc(x)for x in numpy.asarray(r,dtype='float32')]
a64=[x*((1-1/(2*x**2)+3/(4*x**4)-15/(8*x**6))**(-1))*numpy.sqrt(numpy.pi) for x in r]
a32=[x*((1-1/(2*x**2)+3/(4*x**4)-15/(8*x**6))**(-1))*numpy.float32(numpy.sqrt(numpy.pi)) for x in numpy.asarray(r,dtype='float32')]
for idx,(a,b,c,d,e) in enumerate(zip(r,p64,p32,a64,a32)):print a,b,c,d,e,c-b,e-b,numpy.absolute(c-b)<numpy.absolute(e-b)
for i in range(1,len(p32)): print r[i], p32[i]-p32[i-1]#, show that the value don't look stable at some point before inf.
#float64 threshold is 26.63 the approx seam more precise at that point.
r = numpy.arange(26.2,26.7,.001)
#scipy.special.erfc(numpy.float128(x)) don't work
#p128=[(numpy.exp(-(x**2)))/scipy.special.erfc(x)for x in numpy.float128(r)]
#those value have been computed with g++ theano/misc/erfc_stability_threshold.c && ./a.out
p128=numpy.float128(['46.47206725', '46.47383842', '46.47560959', '46.47738076', '46.47915193', '46.48092309', '46.48269426', '46.48446543', '46.48623660', '46.48800777', '46.48977894', '46.49155011', '46.49332128', '46.49509245', '46.49686362', '46.49863479', '46.50040596', '46.50217713', '46.50394830', '46.50571947', '46.50749064', '46.50926181', '46.51103298', '46.51280415', '46.51457532', '46.51634649', '46.51811766', '46.51988883', '46.52166000', '46.52343118', '46.52520235', '46.52697352', '46.52874469', '46.53051586', '46.53228703', '46.53405820', '46.53582938', '46.53760055', '46.53937172', '46.54114289', '46.54291407', '46.54468524', '46.54645641', '46.54822758', '46.54999876', '46.55176993', '46.55354110', '46.55531227', '46.55708345', '46.55885462', '46.56062579', '46.56239697', '46.56416814', '46.56593931', '46.56771049', '46.56948166', '46.57125283', '46.57302401', '46.57479518', '46.57656636', '46.57833753', '46.58010871', '46.58187988', '46.58365105', '46.58542223', '46.58719340', '46.58896458', '46.59073575', '46.59250693', '46.59427810', '46.59604928', '46.59782045', '46.59959163', '46.60136280', '46.60313398', '46.60490516', '46.60667633', '46.60844751', '46.61021868', '46.61198986', '46.61376104', '46.61553221', '46.61730339', '46.61907456', '46.62084574', '46.62261692', '46.62438809', '46.62615927', '46.62793045', '46.62970163', '46.63147280', '46.63324398', '46.63501516', '46.63678633', '46.63855751', '46.64032869', '46.64209987', '46.64387104', '46.64564222', '46.64741340', '46.64918458', '46.65095576', '46.65272693', '46.65449811', '46.65626929', '46.65804047', '46.65981165', '46.66158283', '46.66335401', '46.66512519', '46.66689636', '46.66866754', '46.67043872', '46.67220990', '46.67398108', '46.67575226', '46.67752344', '46.67929462', '46.68106580', '46.68283698', '46.68460816', '46.68637934', '46.68815052', '46.68992170', '46.69169288', '46.69346406', '46.69523524', '46.69700642', '46.69877760', '46.70054878', '46.70231997', '46.70409115', '46.70586233', '46.70763351', '46.70940469', '46.71117587', '46.71294705', '46.71471824', '46.71648942', '46.71826060', '46.72003178', '46.72180296', '46.72357414', '46.72534533', '46.72711651', '46.72888769', '46.73065887', '46.73243006', '46.73420124', '46.73597242', '46.73774361', '46.73951479', '46.74128597', '46.74305715', '46.74482834', '46.74659952', '46.74837070', '46.75014189', '46.75191307', '46.75368426', '46.75545544', '46.75722662', '46.75899781', '46.76076899', '46.76254018', '46.76431136', '46.76608254', '46.76785373', '46.76962491', '46.77139610', '46.77316728', '46.77493847', '46.77670965', '46.77848084', '46.78025202', '46.78202321', '46.78379439', '46.78556558', '46.78733677', '46.78910795', '46.79087914', '46.79265032', '46.79442151', '46.79619269', '46.79796388', '46.79973507', '46.80150625', '46.80327744', '46.80504863', '46.80681981', '46.80859100', '46.81036219', '46.81213337', '46.81390456', '46.81567575', '46.81744693', '46.81921812', '46.82098931', '46.82276050', '46.82453168', '46.82630287', '46.82807406', '46.82984525', '46.83161644', '46.83338762', '46.83515881', '46.83693000', '46.83870119', '46.84047238', '46.84224357', '46.84401475', '46.84578594', '46.84755713', '46.84932832', '46.85109951', '46.85287070', '46.85464189', '46.85641308', '46.85818427', '46.85995546', '46.86172665', '46.86349784', '46.86526903', '46.86704022', '46.86881141', '46.87058260', '46.87235379', '46.87412498', '46.87589617', '46.87766736', '46.87943855', '46.88120974', '46.88298093', '46.88475212', '46.88652331', '46.88829450', '46.89006569', '46.89183688', '46.89360807', '46.89537927', '46.89715046', '46.89892165', '46.90069284', '46.90246403', '46.90423522', '46.90600642', '46.90777761', '46.90954880', '46.91131999', '46.91309119', '46.91486238', '46.91663357', '46.91840476', '46.92017596', '46.92194715', '46.92371834', '46.92548953', '46.92726073', '46.92903192', '46.93080311', '46.93257431', '46.93434550', '46.93611669', '46.93788789', '46.93965908', '46.94143028', '46.94320147', '46.94497266', '46.94674386', '46.94851505', '46.95028625', '46.95205744', '46.95382864', '46.95559983', '46.95737103', '46.95914222', '46.96091341', '46.96268461', '46.96445581', '46.96622700', '46.96799820', '46.96976939', '46.97154059', '46.97331178', '46.97508298', '46.97685417', '46.97862537', '46.98039657', '46.98216776', '46.98393896', '46.98571015', '46.98748135', '46.98925255', '46.99102374', '46.99279494', '46.99456614', '46.99633733', '46.99810853', '46.99987973', '47.00165092', '47.00342212', '47.00519332', '47.00696452', '47.00873571', '47.01050691', '47.01227811', '47.01404931', '47.01582050', '47.01759170', '47.01936290', '47.02113410', '47.02290530', '47.02467649', '47.02644769', '47.02821889', '47.02999009', '47.03176129', '47.03353249', '47.03530369', '47.03707489', '47.03884608', '47.04061728', '47.04238848', '47.04415968', '47.04593088', '47.04770208', '47.04947328', '47.05124448', '47.05301568', '47.05478688', '47.05655808', '47.05832928', '47.06010048', '47.06187168', '47.06364288', '47.06541408', '47.06718528', '47.06895648', '47.07072768', '47.07249888', '47.07427009', '47.07604129', '47.', '47.07958369', '47.08135489', '47.08312609', '47.08489729', '47.08666850', '47.08843970', '47.09021090', '47.09198210', '47.09375330', '47.09552450', '47.09729571', '47.09906691', '47.10083811', '47.10260931', '47.10438052', '47.10615172', '47.10792292', '47.10969412', '47.11146533', '47.11323653', '47.11500773', '47.11677894', '47.11855014', '47.12032134', '47.12209255', '47.12386375', '47.12563495', '47.12740616', '47.12917736', '47.13094857', '47.13271977', '47.13449097', '47.13626218', '47.13803338', '47.13980459', '47.14157579', '47.14334700', '47.14511820', '47.14688941', '47.14866061', '47.15043182', '47.15220302', '47.15397423', '47.15574543', '47.15751664', '47.15928784', '47.16105905', '47.16283025', '47.16460146', '47.16637266', '47.16814387', '47.16991508', '47.17168628', '47.17345749', '47.17522869', '47.17699990', '47.17877111', '47.18054231', '47.18231352', '47.18408473', '47.18585593', '47.18762714', '47.18939835', '47.19116956', '47.19294076', '47.19471197', '47.19648318', '47.19825439', '47.20002559', '47.20179680', '47.20356801', '47.20533922', '47.20711042', '47.20888163', '47.21065284', '47.21242405', '47.21419526', '47.21596647', '47.21773767', '47.21950888', '47.22128009', '47.22305130', '47.22482251', '47.22659372', '47.22836493', '47.23013614', '47.23190735', '47.23367855', '47.23544976', '47.23722097', '47.23899218', '47.24076339', '47.24253460', '47.24430581', '47.24607702', '47.24784823', '47.24961944', '47.25139065', '47.25316186', '47.25493307', '47.25670429', '47.25847550', '47.26024671', '47.26201792', '47.26378913', '47.26556034', '47.26733155', '47.26910276', '47.27087397', '47.27264518', '47.27441640', '47.27618761', '47.27795882', '47.27973003', '47.28150124', '47.28327246', '47.28504367', '47.28681488', '47.28858609', '47.29035730', '47.29212852', '47.29389973', '47.29567094', '47.29744215', '47.29921337', '47.30098458', '47.30275579', '47.30452701', '47.30629822', '47.30806943', '47.30984065', '47.31161186', '47.31338307', '47.31515429', '47.31692550', '47.31869671', '47.32046793', '47.32223914', '47.32401036', '47.32578157', '47.32755278', '47.32932400', '47.33109521', '47.33286643', '47.33463764', '47.33640886', '47.33818007', '47.33995129', '47.34172250', '47.34349372', '47.34526493', '47.34703615', '47.34880736', '47.35057858', '47.35234979', '47.35412101', '47.35589223'])
p64=[(numpy.exp(-(x**2)))/scipy.special.erfc(x)for x in r]
a128=[x*((1-1/(2*x**2)+3/(4*x**4)-15/(8*x**6))**(-1))*numpy.float128(numpy.sqrt(numpy.pi)) for x in numpy.asarray(r,dtype='float128')]
a64=[x*((1-1/(2*x**2)+3/(4*x**4)-15/(8*x**6)+63/(7*x**8))**(-1))*numpy.sqrt(numpy.pi) for x in r]
for a,b,c,d in zip(r,p128,p64,a64):print a,b,c,d,c-b,d-b
for i in range(1,len(p64)): print i, 64[i]-p64[i-1]
"""
# ############### # ###############
# # Loop fusion # # # Loop fusion #
# ############### # ###############
......
...@@ -1754,6 +1754,61 @@ class T_local_erfc(unittest.TestCase): ...@@ -1754,6 +1754,61 @@ class T_local_erfc(unittest.TestCase):
raise KnownFailureTest("the python code upcast somewhere internally some value of float32 to python float for part of its computation. That make that the c and python code don't generate the same value. You can ignore this error.") raise KnownFailureTest("the python code upcast somewhere internally some value of float32 to python float for part of its computation. That make that the c and python code don't generate the same value. You can ignore this error.")
assert all(numpy.isfinite(f(val))) assert all(numpy.isfinite(f(val)))
def test_local_grad_log_erfc_neg(self):
val = [-100,-30,-27,-26.4,-26.2,-26,-11,-10,-9,-3,-2,-1,0,1,2,3,9,10,11,27,26.4,26.2,26,28,30,100]
if theano.config.mode in ["DebugMode", "DEBUG_MODE", "FAST_COMPILE"]:
#python mode don't like the inv(0) in computation, but the switch don't select this value. So it is computed for no good reason.
val.remove(0)
if theano.config.mode in ["DebugMode", "DEBUG_MODE"] and theano.config.floatX=='float32':
# In float32 their is a plage of values close to 10 that we stabilize as it give bigger error then the stabilized version.
# The orig value in float32 -30.0, the stab value -20.1 the orig value in float64 -18.1.
val.remove(10)
val = numpy.asarray(val)
x = T.vector()
#their is some nan that will happear in the graph for the log of the negatives values
mode = copy.copy(self.mode)
mode.check_isfinite = False
mode.allow_remove_inf = True
mode_fusion = copy.copy(self.mode_fusion)
mode_fusion.check_isfinite = False
mode_fusion.allow_remove_inf = True
f = theano.function([x],T.grad(T.log(T.erfc(x)),x), mode=mode)
#theano.printing.debugprint(f)
assert len(f.maker.env.nodes)==22, len(f.maker.env.nodes)
assert not any([hasattr(n.op,'scalar_op') and n.op.scalar_op==scal.pow for n in f.maker.env.nodes])
assert all(numpy.isfinite(f(val)))
assert f.maker.env.outputs[0].dtype==theano.config.floatX
#test with a different mul constant
f = theano.function([x],T.mul(T.exp(T.neg(T.sqr(x))),-10.12837917)/T.erfc(x), mode=mode)
#theano.printing.debugprint(f)
assert len(f.maker.env.nodes)==22, len(f.maker.env.nodes)
assert f.maker.env.outputs[0].dtype==theano.config.floatX
assert not any([hasattr(n.op,'scalar_op') and n.op.scalar_op==scal.pow for n in f.maker.env.nodes])
assert all(numpy.isfinite(f(val)))
#test that we work without the mul
f = theano.function([x],T.exp(T.neg(T.sqr(x)))/T.erfc(x), mode=mode)
theano.printing.debugprint(f)
assert len(f.maker.env.nodes)==21, len(f.maker.env.nodes)
assert f.maker.env.outputs[0].dtype==theano.config.floatX
assert not any([hasattr(n.op,'scalar_op') and n.op.scalar_op==scal.pow for n in f.maker.env.nodes])
assert all(numpy.isfinite(f(val)))
f = theano.function([x],T.grad(T.log(T.erfc(x)),x), mode=mode_fusion)
assert len(f.maker.env.nodes)==1, len(f.maker.env.nodes)
assert f.maker.env.outputs[0].dtype==theano.config.floatX
assert not any([hasattr(n.op,'scalar_op') and n.op.scalar_op==scal.pow for n in f.maker.env.nodes])
#TODO: fix this problem
if theano.config.floatX=="float32" and theano.config.mode in ["DebugMode", "DEBUG_MODE"]:
#Showing this test error is a duplicate of the one in test_local_log_erfc. We hide it.
#raise KnownFailureTest("the python code upcast somewhere internally some value of float32 to python float for part of its computation. That make that the c and python code don't generate the same value. You can ignore this error. This happen in an intermediate step that don't show in the final result.")
pass
else:
assert all(numpy.isfinite(f(val)))
def speed_local_log_erfc(self): def speed_local_log_erfc(self):
val = numpy.random.rand(1e6) val = numpy.random.rand(1e6)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论