提交 24fd08b6 authored 作者: Frederic Bastien's avatar Frederic Bastien

make the stabilization optimization of log(erfc(x)) work.

上级 f4029f30
......@@ -2707,6 +2707,10 @@ def local_grad_log_erfc_neg(node):
sqr = neg.owner.inputs[0]
x = sqr.owner.inputs[0]
elif exp.owner.inputs[0].owner.op == T.mul:
#We should compare that -(erfc_x**2) is equivalent to mul_neg
#Their is currently not easy way to do this in the general case
#So we implement some common case for now.
#in many case the neg are replaced by mul in the graph.
#This also allow to stabilize log(erfc(cst*x))
mul_neg = exp.owner.inputs[0]
......@@ -2735,11 +2739,10 @@ def local_grad_log_erfc_neg(node):
mul_neg = T.mul(*mul_inputs)
try:
cst = get_constant_value(mul_neg.owner.inputs[0])
cst2 = get_constant_value(mul_neg.owner.inputs[0])
except TypeError:
return False
if cst!=-1:
return False#todo implement that case
if len(mul_neg.owner.inputs) == 2:
if not mul_neg.owner.inputs[1].owner or mul_neg.owner.inputs[1].owner.op != T.sqr:
return False
......@@ -2752,6 +2755,26 @@ def local_grad_log_erfc_neg(node):
else:
return False
if cst2!=-1:
if (not erfc_x.owner or erfc_x.owner.op != T.mul
or len(erfc_x.owner.inputs)!=2):
#todo implement that case
return False
if erfc_x.owner.inputs[1] is not mul_neg.owner.inputs[1]:
return False
x = erfc_x
try:
cst = get_constant_value(erfc_x.owner.inputs[0])
except TypeError:
return False
if cst2 != -cst*2:
return False
#The constant is valid. Must check that the
elif erfc_x is not x:
return False
else:
return False
......@@ -2759,8 +2782,6 @@ def local_grad_log_erfc_neg(node):
#We use that flag to don't apply the optimization recursively
return False
if erfc_x is not x:
return False
#we move the y outside the div.
true_div_no_mul = T.true_div(exp,erfc)
......
......@@ -1863,6 +1863,13 @@ class T_local_erfc(unittest.TestCase):
assert f.maker.env.outputs[0].dtype==theano.config.floatX
assert all(numpy.isfinite(f(val)))
#test that it work correctly if x is x*2 in the graph.
f = theano.function([x],T.grad(T.log(T.erfc(2*x)),x), mode=mode)
#theano.printing.debugprint(f)
assert len(f.maker.env.nodes)==23, len(f.maker.env.nodes)
assert numpy.isfinite(f(val)).all()
assert f.maker.env.outputs[0].dtype==theano.config.floatX
f = theano.function([x],T.grad(T.log(T.erfc(x)),x), mode=mode_fusion)
assert len(f.maker.env.nodes)==1, len(f.maker.env.nodes)
assert f.maker.env.outputs[0].dtype==theano.config.floatX
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论