提交 6361bb2b authored 作者: Frederic's avatar Frederic

Allow to pass rtol and atol to TensorType.values_eq_approx

上级 f74426f1
......@@ -922,7 +922,8 @@ class TensorType(Type):
return False
@staticmethod
def values_eq_approx(a, b, allow_remove_inf=False, allow_remove_nan=False):
def values_eq_approx(a, b, allow_remove_inf=False, allow_remove_nan=False,
rtol=None, atol=None):
"""
:param allow_remove_inf: If True, when there is an inf in a,
we allow any value in b in that position.
......@@ -930,6 +931,8 @@ class TensorType(Type):
:param allow_remove_nan: If True, when there is a nan in a,
we allow any value in b in that position.
Event +-inf
:param rtol: relative tolerance, passed to _allclose
:param atol: absolute tolerance, passed to _allclose
"""
if isinstance(a, numpy.ndarray) and isinstance(b, numpy.ndarray):
if a.shape != b.shape:
......@@ -945,7 +948,7 @@ class TensorType(Type):
a = a.reshape(1)
b = b.reshape(1)
cmp = _allclose(a, b)
cmp = _allclose(a, b, rtol=rtol, atol=atol)
if cmp:
# Numpy claims they are close, this is good enough for us.
return True
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论