提交 d6fc8e30 authored 作者: James Bergstra's avatar James Bergstra

fix div-by-zero in tensor_grad

上级 5a0baf88
......@@ -444,7 +444,7 @@ class numeric_grad(object):
self.gf = self.gf[0]
@staticmethod
def abs_rel_err(a,b):
def abs_rel_err(a, b, eps=1e-8):
"""Return absolute and relative error between a and b.
The relative error is a small number when a and b are close, relative to how big they are.
......@@ -455,11 +455,11 @@ class numeric_grad(object):
The tuple (abs_err, rel_err) is returned
"""
abs_err = abs(a-b)
rel_err = abs_err / (abs(a) + abs(b))
abs_err = abs(a - b)
rel_err = abs_err / (abs(a) + abs(b) + eps)
return (abs_err, rel_err)
def abs_rel_errors(self, g_pt):
def abs_rel_errors(self, g_pt, eps=1e-8):
"""Return the abs and rel error of gradient estimate `g_pt`
`g_pt` must be a list of ndarrays of the same length as self.gf,
......@@ -479,7 +479,7 @@ class numeric_grad(object):
raise ValueError(
'argument element %i has wrong shape %s' % (
i, str((a.shape, b.shape))))
errs.append(numeric_grad.abs_rel_err(a,b))
errs.append(numeric_grad.abs_rel_err(a, b, eps))
return errs
def max_err(self, g_pt, abs_tol, rel_tol):
......@@ -499,7 +499,11 @@ class numeric_grad(object):
abs_rel_errs = self.abs_rel_errors(g_pt)
for abs_err, rel_err in abs_rel_errs:
scaled_err = numpy.minimum(abs_err/abs_tol, rel_err/rel_tol)
if not numpy.all(numpy.isfinite(abs_err)):
raise ValueError('abs_err not finite', repr(abs_err))
if not numpy.all(numpy.isfinite(rel_err)):
raise ValueError('rel_err not finite', repr(rel_err))
scaled_err = numpy.minimum(abs_err / abs_tol, rel_err / rel_tol)
max_i = scaled_err.argmax()
pos.append(max_i)
......@@ -513,8 +517,8 @@ class numeric_grad(object):
return (max_arg, pos[max_arg], abs_errs[max_arg], rel_errs[max_arg])
def verify_grad(fun, pt, n_tests=2, rng=None, eps=None, abs_tol=None, rel_tol=None,
mode=None, cast_to_output_type=False):
def verify_grad(fun, pt, n_tests=2, rng=None, eps=None, abs_tol=None,
rel_tol=None, mode=None, cast_to_output_type=False):
""" Test a gradient by Finite Difference Method. Raise error on failure.
Example:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论