提交 943bddf4 authored 作者: James Bergstra's avatar James Bergstra

added support for ndarray subtypes in values_eq_approx

上级 a8171dd2
...@@ -497,19 +497,13 @@ class TensorType(Type): ...@@ -497,19 +497,13 @@ class TensorType(Type):
we allow any value in b in that position. we allow any value in b in that position.
Event +-inf Event +-inf
""" """
if type(a) is numpy.ndarray and type(b) is numpy.ndarray: if isinstance(a, numpy.ndarray) and isinstance(b, numpy.ndarray):
if a.shape != b.shape: if a.shape != b.shape:
return False return False
if a.dtype != b.dtype: if a.dtype != b.dtype:
return False return False
if 'int' in str(a.dtype): if 'int' in str(a.dtype):
return numpy.all(a==b) return numpy.all(a==b)
#elif a.shape == (): #for comparing scalars, use broadcasting.
## Note: according to James B, there was a reason for the
## following two lines, that may seem weird at first glance.
## If someone can figure out what it is, please say it here!
#ones = numpy.ones(2)
#return _allclose(ones * a, ones*b) ### dtype handling is wrong here
else: else:
cmp = _allclose(a, b) cmp = _allclose(a, b)
if cmp: if cmp:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论