提交 d85528d4 authored 作者: Frederic Bastien's avatar Frederic Bastien

Make CAReduce work in more cases with bool instead of crashing.

上级 127ae529
......@@ -1539,8 +1539,8 @@ class CAReduce(Op):
scal_name = 'maximum'
if input.type.dtype in ["float32", "float64"]:
identity = "-__builtin_inf()"
elif input.type.dtype.startswith("uint"):
# numpy does not define NPY_MIN_UINT*
elif input.type.dtype.startswith("uint") or input.type.dtype == 'bool':
# numpy does not define NPY_MIN_UINT* and NPY_MIN_BOOL
identity = "0"
else:
identity = "NPY_MIN_" + str(input.type.dtype).upper()
......@@ -1548,6 +1548,9 @@ class CAReduce(Op):
scal_name = 'minimum'
if input.type.dtype in ["float32", "float64"]:
identity = "__builtin_inf()"
elif input.type.dtype == 'bool':
# numpy does not define NPY_MIN_UINT* and NPY_MAX_BOOL
identity = "1"
else:
identity = "NPY_MAX_" + str(input.type.dtype).upper()
fail = sub["fail"]
......
......@@ -503,7 +503,7 @@ class test_CAReduce(unittest_tools.InferShapeTester):
assert xv.size == 0
def test_perform(self):
for dtype in ["floatX", "complex64", "complex128", "int8", "uint8"]:
for dtype in ["bool", "floatX", "complex64", "complex128", "int8", "uint8"]:
self.with_linker(gof.PerformLinker(), scalar.add, dtype=dtype)
self.with_linker(gof.PerformLinker(), scalar.mul, dtype=dtype)
self.with_linker(gof.PerformLinker(), scalar.maximum, dtype=dtype)
......@@ -537,17 +537,17 @@ class test_CAReduce(unittest_tools.InferShapeTester):
if not theano.config.cxx:
raise SkipTest("G++ not available, so we need to skip this test.")
for dtype in ["floatX", "complex64", "complex128", "int8", "uint8"]:
for dtype in ["bool", "floatX", "complex64", "complex128", "int8", "uint8"]:
self.with_linker(gof.CLinker(), scalar.add, dtype=dtype)
self.with_linker(gof.CLinker(), scalar.mul, dtype=dtype)
for dtype in ["floatX", "int8", "uint8"]:
for dtype in ["bool", "floatX", "int8", "uint8"]:
self.with_linker(gof.CLinker(), scalar.minimum, dtype=dtype)
self.with_linker(gof.CLinker(), scalar.maximum, dtype=dtype)
self.with_linker(gof.CLinker(), scalar.and_, dtype=dtype,
tensor_op=tensor.all)
self.with_linker(gof.CLinker(), scalar.or_, dtype=dtype,
tensor_op=tensor.any)
for dtype in ["int8", "uint8"]:
for dtype in ["bool", "int8", "uint8"]:
self.with_linker(gof.CLinker(), scalar.or_, dtype=dtype)
self.with_linker(gof.CLinker(), scalar.and_, dtype=dtype)
self.with_linker(gof.CLinker(), scalar.xor, dtype=dtype)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论