提交 62fcda84 authored 作者: Frederic's avatar Frederic

Reuse test_comparison in gpuarray to test sgn() isnan on GPU. Make gpuarray…

Reuse test_comparison in gpuarray to test sgn() isnan on GPU. Make gpuarray better handle nan. At the same time, enable c code for EQ with complex.
上级 c50df515
......@@ -384,6 +384,14 @@ class G_reshape(test_basic.T_reshape):
assert self.op == GpuReshape
class G_comparison(test_basic.test_comparison):
def setUp(self):
utt.seed_rng()
self.mode = mode_with_gpu
self.shared = gpuarray_shared_constructor
self.dtypes = ['float64', 'float32']
class G_Join_and_Split(test_basic.T_Join_and_Split):
def setUp(self):
super(G_Join_and_Split, self).setUp()
......
......@@ -121,7 +121,20 @@ class GpuArrayType(Type):
return False
if a.typecode != b.typecode:
return False
return numpy.asarray(compare(a, '==', b)).all()
a_eq_b = numpy.asarray(compare(a, '==', b))
if a_eq_b.all():
return True
# maybe the trouble is that there are NaNs
a = numpy.asarray(a)
b = numpy.asarray(b)
a_missing = numpy.isnan(a)
if a_missing.any():
b_missing = numpy.isnan(b)
return numpy.all(a_eq_b + (a_missing == b_missing))
else:
return False
@staticmethod
def values_eq_approx(a, b,
......@@ -157,7 +170,15 @@ class GpuArrayType(Type):
op_tmpl="res[i] = (fabs(%%(a)s - %%(b)s) <"
"(%(atol_)s + %(rtol_)s * fabs(%%(b)s)))" %
locals())
return numpy.asarray(res).all()
ret = numpy.asarray(res).all()
if ret:
return True
# maybe the trouble is that there are NaNs
an = numpy.asarray(a)
bn = numpy.asarray(b)
return tensor.TensorType.values_eq_approx(
an, bn, allow_remove_inf=allow_remove_inf,
allow_remove_nan=allow_remove_nan, rtol=rtol, atol=atol)
@staticmethod
def may_share_memory(a, b):
......
......@@ -1096,8 +1096,6 @@ class EQ(LogicalComparison):
def c_code(self, node, name, inputs, outputs, sub):
(x, y) = inputs
(z,) = outputs
if node.inputs[0].type in complex_types:
raise NotImplementedError()
return "%(z)s = (%(x)s == %(y)s);" % locals()
eq = EQ()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论