提交 3ecd6ebf authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Fix C code of minimum/maximum when there are NaNs

Reinstate test.
上级 5a31a368
......@@ -753,7 +753,7 @@ class ScalarOp(Op):
return self.__class__.__name__
def c_code_cache_version(self):
return (3,)
return (4,)
class UnaryScalarOp(ScalarOp):
......@@ -1078,7 +1078,9 @@ class Maximum(BinaryScalarOp):
def c_code(self, node, name, (x, y), (z, ), sub):
if any([i.type in complex_types for i in node.inputs]):
raise NotImplementedError()
return "%(z)s = ((%(y)s)>(%(x)s)? (%(y)s):(%(x)s));" % locals()
# Test for both y>x and x>=y to detect NaN
return ('%(z)s = ((%(y)s)>(%(x)s)? (%(y)s): '
'((%(x)s)>=(%(y)s)? (%(x)s): nan("")));' % locals())
def grad(self, (x, y), (gz, )):
assert gz.type not in complex_types
......@@ -1103,7 +1105,8 @@ class Minimum(BinaryScalarOp):
def c_code(self, node, name, (x, y), (z, ), sub):
if any([i.type in complex_types for i in node.inputs]):
raise NotImplementedError()
return "%(z)s = ((%(y)s)<(%(x)s)? (%(y)s):(%(x)s));" % locals()
return ('%(z)s = ((%(y)s)<(%(x)s)? (%(y)s): '
'((%(x)s)<=(%(y)s)? (%(x)s): nan("")));' % locals())
def grad(self, (x, y), (gz, )):
assert gz.type not in complex_types
......
......@@ -389,10 +389,6 @@ class test_CAReduce(unittest_tools.InferShapeTester):
self.with_linker(gof.CLinker(), scalar.and_, dtype=dtype)
self.with_linker(gof.CLinker(), scalar.xor, dtype=dtype)
@dec.knownfailureif(
True,
("When there is nan in the input of CAReduce,"
" we don't have a good output. "))
def test_c_nan(self):
for dtype in ["floatX", "complex64", "complex128"]:
self.with_linker(gof.CLinker(), scalar.add, dtype=dtype,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论