提交 45e52dd0 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Make mean return the right type and keep the accumulator at least float32 if possible.

上级 9c9a5a51
...@@ -2756,19 +2756,17 @@ def mean(input, axis=None, dtype=None, op=False, keepdims=False, ...@@ -2756,19 +2756,17 @@ def mean(input, axis=None, dtype=None, op=False, keepdims=False,
out = makeKeepDims(input, out, axis) out = makeKeepDims(input, out, axis)
return out return out
# float16 has very low precision so we do some things differently
f16 = (input.dtype == 'float16')
if dtype is not None: if dtype is not None:
# The summation will be done with the specified dtype. # The summation will be done with the specified dtype.
# sum() will complain if it is not suitable. # sum() will complain if it is not suitable.
sum_dtype = dtype sum_dtype = dtype
else: else:
# Let sum() infer the appropriate dtype.
sum_dtype = None sum_dtype = None
if f16 and sum_dtype is None and acc_dtype != 'float16': # float16 overflows way too fast for sum
sum_dtype = 'float32' if ((sum_dtype == 'float16' or input.dtype == 'float16') and
acc_dtype != 'float16'):
sum_dtype == 'float32'
s = sum(input, axis=axis, dtype=sum_dtype, keepdims=keepdims, s = sum(input, axis=axis, dtype=sum_dtype, keepdims=keepdims,
acc_dtype=acc_dtype) acc_dtype=acc_dtype)
...@@ -2777,7 +2775,7 @@ def mean(input, axis=None, dtype=None, op=False, keepdims=False, ...@@ -2777,7 +2775,7 @@ def mean(input, axis=None, dtype=None, op=False, keepdims=False,
# Cast shp into a float type # Cast shp into a float type
# TODO Once we have a consistent casting policy, we could simply # TODO Once we have a consistent casting policy, we could simply
# use true_div. # use true_div.
if s.dtype in ('float32', 'complex64'): if s.dtype in ('float16', 'float32', 'complex64'):
shp = cast(shp, 'float32') shp = cast(shp, 'float32')
else: else:
shp = cast(shp, 'float64') shp = cast(shp, 'float64')
...@@ -2795,7 +2793,7 @@ def mean(input, axis=None, dtype=None, op=False, keepdims=False, ...@@ -2795,7 +2793,7 @@ def mean(input, axis=None, dtype=None, op=False, keepdims=False,
for i in axis: for i in axis:
s = true_div(s, shp[i]) s = true_div(s, shp[i])
if f16: if dtype == 'float16' or (dtype is None and input.dtype == 'float16'):
s = cast(s, 'float16') s = cast(s, 'float16')
return s return s
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论