提交 c50df515 authored 作者: Frederic's avatar Frederic

Fix Sng handling of nan and fix related test

上级 669cf0d1
...@@ -2104,7 +2104,7 @@ class Sgn(UnaryScalarOp): ...@@ -2104,7 +2104,7 @@ class Sgn(UnaryScalarOp):
(z,) = outputs (z,) = outputs
type = node.inputs[0].type type = node.inputs[0].type
if type in float_types: if type in float_types:
return "%(z)s = (%(x)s >= 0) ? (%(x)s == 0) ? 0.0 : 1.0 : -1.0;" % locals() return '%(z)s = (%(x)s > 0) ? 1. : ((%(x)s < 0) ? -1. : (isnan(%(x)s) ? NAN : 0.));' % locals()
if type in int_types: if type in int_types:
return "%(z)s = (%(x)s >= 0) ? (%(x)s == 0) ? 0 : 1 : -1;" % locals() return "%(z)s = (%(x)s >= 0) ? (%(x)s == 0) ? 0 : 1 : -1;" % locals()
raise TypeError() # complex has no sgn raise TypeError() # complex has no sgn
...@@ -2112,7 +2112,7 @@ class Sgn(UnaryScalarOp): ...@@ -2112,7 +2112,7 @@ class Sgn(UnaryScalarOp):
def c_code_cache_version(self): def c_code_cache_version(self):
s = super(Sgn, self).c_code_cache_version() s = super(Sgn, self).c_code_cache_version()
if s: if s:
return (3,) + s return (4,) + s
else: # if parent is unversioned, we are too else: # if parent is unversioned, we are too
return s return s
sgn = Sgn(same_out_nocomplex, name='sgn') sgn = Sgn(same_out_nocomplex, name='sgn')
......
...@@ -4149,9 +4149,15 @@ class test_comparison(unittest.TestCase): ...@@ -4149,9 +4149,15 @@ class test_comparison(unittest.TestCase):
(shared(l.astype(dtype)), r, False), (shared(l.astype(dtype)), r, False),
(shared(l.astype(dtype)), constant(r), False), (shared(l.astype(dtype)), constant(r), False),
]: ]:
mode = get_default_mode()
mode.check_isfinite = False
try: try:
fn1 = inplace_func([], isclose(x, y, equal_nan=False)) o1 = isclose(x, y, equal_nan=False)
fn2 = inplace_func([], isclose(x, y, equal_nan=True)) fn1 = inplace_func([], o1, mode=mode)
o2 = isclose(x, y, equal_nan=True)
fn2 = inplace_func([], o2, mode=mode)
v1 = fn1() v1 = fn1()
v2 = fn2() v2 = fn2()
self.assertTrue( self.assertTrue(
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论