提交 d85528d4 authored 作者: Frederic Bastien's avatar Frederic Bastien

Make CAReduce work in more cases with bool instead of crashing.

上级 127ae529
...@@ -1539,8 +1539,8 @@ class CAReduce(Op): ...@@ -1539,8 +1539,8 @@ class CAReduce(Op):
scal_name = 'maximum' scal_name = 'maximum'
if input.type.dtype in ["float32", "float64"]: if input.type.dtype in ["float32", "float64"]:
identity = "-__builtin_inf()" identity = "-__builtin_inf()"
elif input.type.dtype.startswith("uint"): elif input.type.dtype.startswith("uint") or input.type.dtype == 'bool':
# numpy does not define NPY_MIN_UINT* # numpy does not define NPY_MIN_UINT* and NPY_MIN_BOOL
identity = "0" identity = "0"
else: else:
identity = "NPY_MIN_" + str(input.type.dtype).upper() identity = "NPY_MIN_" + str(input.type.dtype).upper()
...@@ -1548,6 +1548,9 @@ class CAReduce(Op): ...@@ -1548,6 +1548,9 @@ class CAReduce(Op):
scal_name = 'minimum' scal_name = 'minimum'
if input.type.dtype in ["float32", "float64"]: if input.type.dtype in ["float32", "float64"]:
identity = "__builtin_inf()" identity = "__builtin_inf()"
elif input.type.dtype == 'bool':
# numpy does not define NPY_MIN_UINT* and NPY_MAX_BOOL
identity = "1"
else: else:
identity = "NPY_MAX_" + str(input.type.dtype).upper() identity = "NPY_MAX_" + str(input.type.dtype).upper()
fail = sub["fail"] fail = sub["fail"]
......
...@@ -503,7 +503,7 @@ class test_CAReduce(unittest_tools.InferShapeTester): ...@@ -503,7 +503,7 @@ class test_CAReduce(unittest_tools.InferShapeTester):
assert xv.size == 0 assert xv.size == 0
def test_perform(self): def test_perform(self):
for dtype in ["floatX", "complex64", "complex128", "int8", "uint8"]: for dtype in ["bool", "floatX", "complex64", "complex128", "int8", "uint8"]:
self.with_linker(gof.PerformLinker(), scalar.add, dtype=dtype) self.with_linker(gof.PerformLinker(), scalar.add, dtype=dtype)
self.with_linker(gof.PerformLinker(), scalar.mul, dtype=dtype) self.with_linker(gof.PerformLinker(), scalar.mul, dtype=dtype)
self.with_linker(gof.PerformLinker(), scalar.maximum, dtype=dtype) self.with_linker(gof.PerformLinker(), scalar.maximum, dtype=dtype)
...@@ -537,17 +537,17 @@ class test_CAReduce(unittest_tools.InferShapeTester): ...@@ -537,17 +537,17 @@ class test_CAReduce(unittest_tools.InferShapeTester):
if not theano.config.cxx: if not theano.config.cxx:
raise SkipTest("G++ not available, so we need to skip this test.") raise SkipTest("G++ not available, so we need to skip this test.")
for dtype in ["floatX", "complex64", "complex128", "int8", "uint8"]: for dtype in ["bool", "floatX", "complex64", "complex128", "int8", "uint8"]:
self.with_linker(gof.CLinker(), scalar.add, dtype=dtype) self.with_linker(gof.CLinker(), scalar.add, dtype=dtype)
self.with_linker(gof.CLinker(), scalar.mul, dtype=dtype) self.with_linker(gof.CLinker(), scalar.mul, dtype=dtype)
for dtype in ["floatX", "int8", "uint8"]: for dtype in ["bool", "floatX", "int8", "uint8"]:
self.with_linker(gof.CLinker(), scalar.minimum, dtype=dtype) self.with_linker(gof.CLinker(), scalar.minimum, dtype=dtype)
self.with_linker(gof.CLinker(), scalar.maximum, dtype=dtype) self.with_linker(gof.CLinker(), scalar.maximum, dtype=dtype)
self.with_linker(gof.CLinker(), scalar.and_, dtype=dtype, self.with_linker(gof.CLinker(), scalar.and_, dtype=dtype,
tensor_op=tensor.all) tensor_op=tensor.all)
self.with_linker(gof.CLinker(), scalar.or_, dtype=dtype, self.with_linker(gof.CLinker(), scalar.or_, dtype=dtype,
tensor_op=tensor.any) tensor_op=tensor.any)
for dtype in ["int8", "uint8"]: for dtype in ["bool", "int8", "uint8"]:
self.with_linker(gof.CLinker(), scalar.or_, dtype=dtype) self.with_linker(gof.CLinker(), scalar.or_, dtype=dtype)
self.with_linker(gof.CLinker(), scalar.and_, dtype=dtype) self.with_linker(gof.CLinker(), scalar.and_, dtype=dtype)
self.with_linker(gof.CLinker(), scalar.xor, dtype=dtype) self.with_linker(gof.CLinker(), scalar.xor, dtype=dtype)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论