提交 5d1a9db8 authored 作者: Frederic's avatar Frederic

Generalize local_grad_log_erfc_neg as it is now needed by other opt change

上级 037fca35
......@@ -5523,8 +5523,8 @@ def constant_folding(node):
return rval
topo_constant_folding=in2out(constant_folding, ignore_newtrees=False,
name="topo_constant_folding")
topo_constant_folding = in2out(constant_folding, ignore_newtrees=False,
name="topo_constant_folding")
register_canonicalize(topo_constant_folding, 'fast_compile', final_opt=True)
register_stabilize(topo_constant_folding, 'fast_compile', final_opt=True)
register_specialize(topo_constant_folding, 'fast_compile', final_opt=True)
......@@ -5769,7 +5769,7 @@ def local_log_erfc(node):
# sqrt(pi)*-x/(1-1/(2*x**2)+3/(4*x**4)-15/(8*x**6)))
# for float64: threshold=26.63 see at the end of the fct for the explaination
# for float32: threshold=9.3 see at the end of the fct for the explaination
# TODO: remove the contraint that there are only 2 inputs to mul and exp(x**2)
# TODO: remove the contraint that there are only 2 inputs to exp(x**2)
# is the second.
# TODO: at the test point 10 in float32, there is instability in the original
# value. The original gives -30.0, the stab -20.1 and in float64 -18.1.
......@@ -5796,14 +5796,17 @@ def local_grad_log_erfc_neg(node):
exp = node.inputs[0]
else:
mul = node.inputs[0]
if mul.owner.inputs[0].owner or len(mul.owner.inputs) != 2:
return False
y = mul.owner.inputs[0]
if (not mul.owner.inputs[1].owner or
mul.owner.inputs[1].owner.op != T.exp):
return False
exp = mul.owner.inputs[1]
exp = None
for idx, inp in enumerate(mul.owner.inputs):
if inp.owner and inp.owner.op == T.exp:
exp = inp
break
if len(mul.owner.inputs) == 2:
y = [mul.owner.inputs[1-idx]]
else:
y = mul.owner.inputs[:]
del y[idx]
del mul
if not exp.owner.inputs[0].owner:
return False
......@@ -5905,9 +5908,10 @@ def local_grad_log_erfc_neg(node):
# threshold = 10.1
elif x.dtype == 'float64':
threshold = 26.641747557
ret = T.switch(x < threshold, true_div_no_mul, stab_value) * y
ret = T.switch(x < threshold, true_div_no_mul, stab_value)
if y != 1:
ret = T.mul(ret, *y)
ret.values_eq_approx = values_eq_approx_remove_inf_nan
return [ret]
"""
The libm used for the test is amdlibm
......
......@@ -4532,7 +4532,8 @@ class T_local_erfc(unittest.TestCase):
mode_fusion.check_isfinite = False
f = theano.function([x], T.grad(T.log(T.erfc(x)).sum(), x), mode=mode)
assert len(f.maker.fgraph.apply_nodes) == 23, len(f.maker.fgraph.apply_nodes)
# The useless alloc in the graph will get removed by later optimization
assert len(f.maker.fgraph.apply_nodes) == 25, len(f.maker.fgraph.apply_nodes)
assert all(numpy.isfinite(f(val)))
assert f.maker.fgraph.outputs[0].dtype == theano.config.floatX
......@@ -4545,7 +4546,7 @@ class T_local_erfc(unittest.TestCase):
# test that we work without the mul
f = theano.function([x], T.exp(T.neg(T.sqr(x))) / T.erfc(x), mode=mode)
assert len(f.maker.fgraph.apply_nodes) == 23, len(f.maker.fgraph.apply_nodes)
assert len(f.maker.fgraph.apply_nodes) == 22, len(f.maker.fgraph.apply_nodes)
assert f.maker.fgraph.outputs[0].dtype == theano.config.floatX
assert all(numpy.isfinite(f(val)))
......@@ -4558,14 +4559,15 @@ class T_local_erfc(unittest.TestCase):
# test that we work without the sqr and neg
f = theano.function([x], T.exp(T.mul(-1, x, x)) / T.erfc(x), mode=mode)
assert len(f.maker.fgraph.apply_nodes) == 22, len(f.maker.fgraph.apply_nodes)
assert len(f.maker.fgraph.apply_nodes) == 21, len(f.maker.fgraph.apply_nodes)
assert f.maker.fgraph.outputs[0].dtype == theano.config.floatX
assert all(numpy.isfinite(f(val)))
# test that it work correctly if x is x*2 in the graph.
f = theano.function([x], T.grad(T.log(T.erfc(2 * x)).sum(),
x), mode=mode)
assert len(f.maker.fgraph.apply_nodes) == 23, len(f.maker.fgraph.apply_nodes)
# The useless alloc in the graph will get removed by later optimization
assert len(f.maker.fgraph.apply_nodes) == 25, len(f.maker.fgraph.apply_nodes)
assert numpy.isfinite(f(val)).all()
assert f.maker.fgraph.outputs[0].dtype == theano.config.floatX
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论