提交 b333f6bb authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Fix T.mean() to use float32 for the sum.

上级 bf9413c8
...@@ -3155,11 +3155,9 @@ def mean(input, axis=None, dtype=None, op=False, keepdims=False, ...@@ -3155,11 +3155,9 @@ def mean(input, axis=None, dtype=None, op=False, keepdims=False,
sum_dtype = dtype sum_dtype = dtype
else: else:
sum_dtype = None sum_dtype = None
# float16 overflows on the cast way too often
# float16 overflows way too fast for sum if input.dtype == 'float16':
if ((sum_dtype == 'float16' or input.dtype == 'float16') and sum_dtype = 'float32'
acc_dtype != 'float16'):
sum_dtype == 'float32'
s = sum(input, axis=axis, dtype=sum_dtype, keepdims=keepdims, s = sum(input, axis=axis, dtype=sum_dtype, keepdims=keepdims,
acc_dtype=acc_dtype) acc_dtype=acc_dtype)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论