提交 b333f6bb authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Fix T.mean() to use float32 for the sum.

上级 bf9413c8
......@@ -3155,11 +3155,9 @@ def mean(input, axis=None, dtype=None, op=False, keepdims=False,
sum_dtype = dtype
else:
sum_dtype = None
# float16 overflows way too fast for sum
if ((sum_dtype == 'float16' or input.dtype == 'float16') and
acc_dtype != 'float16'):
sum_dtype == 'float32'
# float16 overflows on the cast way too often
if input.dtype == 'float16':
sum_dtype = 'float32'
s = sum(input, axis=axis, dtype=sum_dtype, keepdims=keepdims,
acc_dtype=acc_dtype)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论