提交 cf208ead authored 作者: Frederic Bastien's avatar Frederic Bastien

in the stability optimization of the grad of log(erfc(x)) accept more equivalent…

in the stability optimization of the grad of log(erfc(x)) accept more equivalent type of graph and test that the opt are done in the stabilization phase.
上级 80209cb4
......@@ -2633,7 +2633,7 @@ def local_grad_log_erfc_neg(node):
if not exp.owner.inputs[0].owner:
return False
#import pdb;pdb.set_trace()
if exp.owner.inputs[0].owner.op == T.neg:
neg = exp.owner.inputs[0]
if not neg.owner.inputs[0].owner or neg.owner.inputs[0].owner.op != T.sqr:
......@@ -2641,26 +2641,50 @@ def local_grad_log_erfc_neg(node):
sqr = neg.owner.inputs[0]
x = sqr.owner.inputs[0]
elif exp.owner.inputs[0].owner.op == T.mul:
#in many case the neg are replaced by mul.
#in many case the neg are replaced by mul in the graph.
#This also allow to stabilize log(erfc(cst*x))
mul_neg = exp.owner.inputs[0]
#in case that multiple mul are not fused together, we do it here.
def check_input(inputs):
new_inputs=[]
for i in inputs:
if i.owner and i.owner.op == T.mul:
new_inputs.extend(check_input(i.owner.inputs))
else:
new_inputs.append(i)
return new_inputs
mul_inputs = check_input(mul_neg.owner.inputs)
#put the constant first
for i in range(len(mul_inputs)):
if isinstance(i, Constant):
if i==0:
break
else:
tmp=mul_inputs[0]
mul_inputs[0]=mul_inputs[i]
mul_inputs[i]=tmp
break
mul_neg = T.mul(*mul_inputs)
try:
cst = get_constant_value(mul_neg.owner.inputs[0])
except TypeError:
return False
if cst>=0:
if cst!=-1:
return False#todo implement that case
if len(mul_neg.owner.inputs) == 2:
if not mul_neg.owner.inputs[1].owner or mul_neg.owner.inputs[1].owner.op != T.sqr:
return False
sqr = mul_neg.owner.inputs[1]
x = sqr.owner.inputs[0]
#elif len(mul_neg.owner.inputs) == 3:
# if mul_neg.owner.inputs[1] is mul_neg.owner.inputs[2]:
# return False
# x = mul_neg.owner.inputs[1]
# import pdb;pdb.set_trace()
# return False
elif len(mul_neg.owner.inputs) == 3:
if mul_neg.owner.inputs[1] is not mul_neg.owner.inputs[2]:
return False
x = mul_neg.owner.inputs[1]
else:
return False
else:
return False
......
......@@ -1670,14 +1670,15 @@ class T_local_erf(unittest.TestCase):
class T_local_erfc(unittest.TestCase):
def setUp(self):
self.mode = theano.compile.mode.get_default_mode().including('canonicalize').including('fast_run').excluding('fusion').excluding('gpu')
self.mode_fusion = theano.compile.mode.get_default_mode().including('canonicalize').including('fast_run').excluding('gpu')
self.mode = self.mode_fusion.excluding('fusion')
self.mode._optimizer.position_cutoff = 1.50001
def test_local_one_minus_erfc(self):
""" test opt: 1-erfc(x) => erf(x) and -erfc(x)+1 => erf(x)
"""
val = numpy.asarray([-30,-3,-2,-1,0,1,2,3,30])
x = T.vector()
x = T.vector('x')
f = theano.function([x],1-T.erfc(x), mode=self.mode)
theano.printing.debugprint(f)
......@@ -1701,7 +1702,7 @@ class T_local_erfc(unittest.TestCase):
def test_local_erf_neg_minus_one(self):
""" test opt: (-1)+erfc(-x)=>erf(x)"""
val = numpy.asarray([-30,-3,-2,-1,0,1,2,3,30])
x = T.vector()
x = T.vector('x')
f = theano.function([x],-1+T.erfc(-x), mode=self.mode)
theano.printing.debugprint(f)
......@@ -1724,7 +1725,7 @@ class T_local_erfc(unittest.TestCase):
#python mode don't like the inv(0)
val.remove(0)
val = numpy.asarray(val)
x = T.vector()
x = T.vector('x')
#their is some nan that will happear in the graph for the log of the negatives values
mode = copy.copy(self.mode)
......@@ -1767,8 +1768,8 @@ class T_local_erfc(unittest.TestCase):
# The orig value in float32 -30.0, the stab value -20.1 the orig value in float64 -18.1.
val.remove(10)
val = numpy.asarray(val)
x = T.vector()
y = T.vector()
x = T.vector('x')
y = T.vector('y')
#their is some nan that will happear in the graph for the log of the negatives values
mode = copy.copy(self.mode)
......@@ -1780,25 +1781,22 @@ class T_local_erfc(unittest.TestCase):
f = theano.function([x],T.grad(T.log(T.erfc(x)),x), mode=mode)
#theano.printing.debugprint(f)
assert len(f.maker.env.nodes)==22, len(f.maker.env.nodes)
assert not any([hasattr(n.op,'scalar_op') and n.op.scalar_op==scal.pow for n in f.maker.env.nodes])
assert len(f.maker.env.nodes)==23, len(f.maker.env.nodes)
assert all(numpy.isfinite(f(val)))
assert f.maker.env.outputs[0].dtype==theano.config.floatX
#test with a different mul constant
f = theano.function([x],T.mul(T.exp(T.neg(T.sqr(x))),-10.12837917)/T.erfc(x), mode=mode)
#theano.printing.debugprint(f)
assert len(f.maker.env.nodes)==22, len(f.maker.env.nodes)
assert len(f.maker.env.nodes)==23, len(f.maker.env.nodes)
assert f.maker.env.outputs[0].dtype==theano.config.floatX
assert not any([hasattr(n.op,'scalar_op') and n.op.scalar_op==scal.pow for n in f.maker.env.nodes])
assert all(numpy.isfinite(f(val)))
#test that we work without the mul
f = theano.function([x],T.exp(T.neg(T.sqr(x)))/T.erfc(x), mode=mode)
#theano.printing.debugprint(f)
assert len(f.maker.env.nodes)==21, len(f.maker.env.nodes)
assert len(f.maker.env.nodes)==23, len(f.maker.env.nodes)
assert f.maker.env.outputs[0].dtype==theano.config.floatX
assert not any([hasattr(n.op,'scalar_op') and n.op.scalar_op==scal.pow for n in f.maker.env.nodes])
assert all(numpy.isfinite(f(val)))
#test that we don't work if x!=y
......@@ -1806,21 +1804,19 @@ class T_local_erfc(unittest.TestCase):
#theano.printing.debugprint(f)
assert len(f.maker.env.nodes)==5, len(f.maker.env.nodes)
assert f.maker.env.outputs[0].dtype==theano.config.floatX
assert not any([hasattr(n.op,'scalar_op') and n.op.scalar_op==scal.pow for n in f.maker.env.nodes])
f(val,val-3)
#test that we work without the sqr and neg
f = theano.function([x],T.exp(T.mul(-1,x,x))/T.erfc(x), mode=mode)
#theano.printing.debugprint(f)
assert len(f.maker.env.nodes)==21, len(f.maker.env.nodes)
assert len(f.maker.env.nodes)==22, len(f.maker.env.nodes)
assert f.maker.env.outputs[0].dtype==theano.config.floatX
assert not any([hasattr(n.op,'scalar_op') and n.op.scalar_op==scal.pow for n in f.maker.env.nodes])
assert all(numpy.isfinite(f(val)))
f = theano.function([x],T.grad(T.log(T.erfc(x)),x), mode=mode_fusion)
assert len(f.maker.env.nodes)==1, len(f.maker.env.nodes)
assert f.maker.env.outputs[0].dtype==theano.config.floatX
assert not any([hasattr(n.op,'scalar_op') and n.op.scalar_op==scal.pow for n in f.maker.env.nodes])
#TODO: fix this problem
if theano.config.floatX=="float32" and theano.config.mode in ["DebugMode", "DEBUG_MODE"]:
#Showing this test error is a duplicate of the one in test_local_log_erfc. We hide it.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论