提交 ed44bb59 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Make int64 default dtype for bool reduction

This is consistent with numpy
上级 91333f00
......@@ -1848,6 +1848,7 @@ class CAReduceDtype(CAReduce):
if dtype is None:
# If input has a discrete dtype, upcast it to 64
return dict(
bool='int64',
int8='int64',
int16='int64',
int32='int64',
......@@ -1863,6 +1864,7 @@ class CAReduceDtype(CAReduce):
acc_dtype = self.acc_dtype
if acc_dtype is None:
return dict(
bool='int64',
int8='int64',
int16='int64',
int32='int64',
......
......@@ -795,6 +795,7 @@ class T_reduce_dtype(unittest.TestCase):
x = tensor.matrix(dtype=dtype)
s = getattr(x, method)(axis=axis)
assert s.dtype == dict(
bool='int64',
int8='int64',
int16='int64',
int32='int64',
......@@ -820,6 +821,7 @@ class T_reduce_dtype(unittest.TestCase):
x = tensor.matrix(dtype=dtype)
s = getattr(x, method)(axis=axis)
assert s.owner.op.acc_dtype == dict(
bool='int64',
int8='int64',
int16='int64',
int32='int64',
......@@ -1017,6 +1019,7 @@ class T_prod_without_zeros_dtype(unittest.TestCase):
axis = axes[idx % len(axes)]
x = ProdWithoutZeros(axis=axis)(tensor.matrix(dtype=dtype))
assert x.dtype == dict(
bool='int64',
int8='int64',
int16='int64',
int32='int64',
......@@ -1035,6 +1038,7 @@ class T_prod_without_zeros_dtype(unittest.TestCase):
x = tensor.matrix(dtype=dtype)
p = ProdWithoutZeros(axis=axis)(x)
assert p.owner.op.acc_dtype == dict(
bool='int64',
int8='int64',
int16='int64',
int32='int64',
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论