提交 3c5d9bde authored 作者: Pascal Lamblin's avatar Pascal Lamblin 提交者: GitHub

Merge pull request #5760 from nouiz/opt_crash_fix

Fix opt crash fix and more information when verify_grad detect an error.
......@@ -1712,6 +1712,9 @@ def verify_grad(fun, pt, n_tests=2, rng=None, eps=None,
if max_abs_err > abs_tol and max_rel_err > rel_tol:
raise verify_grad.E_grad(max_arg, max_err_pos,
analytic_grad[max_arg].shape,
analytic_grad[max_arg].flatten()[max_err_pos],
num_grad.gf[max_arg].flatten()[max_err_pos],
max_abs_err, max_rel_err,
abs_tol, rel_tol)
......@@ -1727,10 +1730,14 @@ def verify_grad(fun, pt, n_tests=2, rng=None, eps=None,
class GradientError(Exception):
"""This error is raised when a gradient is calculated, but incorrect."""
def __init__(self, arg, err_pos, abs_err, rel_err, abs_tol, rel_tol):
def __init__(self, arg, err_pos, shape, val1, val2,
abs_err, rel_err, abs_tol, rel_tol):
Exception.__init__(self) # to be compatible with python2.4
self.arg = arg
self.err_pos = err_pos
self.shape = shape
self.val1 = val1
self.val2 = val2
self.abs_err = abs_err
self.rel_err = rel_err
self.abs_tol = abs_tol
......@@ -1741,10 +1748,13 @@ class GradientError(Exception):
args_msg = ", ".join(str(a) for a in self.args)
return """\
GradientError: numeric gradient and analytic gradient exceed tolerance:
At position %i of argument %i,
At position %i of argument %i with shape %s,
val1 = %f , val2 = %f
abs. error = %f, abs. tolerance = %f
rel. error = %f, rel. tolerance = %f
Exception args: %s""" % (self.err_pos, self.arg,
self.shape,
self.val1, self.val2,
self.abs_err, self.abs_tol,
self.rel_err, self.rel_tol,
args_msg)
......
......@@ -876,10 +876,13 @@ class Validator(object):
if out.owner is None:
if isinstance(out, tensor.TensorConstant):
if hasattr(out, 'fgraph'):
if hasattr(out, 'fgraph') or getattr(out, 'cached', False):
# If out have an fgraph, we aren't sure if it
# is from the inner graph or outer graph, so
# clone it.
# As it will be used as is in an FunctionGraph
# (won't be cloned later), it can't be a
# cached variable
cloned_out = out.clone()
self.valid.add(cloned_out)
self.invalid.add(out)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论