提交 ed44bb59 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Make int64 default dtype for bool reduction

This is consistent with numpy
上级 91333f00
...@@ -1848,6 +1848,7 @@ class CAReduceDtype(CAReduce): ...@@ -1848,6 +1848,7 @@ class CAReduceDtype(CAReduce):
if dtype is None: if dtype is None:
# If input has a discrete dtype, upcast it to 64 # If input has a discrete dtype, upcast it to 64
return dict( return dict(
bool='int64',
int8='int64', int8='int64',
int16='int64', int16='int64',
int32='int64', int32='int64',
...@@ -1863,6 +1864,7 @@ class CAReduceDtype(CAReduce): ...@@ -1863,6 +1864,7 @@ class CAReduceDtype(CAReduce):
acc_dtype = self.acc_dtype acc_dtype = self.acc_dtype
if acc_dtype is None: if acc_dtype is None:
return dict( return dict(
bool='int64',
int8='int64', int8='int64',
int16='int64', int16='int64',
int32='int64', int32='int64',
......
...@@ -795,6 +795,7 @@ class T_reduce_dtype(unittest.TestCase): ...@@ -795,6 +795,7 @@ class T_reduce_dtype(unittest.TestCase):
x = tensor.matrix(dtype=dtype) x = tensor.matrix(dtype=dtype)
s = getattr(x, method)(axis=axis) s = getattr(x, method)(axis=axis)
assert s.dtype == dict( assert s.dtype == dict(
bool='int64',
int8='int64', int8='int64',
int16='int64', int16='int64',
int32='int64', int32='int64',
...@@ -820,6 +821,7 @@ class T_reduce_dtype(unittest.TestCase): ...@@ -820,6 +821,7 @@ class T_reduce_dtype(unittest.TestCase):
x = tensor.matrix(dtype=dtype) x = tensor.matrix(dtype=dtype)
s = getattr(x, method)(axis=axis) s = getattr(x, method)(axis=axis)
assert s.owner.op.acc_dtype == dict( assert s.owner.op.acc_dtype == dict(
bool='int64',
int8='int64', int8='int64',
int16='int64', int16='int64',
int32='int64', int32='int64',
...@@ -1017,6 +1019,7 @@ class T_prod_without_zeros_dtype(unittest.TestCase): ...@@ -1017,6 +1019,7 @@ class T_prod_without_zeros_dtype(unittest.TestCase):
axis = axes[idx % len(axes)] axis = axes[idx % len(axes)]
x = ProdWithoutZeros(axis=axis)(tensor.matrix(dtype=dtype)) x = ProdWithoutZeros(axis=axis)(tensor.matrix(dtype=dtype))
assert x.dtype == dict( assert x.dtype == dict(
bool='int64',
int8='int64', int8='int64',
int16='int64', int16='int64',
int32='int64', int32='int64',
...@@ -1035,6 +1038,7 @@ class T_prod_without_zeros_dtype(unittest.TestCase): ...@@ -1035,6 +1038,7 @@ class T_prod_without_zeros_dtype(unittest.TestCase):
x = tensor.matrix(dtype=dtype) x = tensor.matrix(dtype=dtype)
p = ProdWithoutZeros(axis=axis)(x) p = ProdWithoutZeros(axis=axis)(x)
assert p.owner.op.acc_dtype == dict( assert p.owner.op.acc_dtype == dict(
bool='int64',
int8='int64', int8='int64',
int16='int64', int16='int64',
int32='int64', int32='int64',
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论