提交 c50df515 authored 作者: Frederic's avatar Frederic

Fix Sng handling of nan and fix related test

上级 669cf0d1
......@@ -2104,7 +2104,7 @@ class Sgn(UnaryScalarOp):
(z,) = outputs
type = node.inputs[0].type
if type in float_types:
return "%(z)s = (%(x)s >= 0) ? (%(x)s == 0) ? 0.0 : 1.0 : -1.0;" % locals()
return '%(z)s = (%(x)s > 0) ? 1. : ((%(x)s < 0) ? -1. : (isnan(%(x)s) ? NAN : 0.));' % locals()
if type in int_types:
return "%(z)s = (%(x)s >= 0) ? (%(x)s == 0) ? 0 : 1 : -1;" % locals()
raise TypeError() # complex has no sgn
......@@ -2112,7 +2112,7 @@ class Sgn(UnaryScalarOp):
def c_code_cache_version(self):
s = super(Sgn, self).c_code_cache_version()
if s:
return (3,) + s
return (4,) + s
else: # if parent is unversioned, we are too
return s
sgn = Sgn(same_out_nocomplex, name='sgn')
......
......@@ -4149,9 +4149,15 @@ class test_comparison(unittest.TestCase):
(shared(l.astype(dtype)), r, False),
(shared(l.astype(dtype)), constant(r), False),
]:
mode = get_default_mode()
mode.check_isfinite = False
try:
fn1 = inplace_func([], isclose(x, y, equal_nan=False))
fn2 = inplace_func([], isclose(x, y, equal_nan=True))
o1 = isclose(x, y, equal_nan=False)
fn1 = inplace_func([], o1, mode=mode)
o2 = isclose(x, y, equal_nan=True)
fn2 = inplace_func([], o2, mode=mode)
v1 = fn1()
v2 = fn2()
self.assertTrue(
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论