提交 25686186 authored 作者: James Bergstra's avatar James Bergstra

elaborated TensorType.values_eq to implement NaN==NaN

上级 78e1dec1
......@@ -321,7 +321,16 @@ class TensorType(Type):
return False
if a.dtype != b.dtype:
return False
return numpy.all(a==b)
a_eq_b = (a==b)
r = numpy.all(a_eq_b)
if r: return True
# maybe the trouble is that there are NaNs
a_missing = numpy.isnan(a)
if a_missing.any():
b_missing = numpy.isnan(b)
return numpy.all(a_eq_b + (a_missing == b_missing))
else:
return False
@staticmethod
def values_eq_approx(a, b):
if type(a) is numpy.ndarray and type(b) is numpy.ndarray:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论