提交 755ba97a authored 作者: abergeron's avatar abergeron

Merge pull request #3189 from nouiz/fix_nan

[BUG, TEST] Fix Sng handling of nan and fix related test
...@@ -384,6 +384,14 @@ class G_reshape(test_basic.T_reshape): ...@@ -384,6 +384,14 @@ class G_reshape(test_basic.T_reshape):
assert self.op == GpuReshape assert self.op == GpuReshape
class G_comparison(test_basic.test_comparison):
def setUp(self):
utt.seed_rng()
self.mode = mode_with_gpu
self.shared = gpuarray_shared_constructor
self.dtypes = ['float64', 'float32']
class G_Join_and_Split(test_basic.T_Join_and_Split): class G_Join_and_Split(test_basic.T_Join_and_Split):
def setUp(self): def setUp(self):
super(G_Join_and_Split, self).setUp() super(G_Join_and_Split, self).setUp()
......
...@@ -121,7 +121,20 @@ class GpuArrayType(Type): ...@@ -121,7 +121,20 @@ class GpuArrayType(Type):
return False return False
if a.typecode != b.typecode: if a.typecode != b.typecode:
return False return False
return numpy.asarray(compare(a, '==', b)).all() a_eq_b = numpy.asarray(compare(a, '==', b))
if a_eq_b.all():
return True
# maybe the trouble is that there are NaNs
a = numpy.asarray(a)
b = numpy.asarray(b)
a_missing = numpy.isnan(a)
if a_missing.any():
b_missing = numpy.isnan(b)
return numpy.all(a_eq_b + (a_missing == b_missing))
else:
return False
@staticmethod @staticmethod
def values_eq_approx(a, b, def values_eq_approx(a, b,
...@@ -157,7 +170,15 @@ class GpuArrayType(Type): ...@@ -157,7 +170,15 @@ class GpuArrayType(Type):
op_tmpl="res[i] = (fabs(%%(a)s - %%(b)s) <" op_tmpl="res[i] = (fabs(%%(a)s - %%(b)s) <"
"(%(atol_)s + %(rtol_)s * fabs(%%(b)s)))" % "(%(atol_)s + %(rtol_)s * fabs(%%(b)s)))" %
locals()) locals())
return numpy.asarray(res).all() ret = numpy.asarray(res).all()
if ret:
return True
# maybe the trouble is that there are NaNs
an = numpy.asarray(a)
bn = numpy.asarray(b)
return tensor.TensorType.values_eq_approx(
an, bn, allow_remove_inf=allow_remove_inf,
allow_remove_nan=allow_remove_nan, rtol=rtol, atol=atol)
@staticmethod @staticmethod
def may_share_memory(a, b): def may_share_memory(a, b):
......
...@@ -1096,8 +1096,6 @@ class EQ(LogicalComparison): ...@@ -1096,8 +1096,6 @@ class EQ(LogicalComparison):
def c_code(self, node, name, inputs, outputs, sub): def c_code(self, node, name, inputs, outputs, sub):
(x, y) = inputs (x, y) = inputs
(z,) = outputs (z,) = outputs
if node.inputs[0].type in complex_types:
raise NotImplementedError()
return "%(z)s = (%(x)s == %(y)s);" % locals() return "%(z)s = (%(x)s == %(y)s);" % locals()
eq = EQ() eq = EQ()
...@@ -2104,7 +2102,7 @@ class Sgn(UnaryScalarOp): ...@@ -2104,7 +2102,7 @@ class Sgn(UnaryScalarOp):
(z,) = outputs (z,) = outputs
type = node.inputs[0].type type = node.inputs[0].type
if type in float_types: if type in float_types:
return "%(z)s = (%(x)s >= 0) ? (%(x)s == 0) ? 0.0 : 1.0 : -1.0;" % locals() return '%(z)s = (%(x)s > 0) ? 1. : ((%(x)s < 0) ? -1. : (isnan(%(x)s) ? NAN : 0.));' % locals()
if type in int_types: if type in int_types:
return "%(z)s = (%(x)s >= 0) ? (%(x)s == 0) ? 0 : 1 : -1;" % locals() return "%(z)s = (%(x)s >= 0) ? (%(x)s == 0) ? 0 : 1 : -1;" % locals()
raise TypeError() # complex has no sgn raise TypeError() # complex has no sgn
...@@ -2112,7 +2110,7 @@ class Sgn(UnaryScalarOp): ...@@ -2112,7 +2110,7 @@ class Sgn(UnaryScalarOp):
def c_code_cache_version(self): def c_code_cache_version(self):
s = super(Sgn, self).c_code_cache_version() s = super(Sgn, self).c_code_cache_version()
if s: if s:
return (3,) + s return (4,) + s
else: # if parent is unversioned, we are too else: # if parent is unversioned, we are too
return s return s
sgn = Sgn(same_out_nocomplex, name='sgn') sgn = Sgn(same_out_nocomplex, name='sgn')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论