提交 45e52dd0 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Make mean return the right type and keep the accumulator at least float32 if possible.

上级 9c9a5a51
......@@ -2756,19 +2756,17 @@ def mean(input, axis=None, dtype=None, op=False, keepdims=False,
out = makeKeepDims(input, out, axis)
return out
# float16 has very low precision so we do some things differently
f16 = (input.dtype == 'float16')
if dtype is not None:
# The summation will be done with the specified dtype.
# sum() will complain if it is not suitable.
sum_dtype = dtype
else:
# Let sum() infer the appropriate dtype.
sum_dtype = None
if f16 and sum_dtype is None and acc_dtype != 'float16':
sum_dtype = 'float32'
# float16 overflows way too fast for sum
if ((sum_dtype == 'float16' or input.dtype == 'float16') and
acc_dtype != 'float16'):
sum_dtype == 'float32'
s = sum(input, axis=axis, dtype=sum_dtype, keepdims=keepdims,
acc_dtype=acc_dtype)
......@@ -2777,7 +2775,7 @@ def mean(input, axis=None, dtype=None, op=False, keepdims=False,
# Cast shp into a float type
# TODO Once we have a consistent casting policy, we could simply
# use true_div.
if s.dtype in ('float32', 'complex64'):
if s.dtype in ('float16', 'float32', 'complex64'):
shp = cast(shp, 'float32')
else:
shp = cast(shp, 'float64')
......@@ -2795,7 +2793,7 @@ def mean(input, axis=None, dtype=None, op=False, keepdims=False,
for i in axis:
s = true_div(s, shp[i])
if f16:
if dtype == 'float16' or (dtype is None and input.dtype == 'float16'):
s = cast(s, 'float16')
return s
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论