提交 a476757a authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Change mean() to always return the same dtype even if axis=[].

上级 872f34a3
...@@ -3182,6 +3182,9 @@ def mean(input, axis=None, dtype=None, op=False, keepdims=False, ...@@ -3182,6 +3182,9 @@ def mean(input, axis=None, dtype=None, op=False, keepdims=False,
for i in axis: for i in axis:
s = true_div(s, shp[i]) s = true_div(s, shp[i])
if s.dtype != shp.dtype and s.dtype in discrete_dtypes:
s = cast(s, shp.dtype)
if dtype == 'float16' or (dtype is None and input.dtype == 'float16'): if dtype == 'float16' or (dtype is None and input.dtype == 'float16'):
s = cast(s, 'float16') s = cast(s, 'float16')
s.name = 'mean' s.name = 'mean'
......
...@@ -939,7 +939,7 @@ class T_mean_dtype(unittest.TestCase): ...@@ -939,7 +939,7 @@ class T_mean_dtype(unittest.TestCase):
axis = axes[idx % len(axes)] axis = axes[idx % len(axes)]
x = tensor.matrix(dtype=dtype) x = tensor.matrix(dtype=dtype)
m = x.mean(axis=axis) m = x.mean(axis=axis)
if dtype in tensor.discrete_dtypes and axis != []: if dtype in tensor.discrete_dtypes:
assert m.dtype == 'float64' assert m.dtype == 'float64'
else: else:
assert m.dtype == dtype, (m, m.dtype, dtype) assert m.dtype == dtype, (m, m.dtype, dtype)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论