提交 6361bb2b authored 作者: Frederic's avatar Frederic

Allow to pass rtol and atol to TensorType.values_eq_approx

上级 f74426f1
...@@ -922,7 +922,8 @@ class TensorType(Type): ...@@ -922,7 +922,8 @@ class TensorType(Type):
return False return False
@staticmethod @staticmethod
def values_eq_approx(a, b, allow_remove_inf=False, allow_remove_nan=False): def values_eq_approx(a, b, allow_remove_inf=False, allow_remove_nan=False,
rtol=None, atol=None):
""" """
:param allow_remove_inf: If True, when there is an inf in a, :param allow_remove_inf: If True, when there is an inf in a,
we allow any value in b in that position. we allow any value in b in that position.
...@@ -930,6 +931,8 @@ class TensorType(Type): ...@@ -930,6 +931,8 @@ class TensorType(Type):
:param allow_remove_nan: If True, when there is a nan in a, :param allow_remove_nan: If True, when there is a nan in a,
we allow any value in b in that position. we allow any value in b in that position.
Event +-inf Event +-inf
:param rtol: relative tolerance, passed to _allclose
:param atol: absolute tolerance, passed to _allclose
""" """
if isinstance(a, numpy.ndarray) and isinstance(b, numpy.ndarray): if isinstance(a, numpy.ndarray) and isinstance(b, numpy.ndarray):
if a.shape != b.shape: if a.shape != b.shape:
...@@ -945,7 +948,7 @@ class TensorType(Type): ...@@ -945,7 +948,7 @@ class TensorType(Type):
a = a.reshape(1) a = a.reshape(1)
b = b.reshape(1) b = b.reshape(1)
cmp = _allclose(a, b) cmp = _allclose(a, b, rtol=rtol, atol=atol)
if cmp: if cmp:
# Numpy claims they are close, this is good enough for us. # Numpy claims they are close, this is good enough for us.
return True return True
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论