提交 7a32c2df authored 作者: Frederic Bastien's avatar Frederic Bastien

allow small rouding error for test of erf,erfc and their inplace variant.

上级 c396b0ca
......@@ -58,11 +58,11 @@ def safe_make_node(op, *inputs):
else:
return node.owner
def makeTester(name, op, expected, checks = {}, good = {}, bad_build = {}, bad_runtime = {}, grad = {}, mode = None, grad_rtol=None):
def makeTester(name, op, expected, checks = {}, good = {}, bad_build = {}, bad_runtime = {}, grad = {}, mode = None, grad_rtol=None, eps = 1e-10):
if grad is True:
grad = good
_op, _expected, _checks, _good, _bad_build, _bad_runtime, _grad, _mode, _grad_rtol = op, expected, checks, good, bad_build, bad_runtime, grad, mode, grad_rtol
_op, _expected, _checks, _good, _bad_build, _bad_runtime, _grad, _mode, _grad_rtol, _eps = op, expected, checks, good, bad_build, bad_runtime, grad, mode, grad_rtol, eps
class Checker(unittest.TestCase):
......@@ -106,6 +106,10 @@ def makeTester(name, op, expected, checks = {}, good = {}, bad_build = {}, bad_r
expecteds = self.expected(*inputs)
eps = 1e-10
if any([i.dtype=='float32' for i in inputs]):
eps=8e-6#1e-6
eps = numpy.max([eps,_eps])
try:
variables = f(*inputs)
except:
......@@ -117,9 +121,7 @@ def makeTester(name, op, expected, checks = {}, good = {}, bad_build = {}, bad_r
if not isinstance(expecteds, (list, tuple)):
expecteds = (expecteds, )
if any([i.dtype=='float32' for i in inputs]):
eps=8e-6#1e-6
else: eps = 1e-10
for i, (variable, expected) in enumerate(zip(variables, expecteds)):
if variable.dtype != expected.dtype or variable.shape != expected.shape or \
numpy.any(numpy.abs(variable - expected) > eps):
......@@ -665,12 +667,14 @@ ErfTester = makeBroadcastTester(op = erf,
expected = expected,
good = _good_broadcast_unary_normal,
grad = _grad_broadcast_unary_normal,
eps = 2e-10,
mode = mode_no_scipy)
ErfInplaceTester = makeBroadcastTester(op = inplace.erf_inplace,
expected = expected,
good = _good_broadcast_unary_normal_no_int,
grad = _grad_broadcast_unary_normal,
mode = mode_no_scipy,
eps = 2e-10,
inplace = True)
if imported_scipy_special:
......@@ -690,11 +694,13 @@ ErfcTester = makeBroadcastTester(op = erfc,
expected = expected,
good = _good_broadcast_unary_normal,
grad = _grad_broadcast_unary_normal,
eps = 2e-10,
mode = mode_no_scipy)
ErfcInplaceTester = makeBroadcastTester(op = inplace.erfc_inplace,
expected = expected,
good = _good_broadcast_unary_normal_no_int,
grad = _grad_broadcast_unary_normal,
eps = 2e-10,
mode = mode_no_scipy,
inplace = True)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论