提交 755ba97a authored 作者: abergeron's avatar abergeron

Merge pull request #3189 from nouiz/fix_nan

[BUG, TEST] Fix Sng handling of nan and fix related test
......@@ -384,6 +384,14 @@ class G_reshape(test_basic.T_reshape):
assert self.op == GpuReshape
class G_comparison(test_basic.test_comparison):
def setUp(self):
utt.seed_rng()
self.mode = mode_with_gpu
self.shared = gpuarray_shared_constructor
self.dtypes = ['float64', 'float32']
class G_Join_and_Split(test_basic.T_Join_and_Split):
def setUp(self):
super(G_Join_and_Split, self).setUp()
......
......@@ -121,7 +121,20 @@ class GpuArrayType(Type):
return False
if a.typecode != b.typecode:
return False
return numpy.asarray(compare(a, '==', b)).all()
a_eq_b = numpy.asarray(compare(a, '==', b))
if a_eq_b.all():
return True
# maybe the trouble is that there are NaNs
a = numpy.asarray(a)
b = numpy.asarray(b)
a_missing = numpy.isnan(a)
if a_missing.any():
b_missing = numpy.isnan(b)
return numpy.all(a_eq_b + (a_missing == b_missing))
else:
return False
@staticmethod
def values_eq_approx(a, b,
......@@ -157,7 +170,15 @@ class GpuArrayType(Type):
op_tmpl="res[i] = (fabs(%%(a)s - %%(b)s) <"
"(%(atol_)s + %(rtol_)s * fabs(%%(b)s)))" %
locals())
return numpy.asarray(res).all()
ret = numpy.asarray(res).all()
if ret:
return True
# maybe the trouble is that there are NaNs
an = numpy.asarray(a)
bn = numpy.asarray(b)
return tensor.TensorType.values_eq_approx(
an, bn, allow_remove_inf=allow_remove_inf,
allow_remove_nan=allow_remove_nan, rtol=rtol, atol=atol)
@staticmethod
def may_share_memory(a, b):
......
......@@ -1096,8 +1096,6 @@ class EQ(LogicalComparison):
def c_code(self, node, name, inputs, outputs, sub):
(x, y) = inputs
(z,) = outputs
if node.inputs[0].type in complex_types:
raise NotImplementedError()
return "%(z)s = (%(x)s == %(y)s);" % locals()
eq = EQ()
......@@ -2104,7 +2102,7 @@ class Sgn(UnaryScalarOp):
(z,) = outputs
type = node.inputs[0].type
if type in float_types:
return "%(z)s = (%(x)s >= 0) ? (%(x)s == 0) ? 0.0 : 1.0 : -1.0;" % locals()
return '%(z)s = (%(x)s > 0) ? 1. : ((%(x)s < 0) ? -1. : (isnan(%(x)s) ? NAN : 0.));' % locals()
if type in int_types:
return "%(z)s = (%(x)s >= 0) ? (%(x)s == 0) ? 0 : 1 : -1;" % locals()
raise TypeError() # complex has no sgn
......@@ -2112,7 +2110,7 @@ class Sgn(UnaryScalarOp):
def c_code_cache_version(self):
s = super(Sgn, self).c_code_cache_version()
if s:
return (3,) + s
return (4,) + s
else: # if parent is unversioned, we are too
return s
sgn = Sgn(same_out_nocomplex, name='sgn')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论