提交 d3c8eb17 authored 作者: nouiz's avatar nouiz

Merge pull request #801 from lamblin/fix_careduce_nan

Fix CAReduce for 0 shapes and NaNs
...@@ -753,7 +753,7 @@ class ScalarOp(Op): ...@@ -753,7 +753,7 @@ class ScalarOp(Op):
return self.__class__.__name__ return self.__class__.__name__
def c_code_cache_version(self): def c_code_cache_version(self):
return (3,) return (4,)
class UnaryScalarOp(ScalarOp): class UnaryScalarOp(ScalarOp):
...@@ -1078,7 +1078,9 @@ class Maximum(BinaryScalarOp): ...@@ -1078,7 +1078,9 @@ class Maximum(BinaryScalarOp):
def c_code(self, node, name, (x, y), (z, ), sub): def c_code(self, node, name, (x, y), (z, ), sub):
if any([i.type in complex_types for i in node.inputs]): if any([i.type in complex_types for i in node.inputs]):
raise NotImplementedError() raise NotImplementedError()
return "%(z)s = ((%(y)s)>(%(x)s)? (%(y)s):(%(x)s));" % locals() # Test for both y>x and x>=y to detect NaN
return ('%(z)s = ((%(y)s)>(%(x)s)? (%(y)s): '
'((%(x)s)>=(%(y)s)? (%(x)s): nan("")));' % locals())
def grad(self, (x, y), (gz, )): def grad(self, (x, y), (gz, )):
assert gz.type not in complex_types assert gz.type not in complex_types
...@@ -1103,7 +1105,8 @@ class Minimum(BinaryScalarOp): ...@@ -1103,7 +1105,8 @@ class Minimum(BinaryScalarOp):
def c_code(self, node, name, (x, y), (z, ), sub): def c_code(self, node, name, (x, y), (z, ), sub):
if any([i.type in complex_types for i in node.inputs]): if any([i.type in complex_types for i in node.inputs]):
raise NotImplementedError() raise NotImplementedError()
return "%(z)s = ((%(y)s)<(%(x)s)? (%(y)s):(%(x)s));" % locals() return ('%(z)s = ((%(y)s)<(%(x)s)? (%(y)s): '
'((%(x)s)<=(%(y)s)? (%(x)s): nan("")));' % locals())
def grad(self, (x, y), (gz, )): def grad(self, (x, y), (gz, )):
assert gz.type not in complex_types assert gz.type not in complex_types
......
...@@ -1209,8 +1209,12 @@ class CAReduce(Op): ...@@ -1209,8 +1209,12 @@ class CAReduce(Op):
# if available # if available
if variable.shape[dimension] == 0: if variable.shape[dimension] == 0:
if hasattr(self.scalar_op, 'identity'): if hasattr(self.scalar_op, 'identity'):
variable = numpy.array(self.scalar_op.identity) # Compute the shape of the output
break v_shape = list(variable.shape)
del v_shape[dimension]
variable = numpy.empty(tuple(v_shape),
dtype=variable.dtype)
variable.fill(self.scalar_op.identity)
else: else:
raise ValueError(( raise ValueError((
"Input (%s) has zero-size on axis %s, but " "Input (%s) has zero-size on axis %s, but "
......
...@@ -315,15 +315,18 @@ class test_CAReduce(unittest_tools.InferShapeTester): ...@@ -315,15 +315,18 @@ class test_CAReduce(unittest_tools.InferShapeTester):
else: else:
self.fail() self.fail()
else: else:
#numpy.{all,any} return bool type. # numpy.{all,any} return bool type,
# but theano ops return an int8 array instead
if scalar_op in [scalar.and_, scalar.or_]: if scalar_op in [scalar.and_, scalar.or_]:
zv = numpy.asarray(zv, dtype=dtype) zv = numpy.asarray(zv, dtype='int8')
if test_nan: if test_nan:
self.assertTrue(theano.tensor.TensorType.values_eq(f(xv), self.assertTrue(theano.tensor.TensorType.values_eq(f(xv),
zv), zv),
(f(xv), zv)) (f(xv), zv))
else: else:
self.assertTrue(numpy.allclose(f(xv), zv), (f(xv), zv)) f_xv = f(xv)
self.assertTrue((f_xv.shape == zv.shape), (f_xv, zv))
self.assertTrue(numpy.allclose(f_xv, zv), (f_xv, zv))
#test CAReduce.infer_shape #test CAReduce.infer_shape
#the Shape op don't implement c_code! #the Shape op don't implement c_code!
...@@ -355,10 +358,6 @@ class test_CAReduce(unittest_tools.InferShapeTester): ...@@ -355,10 +358,6 @@ class test_CAReduce(unittest_tools.InferShapeTester):
self.with_linker(gof.PerformLinker(), scalar.and_, dtype=dtype) self.with_linker(gof.PerformLinker(), scalar.and_, dtype=dtype)
self.with_linker(gof.PerformLinker(), scalar.xor, dtype=dtype) self.with_linker(gof.PerformLinker(), scalar.xor, dtype=dtype)
@dec.knownfailureif(
True,
("When there is nan in the input of CAReduce,"
" we don't have a good output. "))
def test_perform_nan(self): def test_perform_nan(self):
for dtype in ["floatX", "complex64", "complex128"]: for dtype in ["floatX", "complex64", "complex128"]:
self.with_linker(gof.PerformLinker(), scalar.add, dtype=dtype, self.with_linker(gof.PerformLinker(), scalar.add, dtype=dtype,
...@@ -370,12 +369,8 @@ class test_CAReduce(unittest_tools.InferShapeTester): ...@@ -370,12 +369,8 @@ class test_CAReduce(unittest_tools.InferShapeTester):
self.with_linker(gof.PerformLinker(), scalar.minimum, dtype=dtype, self.with_linker(gof.PerformLinker(), scalar.minimum, dtype=dtype,
test_nan=True) test_nan=True)
self.with_linker(gof.PerformLinker(), scalar.or_, dtype=dtype, self.with_linker(gof.PerformLinker(), scalar.or_, dtype=dtype,
test_nan=True)
self.with_linker(gof.PerformLinker(), scalar.and_, dtype=dtype,
test_nan=True)
self.with_linker(gof.PerformLinker(), or_, dtype=dtype,
test_nan=True, tensor_op=tensor.any) test_nan=True, tensor_op=tensor.any)
self.with_linker(gof.PerformLinker(), and_, dtype=dtype, self.with_linker(gof.PerformLinker(), scalar.and_, dtype=dtype,
test_nan=True, tensor_op=tensor.all) test_nan=True, tensor_op=tensor.all)
def test_c(self): def test_c(self):
...@@ -394,10 +389,6 @@ class test_CAReduce(unittest_tools.InferShapeTester): ...@@ -394,10 +389,6 @@ class test_CAReduce(unittest_tools.InferShapeTester):
self.with_linker(gof.CLinker(), scalar.and_, dtype=dtype) self.with_linker(gof.CLinker(), scalar.and_, dtype=dtype)
self.with_linker(gof.CLinker(), scalar.xor, dtype=dtype) self.with_linker(gof.CLinker(), scalar.xor, dtype=dtype)
@dec.knownfailureif(
True,
("When there is nan in the input of CAReduce,"
" we don't have a good output. "))
def test_c_nan(self): def test_c_nan(self):
for dtype in ["floatX", "complex64", "complex128"]: for dtype in ["floatX", "complex64", "complex128"]:
self.with_linker(gof.CLinker(), scalar.add, dtype=dtype, self.with_linker(gof.CLinker(), scalar.add, dtype=dtype,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论