提交 8bd24b6b authored 作者: Frederic Bastien's avatar Frederic Bastien

Added the allow_remove_nan paramter to TensorType.values_eq_approx

上级 7886aa24
......@@ -488,11 +488,14 @@ class TensorType(Type):
else:
return False
@staticmethod
def values_eq_approx(a, b, allow_remove_inf = False):
def values_eq_approx(a, b, allow_remove_inf = True, allow_remove_nan = False):
"""
:param allow_remove_inf: If True, when their is an inf in a,
we allow any value in b in that position.
Event -inf
:param allow_remove_nan: If True, when their is a nan in a,
we allow any value in b in that position.
Event +-inf
"""
if type(a) is numpy.ndarray and type(b) is numpy.ndarray:
if a.shape != b.shape:
......@@ -551,6 +554,8 @@ class TensorType(Type):
if allow_remove_inf:
both_inf += a_inf
if allow_remove_nan:
both_missing += a_missing
# Combine all information.
return (cmp_elemwise + both_missing + both_inf).all()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论