提交 8d803f8c authored 作者: James Bergstra's avatar James Bergstra

Modified tensor.basic._allclose used in values_eq_approx to be more lenient with

single-precision floats.
上级 b4ccfa60
...@@ -199,6 +199,15 @@ def _wrap_tensor_into_member(x): ...@@ -199,6 +199,15 @@ def _wrap_tensor_into_member(x):
return compile.module.Member(constant(x)) return compile.module.Member(constant(x))
compile.module.register_wrapper(_obj_is_wrappable_as_tensor, _wrap_tensor_into_member) compile.module.register_wrapper(_obj_is_wrappable_as_tensor, _wrap_tensor_into_member)
def _allclose(a, b):
narrow = 'float32', 'complex64'
if (str(a.dtype) in narrow) or (str(b.dtype) in narrow):
atol = 1e-5
rtol = 1e-3 # Sensible??
return numpy.allclose(a,b, atol=atol, rtol=rtol)
else:
# keep defaults of in numpy.allclose
return numpy.allclose(a,b)
class TensorType(Type): class TensorType(Type):
"""Symbolic `Type` representing a numpy.ndarray value.""" """Symbolic `Type` representing a numpy.ndarray value."""
...@@ -299,12 +308,9 @@ class TensorType(Type): ...@@ -299,12 +308,9 @@ class TensorType(Type):
# following two lines, that may seem weird at first glance. # following two lines, that may seem weird at first glance.
# If someone can figure out what it is, please say it here! # If someone can figure out what it is, please say it here!
ones = numpy.ones(2) ones = numpy.ones(2)
return numpy.allclose(ones * a, ones*b) return _allclose(ones * a, ones*b)
#elif str(a.dtype).startswith('complex'):
# print >> sys.stderr, 'WARNING: skipping comparison of complex'
# return True
else: else:
cmp = numpy.allclose(a,b) cmp = _allclose(a, b)
if cmp: if cmp:
# Numpy claims they are close, this is good enough for us. # Numpy claims they are close, this is good enough for us.
return True return True
...@@ -320,6 +326,9 @@ class TensorType(Type): ...@@ -320,6 +326,9 @@ class TensorType(Type):
if not a_missing.any(): if not a_missing.any():
# There are no missing values in a, thus this is not the # There are no missing values in a, thus this is not the
# reason why numpy.allclose(a, b) returned False. # reason why numpy.allclose(a, b) returned False.
_info('numpy allclose failed for abs_err %f and rel_err %f' %(
numpy.max( abs(a-b)),
numpy.max( abs(a-b)/(abs(a)+abs(b)))))
return False return False
# The following line is what numpy.allclose bases its decision # The following line is what numpy.allclose bases its decision
# upon, according to its documentation. # upon, according to its documentation.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论