提交 25aca3fa authored 作者: Frederic Bastien's avatar Frederic Bastien

make TensorType.values_eq_approx handle correctly case when their is a mix of…

make TensorType.values_eq_approx handle correctly case when their is a mix of inf and nan. When their is nan, we want to accept values that have the same inf.
上级 248919d0
...@@ -492,6 +492,7 @@ class TensorType(Type): ...@@ -492,6 +492,7 @@ class TensorType(Type):
""" """
:param allow_remove_inf: If True, when their is an inf in a, :param allow_remove_inf: If True, when their is an inf in a,
we allow any value in b in that position. we allow any value in b in that position.
Event -inf
""" """
if type(a) is numpy.ndarray and type(b) is numpy.ndarray: if type(a) is numpy.ndarray and type(b) is numpy.ndarray:
if a.shape != b.shape: if a.shape != b.shape:
...@@ -537,11 +538,22 @@ class TensorType(Type): ...@@ -537,11 +538,22 @@ class TensorType(Type):
(atol + rtol * numpy.absolute(b))) (atol + rtol * numpy.absolute(b)))
# Find places where both a and b have missing values. # Find places where both a and b have missing values.
both_missing = a_missing * numpy.isnan(b) both_missing = a_missing * numpy.isnan(b)
# Find places where both a and b have inf of the same sign.
both_inf = a_inf * numpy.isinf(b)
#check the sign of the inf:
for idx,v in enumerate(both_inf):
if v:
both_inf[idx] = a[idx]==b[idx]
#cmp_elemwise is True when we have inf and -inf.
#So we need to override it.
cmp_elemwise[idx]=a[idx]==b[idx]
if allow_remove_inf: if allow_remove_inf:
both_missing += a_inf both_inf += a_inf
# Combine all information. # Combine all information.
return (cmp_elemwise + both_missing).all() return (cmp_elemwise + both_missing + both_inf).all()
return False return False
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论