提交 3f734f94 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Unify how atol and rtol are determined.

上级 ef9e88ed
......@@ -162,13 +162,7 @@ class GpuArrayType(Type):
return tensor.TensorType.values_eq_approx(
an, bn, allow_remove_inf=allow_remove_inf,
allow_remove_nan=allow_remove_nan, rtol=rtol, atol=atol)
narrow = 'float32', 'complex64'
if (str(a.dtype) in narrow) or (str(b.dtype) in narrow):
atol_ = theano.tensor.basic.float32_atol
rtol_ = theano.tensor.basic.float32_rtol
else:
atol_ = theano.tensor.basic.float64_atol
rtol_ = theano.tensor.basic.float64_rtol
atol_, rtol_ = theano.tensor.basic._get_atol_rtol(a, b)
if rtol is not None:
rtol_ = rtol
if atol is not None:
......
......@@ -459,18 +459,29 @@ if int(config.tensor.cmp_sloppy) > 1:
# When config.tensor.cmp_sloppy>1 we are even more sloppy. This is
# useful to test the GPU as they don't use extended precision and
# this cause some difference bigger then the normal sloppy.
float16_atol = 5e-3
float16_rtol = 1e-2
float32_atol = 5e-4
float32_rtol = 1e-3
float64_rtol = 1e-4
float64_atol = 1e-3
elif int(config.tensor.cmp_sloppy):
float16_atol = 1e-3
float16_rtol = 5e-3
float32_atol = 1e-4
float32_rtol = 1e-3
float64_rtol = 1e-4
float64_atol = 1e-3
else:
# If you change those value in test don't forget to put them back
# when the test end. Don't forget the case when the test fail.
float16_atol = 5e-4
float16_rtol = 5e-4
float32_atol = 1e-5
float32_rtol = 1e-5
......@@ -481,16 +492,25 @@ else:
float64_rtol = 1.0000000000000001e-06
def _get_atol_rtol(a, b):
tiny = ('float16',)
narrow = ('float32', 'complex64')
if (str(a.dtype) in tiny) or (str(b.dtype) in tiny):
atol = float16_atol
rtol = float16_rtol
elif (str(a.dtype) in narrow) or (str(b.dtype) in narrow):
atol = float32_atol
rtol = float32_rtol
else:
atol = float64_atol
rtol = float64_rtol
return atol, rtol
def _allclose(a, b, rtol=None, atol=None):
a = numpy.asarray(a)
b = numpy.asarray(b)
narrow = 'float32', 'complex64'
if (str(a.dtype) in narrow) or (str(b.dtype) in narrow):
atol_ = float32_atol
rtol_ = float32_rtol
else:
atol_ = float64_atol
rtol_ = float64_rtol
atol_, rtol_ = _get_atol_rtol(a, b)
if rtol is not None:
rtol_ = rtol
if atol is not None:
......
......@@ -309,15 +309,7 @@ def str_diagnostic(expected, value, rtol, atol):
print(ssio.getvalue(), file=sio)
except Exception:
pass
# Use the same formula as in _allclose to find the tolerance used
narrow = 'float32', 'complex64'
if ((str(expected.dtype) in narrow) or
(str(value.dtype) in narrow)):
atol_ = T.basic.float32_atol
rtol_ = T.basic.float32_rtol
else:
atol_ = T.basic.float64_atol
rtol_ = T.basic.float64_rtol
atol_, rtol_ = T.basic._get_atol_rtol(expected, value)
if rtol is not None:
rtol_ = rtol
if atol is not None:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论