提交 8bd24b6b authored 作者: Frederic Bastien's avatar Frederic Bastien

Added the allow_remove_nan paramter to TensorType.values_eq_approx

上级 7886aa24
...@@ -488,11 +488,14 @@ class TensorType(Type): ...@@ -488,11 +488,14 @@ class TensorType(Type):
else: else:
return False return False
@staticmethod @staticmethod
def values_eq_approx(a, b, allow_remove_inf = False): def values_eq_approx(a, b, allow_remove_inf = True, allow_remove_nan = False):
""" """
:param allow_remove_inf: If True, when their is an inf in a, :param allow_remove_inf: If True, when their is an inf in a,
we allow any value in b in that position. we allow any value in b in that position.
Event -inf Event -inf
:param allow_remove_nan: If True, when their is a nan in a,
we allow any value in b in that position.
Event +-inf
""" """
if type(a) is numpy.ndarray and type(b) is numpy.ndarray: if type(a) is numpy.ndarray and type(b) is numpy.ndarray:
if a.shape != b.shape: if a.shape != b.shape:
...@@ -551,6 +554,8 @@ class TensorType(Type): ...@@ -551,6 +554,8 @@ class TensorType(Type):
if allow_remove_inf: if allow_remove_inf:
both_inf += a_inf both_inf += a_inf
if allow_remove_nan:
both_missing += a_missing
# Combine all information. # Combine all information.
return (cmp_elemwise + both_missing + both_inf).all() return (cmp_elemwise + both_missing + both_inf).all()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论