提交 5d1a9db8 authored 作者: Frederic's avatar Frederic

Generalize local_grad_log_erfc_neg as it is now needed by other opt change

上级 037fca35
...@@ -5523,8 +5523,8 @@ def constant_folding(node): ...@@ -5523,8 +5523,8 @@ def constant_folding(node):
return rval return rval
topo_constant_folding=in2out(constant_folding, ignore_newtrees=False, topo_constant_folding = in2out(constant_folding, ignore_newtrees=False,
name="topo_constant_folding") name="topo_constant_folding")
register_canonicalize(topo_constant_folding, 'fast_compile', final_opt=True) register_canonicalize(topo_constant_folding, 'fast_compile', final_opt=True)
register_stabilize(topo_constant_folding, 'fast_compile', final_opt=True) register_stabilize(topo_constant_folding, 'fast_compile', final_opt=True)
register_specialize(topo_constant_folding, 'fast_compile', final_opt=True) register_specialize(topo_constant_folding, 'fast_compile', final_opt=True)
...@@ -5769,7 +5769,7 @@ def local_log_erfc(node): ...@@ -5769,7 +5769,7 @@ def local_log_erfc(node):
# sqrt(pi)*-x/(1-1/(2*x**2)+3/(4*x**4)-15/(8*x**6))) # sqrt(pi)*-x/(1-1/(2*x**2)+3/(4*x**4)-15/(8*x**6)))
# for float64: threshold=26.63 see at the end of the fct for the explaination # for float64: threshold=26.63 see at the end of the fct for the explaination
# for float32: threshold=9.3 see at the end of the fct for the explaination # for float32: threshold=9.3 see at the end of the fct for the explaination
# TODO: remove the contraint that there are only 2 inputs to mul and exp(x**2) # TODO: remove the contraint that there are only 2 inputs to exp(x**2)
# is the second. # is the second.
# TODO: at the test point 10 in float32, there is instability in the original # TODO: at the test point 10 in float32, there is instability in the original
# value. The original gives -30.0, the stab -20.1 and in float64 -18.1. # value. The original gives -30.0, the stab -20.1 and in float64 -18.1.
...@@ -5796,14 +5796,17 @@ def local_grad_log_erfc_neg(node): ...@@ -5796,14 +5796,17 @@ def local_grad_log_erfc_neg(node):
exp = node.inputs[0] exp = node.inputs[0]
else: else:
mul = node.inputs[0] mul = node.inputs[0]
if mul.owner.inputs[0].owner or len(mul.owner.inputs) != 2: exp = None
return False for idx, inp in enumerate(mul.owner.inputs):
y = mul.owner.inputs[0] if inp.owner and inp.owner.op == T.exp:
if (not mul.owner.inputs[1].owner or exp = inp
mul.owner.inputs[1].owner.op != T.exp): break
return False if len(mul.owner.inputs) == 2:
exp = mul.owner.inputs[1] y = [mul.owner.inputs[1-idx]]
else:
y = mul.owner.inputs[:]
del y[idx]
del mul
if not exp.owner.inputs[0].owner: if not exp.owner.inputs[0].owner:
return False return False
...@@ -5905,9 +5908,10 @@ def local_grad_log_erfc_neg(node): ...@@ -5905,9 +5908,10 @@ def local_grad_log_erfc_neg(node):
# threshold = 10.1 # threshold = 10.1
elif x.dtype == 'float64': elif x.dtype == 'float64':
threshold = 26.641747557 threshold = 26.641747557
ret = T.switch(x < threshold, true_div_no_mul, stab_value) * y ret = T.switch(x < threshold, true_div_no_mul, stab_value)
if y != 1:
ret = T.mul(ret, *y)
ret.values_eq_approx = values_eq_approx_remove_inf_nan ret.values_eq_approx = values_eq_approx_remove_inf_nan
return [ret] return [ret]
""" """
The libm used for the test is amdlibm The libm used for the test is amdlibm
......
...@@ -4532,7 +4532,8 @@ class T_local_erfc(unittest.TestCase): ...@@ -4532,7 +4532,8 @@ class T_local_erfc(unittest.TestCase):
mode_fusion.check_isfinite = False mode_fusion.check_isfinite = False
f = theano.function([x], T.grad(T.log(T.erfc(x)).sum(), x), mode=mode) f = theano.function([x], T.grad(T.log(T.erfc(x)).sum(), x), mode=mode)
assert len(f.maker.fgraph.apply_nodes) == 23, len(f.maker.fgraph.apply_nodes) # The useless alloc in the graph will get removed by later optimization
assert len(f.maker.fgraph.apply_nodes) == 25, len(f.maker.fgraph.apply_nodes)
assert all(numpy.isfinite(f(val))) assert all(numpy.isfinite(f(val)))
assert f.maker.fgraph.outputs[0].dtype == theano.config.floatX assert f.maker.fgraph.outputs[0].dtype == theano.config.floatX
...@@ -4545,7 +4546,7 @@ class T_local_erfc(unittest.TestCase): ...@@ -4545,7 +4546,7 @@ class T_local_erfc(unittest.TestCase):
# test that we work without the mul # test that we work without the mul
f = theano.function([x], T.exp(T.neg(T.sqr(x))) / T.erfc(x), mode=mode) f = theano.function([x], T.exp(T.neg(T.sqr(x))) / T.erfc(x), mode=mode)
assert len(f.maker.fgraph.apply_nodes) == 23, len(f.maker.fgraph.apply_nodes) assert len(f.maker.fgraph.apply_nodes) == 22, len(f.maker.fgraph.apply_nodes)
assert f.maker.fgraph.outputs[0].dtype == theano.config.floatX assert f.maker.fgraph.outputs[0].dtype == theano.config.floatX
assert all(numpy.isfinite(f(val))) assert all(numpy.isfinite(f(val)))
...@@ -4558,14 +4559,15 @@ class T_local_erfc(unittest.TestCase): ...@@ -4558,14 +4559,15 @@ class T_local_erfc(unittest.TestCase):
# test that we work without the sqr and neg # test that we work without the sqr and neg
f = theano.function([x], T.exp(T.mul(-1, x, x)) / T.erfc(x), mode=mode) f = theano.function([x], T.exp(T.mul(-1, x, x)) / T.erfc(x), mode=mode)
assert len(f.maker.fgraph.apply_nodes) == 22, len(f.maker.fgraph.apply_nodes) assert len(f.maker.fgraph.apply_nodes) == 21, len(f.maker.fgraph.apply_nodes)
assert f.maker.fgraph.outputs[0].dtype == theano.config.floatX assert f.maker.fgraph.outputs[0].dtype == theano.config.floatX
assert all(numpy.isfinite(f(val))) assert all(numpy.isfinite(f(val)))
# test that it work correctly if x is x*2 in the graph. # test that it work correctly if x is x*2 in the graph.
f = theano.function([x], T.grad(T.log(T.erfc(2 * x)).sum(), f = theano.function([x], T.grad(T.log(T.erfc(2 * x)).sum(),
x), mode=mode) x), mode=mode)
assert len(f.maker.fgraph.apply_nodes) == 23, len(f.maker.fgraph.apply_nodes) # The useless alloc in the graph will get removed by later optimization
assert len(f.maker.fgraph.apply_nodes) == 25, len(f.maker.fgraph.apply_nodes)
assert numpy.isfinite(f(val)).all() assert numpy.isfinite(f(val)).all()
assert f.maker.fgraph.outputs[0].dtype == theano.config.floatX assert f.maker.fgraph.outputs[0].dtype == theano.config.floatX
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论