提交 1adbf0a3 authored 作者: James Bergstra's avatar James Bergstra

tensor.basic - Changed E_grad from a string to an Exception subclass.

上级 19fa61d3
......@@ -4004,7 +4004,19 @@ def verify_grad(op, pt, n_tests=2, rng=None, eps=None, tol=None, mode=None, cast
max_err, max_err_pos = num_grad.max_err(analytic_grad)
if max_err > tol:
raise Exception(verify_grad.E_grad, (max_err, tol, max_err_pos))
raise verify_grad.E_grad(tol, num_grad, analytic_grad)
class GradientError(Exception):
"""This error is raised when a gradient is calculated, but incorrect."""
def __init__(self, tol, num_grad, analytic_grad):
self.num_grad = num_grad
self.analytic_grad = analytic_grad
self.tol = tol
def __str__(self):
max_errs = [numpy.max(e) for e in self.num_grad.abs_rel_errors(self.analytic_grad)]
return "GradientError: numeric gradient and analytic gradient differ than %f (%s)" %(
self.tol, max_errs)
def abs_rel_errors(self):
return self.num_grad.abs_rel_errors(self.analytic_grad)
verify_grad.E_grad = 'gradient error exceeded tolerance'
"""This error is raised when a gradient is calculated, but incorrect."""
verify_grad.E_grad = GradientError
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论