提交 c95f7e8b authored 作者: Frédéric Bastien's avatar Frédéric Bastien 提交者: GitHub

Merge pull request #5717 from nouiz/careduce

crash fix of careduce with bool
......@@ -1539,8 +1539,8 @@ class CAReduce(Op):
scal_name = 'maximum'
if input.type.dtype in ["float32", "float64"]:
identity = "-__builtin_inf()"
elif input.type.dtype.startswith("uint"):
# numpy does not define NPY_MIN_UINT*
elif input.type.dtype.startswith("uint") or input.type.dtype == 'bool':
# numpy does not define NPY_MIN_UINT* and NPY_MIN_BOOL
identity = "0"
else:
identity = "NPY_MIN_" + str(input.type.dtype).upper()
......@@ -1548,6 +1548,9 @@ class CAReduce(Op):
scal_name = 'minimum'
if input.type.dtype in ["float32", "float64"]:
identity = "__builtin_inf()"
elif input.type.dtype == 'bool':
# numpy does not define NPY_MAX_BOOL
identity = "1"
else:
identity = "NPY_MAX_" + str(input.type.dtype).upper()
fail = sub["fail"]
......
......@@ -426,6 +426,9 @@ class test_CAReduce(unittest_tools.InferShapeTester):
elif scalar_op == scalar.add:
for axis in reversed(sorted(tosum)):
zv = numpy.add.reduce(zv, axis)
if dtype == 'bool':
# numpy.add of a bool upcast, while CAReduce don't
zv = zv.astype(dtype)
elif scalar_op == scalar.mul:
for axis in reversed(sorted(tosum)):
zv = numpy.multiply.reduce(zv, axis)
......@@ -503,7 +506,7 @@ class test_CAReduce(unittest_tools.InferShapeTester):
assert xv.size == 0
def test_perform(self):
for dtype in ["floatX", "complex64", "complex128", "int8", "uint8"]:
for dtype in ["bool", "floatX", "complex64", "complex128", "int8", "uint8"]:
self.with_linker(gof.PerformLinker(), scalar.add, dtype=dtype)
self.with_linker(gof.PerformLinker(), scalar.mul, dtype=dtype)
self.with_linker(gof.PerformLinker(), scalar.maximum, dtype=dtype)
......@@ -537,17 +540,17 @@ class test_CAReduce(unittest_tools.InferShapeTester):
if not theano.config.cxx:
raise SkipTest("G++ not available, so we need to skip this test.")
for dtype in ["floatX", "complex64", "complex128", "int8", "uint8"]:
for dtype in ["bool", "floatX", "complex64", "complex128", "int8", "uint8"]:
self.with_linker(gof.CLinker(), scalar.add, dtype=dtype)
self.with_linker(gof.CLinker(), scalar.mul, dtype=dtype)
for dtype in ["floatX", "int8", "uint8"]:
for dtype in ["bool", "floatX", "int8", "uint8"]:
self.with_linker(gof.CLinker(), scalar.minimum, dtype=dtype)
self.with_linker(gof.CLinker(), scalar.maximum, dtype=dtype)
self.with_linker(gof.CLinker(), scalar.and_, dtype=dtype,
tensor_op=tensor.all)
self.with_linker(gof.CLinker(), scalar.or_, dtype=dtype,
tensor_op=tensor.any)
for dtype in ["int8", "uint8"]:
for dtype in ["bool", "int8", "uint8"]:
self.with_linker(gof.CLinker(), scalar.or_, dtype=dtype)
self.with_linker(gof.CLinker(), scalar.and_, dtype=dtype)
self.with_linker(gof.CLinker(), scalar.xor, dtype=dtype)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论