提交 616f8d26 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Fix input validation bugs in local_grad_log_erfc_neg

上级 1cfebf4c
...@@ -2710,7 +2710,7 @@ def local_grad_log_erfc_neg(fgraph, node): ...@@ -2710,7 +2710,7 @@ def local_grad_log_erfc_neg(fgraph, node):
"""Stability optimization for the grad of `log(erfc(x))`. """Stability optimization for the grad of `log(erfc(x))`.
([y*]exp(-(x**2)))/erfc(x) # The y* is optional ([y*]exp(-(x**2)))/erfc(x) # The y* is optional
([y*]exp(x**2))/erfc(-x) => [y*](when x>threashold, ([y*]exp(x**2))/erfc(-x) => [y*](when x > threshold,
sqrt(pi)*-x/(1-1/(2*x**2)+3/(4*x**4)-15/(8*x**6))) sqrt(pi)*-x/(1-1/(2*x**2)+3/(4*x**4)-15/(8*x**6)))
for float64: threshold=26.63 see at the end of the fct for the explanation for float64: threshold=26.63 see at the end of the fct for the explanation
...@@ -2727,11 +2727,14 @@ def local_grad_log_erfc_neg(fgraph, node): ...@@ -2727,11 +2727,14 @@ def local_grad_log_erfc_neg(fgraph, node):
return False return False
if not node.inputs[1].owner or node.inputs[1].owner.op != erfc: if not node.inputs[1].owner or node.inputs[1].owner.op != erfc:
return False return False
erfc_in = node.inputs[1] erfc_in = node.inputs[1]
erfc_x = erfc_in.owner.inputs[0] erfc_x = erfc_in.owner.inputs[0]
if not node.inputs[0].owner: if not node.inputs[0].owner:
return False return False
# TODO: All of this should be replaced with a single, simple unification
# The mul is optional. # The mul is optional.
if node.inputs[0].owner.op != mul: if node.inputs[0].owner.op != mul:
mul_in = None mul_in = None
...@@ -2746,12 +2749,15 @@ def local_grad_log_erfc_neg(fgraph, node): ...@@ -2746,12 +2749,15 @@ def local_grad_log_erfc_neg(fgraph, node):
if inp.owner and inp.owner.op == exp: if inp.owner and inp.owner.op == exp:
exp_in = inp exp_in = inp
break break
else:
return False
if len(mul_in.owner.inputs) == 2: if len(mul_in.owner.inputs) == 2:
y = [mul_in.owner.inputs[1 - idx]] y = [mul_in.owner.inputs[1 - idx]]
else: else:
y = mul_in.owner.inputs[:] y = mul_in.owner.inputs[:]
del y[idx] del y[idx]
del mul_in
if not exp_in.owner.inputs[0].owner: if not exp_in.owner.inputs[0].owner:
return False return False
...@@ -2848,6 +2854,9 @@ def local_grad_log_erfc_neg(fgraph, node): ...@@ -2848,6 +2854,9 @@ def local_grad_log_erfc_neg(fgraph, node):
# We use that flag to don't apply the optimization recursively # We use that flag to don't apply the optimization recursively
return False return False
if erfc_x is not x:
return None
# we move the y outside the div. # we move the y outside the div.
true_div_no_mul = true_div(exp_in, erfc_in) true_div_no_mul = true_div(exp_in, erfc_in)
true_div_no_mul.owner.tag.local_grad_log_erfc_neg = True true_div_no_mul.owner.tag.local_grad_log_erfc_neg = True
...@@ -2864,10 +2873,14 @@ def local_grad_log_erfc_neg(fgraph, node): ...@@ -2864,10 +2873,14 @@ def local_grad_log_erfc_neg(fgraph, node):
# threshold = 10.1 # threshold = 10.1
elif x.dtype == "float64": elif x.dtype == "float64":
threshold = 26.641747557 threshold = 26.641747557
ret = switch(x < threshold, true_div_no_mul, stab_value) ret = switch(x < threshold, true_div_no_mul, stab_value)
if y: if y:
ret = mul(ret, *y) ret = mul(ret, *y)
ret.tag.values_eq_approx = values_eq_approx_remove_inf_nan ret.tag.values_eq_approx = values_eq_approx_remove_inf_nan
return [ret] return [ret]
......
...@@ -22,7 +22,7 @@ from aesara.graph.opt import LocalOptGroup, TopoOptimizer, check_stack_trace, ou ...@@ -22,7 +22,7 @@ from aesara.graph.opt import LocalOptGroup, TopoOptimizer, check_stack_trace, ou
from aesara.graph.optdb import Query from aesara.graph.optdb import Query
from aesara.misc.safe_asarray import _asarray from aesara.misc.safe_asarray import _asarray
from aesara.tensor import inplace from aesara.tensor import inplace
from aesara.tensor.basic import Alloc, join from aesara.tensor.basic import Alloc, join, switch
from aesara.tensor.basic_opt import local_dimshuffle_lift from aesara.tensor.basic_opt import local_dimshuffle_lift
from aesara.tensor.blas import Dot22, Gemv from aesara.tensor.blas import Dot22, Gemv
from aesara.tensor.blas_c import CGemv from aesara.tensor.blas_c import CGemv
...@@ -73,6 +73,7 @@ from aesara.tensor.math import sum as tt_sum ...@@ -73,6 +73,7 @@ from aesara.tensor.math import sum as tt_sum
from aesara.tensor.math import tan, tanh, true_div, xor from aesara.tensor.math import tan, tanh, true_div, xor
from aesara.tensor.math_opt import ( from aesara.tensor.math_opt import (
local_add_specialize, local_add_specialize,
local_grad_log_erfc_neg,
local_greedy_distributor, local_greedy_distributor,
mul_canonizer, mul_canonizer,
) )
...@@ -2839,7 +2840,14 @@ class TestLocalErfc: ...@@ -2839,7 +2840,14 @@ class TestLocalErfc:
) )
assert all(np.isfinite(f(val))) assert all(np.isfinite(f(val)))
@np.errstate(divide="ignore", invalid="ignore")
def test_local_grad_log_erfc_neg(self): def test_local_grad_log_erfc_neg(self):
# TODO: This evaluation is questionable; is the transform's math not
# already established? It doesn't look like these tests are preforming
# a real numerical evaluation of the underlying math. Instead, it
# looks like they're being used as an extremely poor way of validating
# the transform results. It would be better to remove these numerical
# evaluations and confirm the transform output directly and exactly.
val = [ val = [
-100, -100,
-30, -30,
...@@ -2868,81 +2876,83 @@ class TestLocalErfc: ...@@ -2868,81 +2876,83 @@ class TestLocalErfc:
30, 30,
100, 100,
] ]
if config.mode in ["DebugMode", "DEBUG_MODE", "FAST_COMPILE"]:
# python mode don't like the inv(0) in computation,
# but the switch don't select this value.
# So it is computed for no good reason.
val.remove(0)
if config.mode in ["DebugMode", "DEBUG_MODE"] and config.floatX == "float32":
# In float32 their is a plage of values close to 10 that we stabilize as it give bigger error then the stabilized version.
# The orig value in float32 -30.0, the stab value -20.1 the orig value in float64 -18.1.
val.remove(10)
val = np.asarray(val, dtype=config.floatX) val = np.asarray(val, dtype=config.floatX)
x = vector("x") x = vector("x")
y = vector("y") y = vector("y")
# their is some nan that will happear in the graph for the log of the negatives values # Test cases for which the requisite form isn't present
mode = copy.copy(self.mode) no_matches = [
([x, y], exp(sqr(x)) / erfc(y)),
([x, y], exp(neg(x)) / erfc(y)),
([x, y], exp(x * 1) / erfc(y)),
([x, y], exp(neg(sqr(x))) / erfc(y)),
([x], mul(1.0, 2.0, x) / erfc(x)),
]
for inputs, no_match in no_matches:
fg = FunctionGraph(inputs, [no_match], clone=False)
TopoOptimizer(
LocalOptGroup(local_grad_log_erfc_neg), order="out_to_in"
).optimize(fg)
# Make sure that the graph hasn't been changed
assert fg.outputs[0] is no_match
# Some `nan`s will appear in the graph for the log of negatives values
mode = Mode("py", self.mode.optimizer)
mode.check_isfinite = False mode.check_isfinite = False
mode_fusion = copy.copy(self.mode_fusion)
mode_fusion.check_isfinite = False
f = function([x], aesara.gradient.grad(log(erfc(x)).sum(), x), mode=mode) # Make sure that we catch our target graph in a way that it's naturally
# produced
log_erfc_grad = aesara.gradient.grad(log(erfc(x)).sum(), x)
f = function([x], log_erfc_grad, mode=mode)
assert len(f.maker.fgraph.apply_nodes) == 22, len(f.maker.fgraph.apply_nodes) # The resulting graph should be `mul(switch(...), y)`
assert f.maker.fgraph.outputs[0].owner.op == mul
assert f.maker.fgraph.outputs[0].owner.inputs[0].owner.op == switch
assert all(np.isfinite(f(val))) assert all(np.isfinite(f(val)))
assert f.maker.fgraph.outputs[0].dtype == config.floatX assert f.maker.fgraph.outputs[0].dtype == config.floatX
# test with a different mul constant # Test with a different `mul` and `constant`
f = function([x], mul(exp(neg(sqr(x))), -10.12837917) / erfc(x), mode=mode) f = function([x], mul(exp(neg(sqr(x))), -10.12837917) / erfc(x), mode=mode)
assert len(f.maker.fgraph.apply_nodes) == 23, len(f.maker.fgraph.apply_nodes)
assert f.maker.fgraph.outputs[0].owner.op == mul
assert f.maker.fgraph.outputs[0].owner.inputs[0].owner.op == switch
assert f.maker.fgraph.outputs[0].dtype == config.floatX assert f.maker.fgraph.outputs[0].dtype == config.floatX
assert all(np.isfinite(f(val))) assert all(np.isfinite(f(val)))
# test that we work without the mul # Test it works without the `mul`
f = function([x], exp(neg(sqr(x))) / erfc(x), mode=mode) f = function([x], exp(neg(sqr(x))) / erfc(x), mode=mode)
assert len(f.maker.fgraph.apply_nodes) == 22, len(f.maker.fgraph.apply_nodes)
assert f.maker.fgraph.outputs[0].dtype == config.floatX
assert all(np.isfinite(f(val)))
# test that we don't work if x!=y assert f.maker.fgraph.outputs[0].owner.op == switch
f = function([x, y], exp(neg(sqr(x))) / erfc(y), mode=mode)
assert len(f.maker.fgraph.apply_nodes) == 5, len(f.maker.fgraph.apply_nodes)
assert f.maker.fgraph.outputs[0].dtype == config.floatX assert f.maker.fgraph.outputs[0].dtype == config.floatX
f(val, val - 3) assert all(np.isfinite(f(val)))
# test that we work without the sqr and neg # Test that it works without the `sqr` and `neg`
f = function([x], exp(mul(-1, x, x)) / erfc(x), mode=mode) f = function([x], exp(mul(-1, x, x)) / erfc(x), mode=mode)
assert len(f.maker.fgraph.apply_nodes) == 21, len(f.maker.fgraph.apply_nodes)
assert f.maker.fgraph.outputs[0].owner.op == switch
assert f.maker.fgraph.outputs[0].dtype == config.floatX assert f.maker.fgraph.outputs[0].dtype == config.floatX
assert all(np.isfinite(f(val))) assert all(np.isfinite(f(val)))
# test that it work correctly if x is x*2 in the graph. # Test that it works correctly when `x` is multiplied by a constant
f = function([x], aesara.gradient.grad(log(erfc(2 * x)).sum(), x), mode=mode) f = function([x], aesara.gradient.grad(log(erfc(2 * x)).sum(), x), mode=mode)
assert len(f.maker.fgraph.apply_nodes) == 23, len(f.maker.fgraph.apply_nodes)
assert f.maker.fgraph.outputs[0].owner.op == mul
assert f.maker.fgraph.outputs[0].owner.inputs[0].owner.op == switch
assert np.isfinite(f(val)).all() assert np.isfinite(f(val)).all()
assert f.maker.fgraph.outputs[0].dtype == config.floatX assert f.maker.fgraph.outputs[0].dtype == config.floatX
# I suppose this tests whether or not the transform is applied before
# fusion?
mode_fusion = copy.copy(self.mode_fusion)
mode_fusion.check_isfinite = False
f = function([x], aesara.gradient.grad(log(erfc(x)).sum(), x), mode=mode_fusion) f = function([x], aesara.gradient.grad(log(erfc(x)).sum(), x), mode=mode_fusion)
assert len(f.maker.fgraph.apply_nodes) == 1, len(f.maker.fgraph.apply_nodes) assert len(f.maker.fgraph.apply_nodes) == 1, len(f.maker.fgraph.apply_nodes)
assert f.maker.fgraph.outputs[0].dtype == config.floatX assert f.maker.fgraph.outputs[0].dtype == config.floatX
# TODO: fix this problem
if config.floatX == "float32" and config.mode in [
"DebugMode",
"DEBUG_MODE",
]:
# The python code upcast somewhere internally some value of float32
# to python float for part of its computation. That make that the c
# and python code do not generate the same value. You can ignore
# this error. This happen in an intermediate step that don't show
# in the final result.
# Showing this test error is a duplicate of the one in test_local_log_erfc. We hide it.
pass
else:
assert all(np.isfinite(f(val)))
def speed_local_log_erfc(self): def speed_local_log_erfc(self):
val = np.random.rand(1e6) val = np.random.rand(1e6)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论