提交 e4f8a536 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Force accumulator to float32.

上级 493bcaaf
......@@ -1112,10 +1112,20 @@ def local_gpua_careduce(op, context_name, inputs, outputs):
else:
return False
x, = inputs
idtype = x.dtype
adtype = getattr(op, 'acc_dtype', None)
odtype = getattr(op, 'dtype', outputs[0].dtype)
# Force accumulator to float32 for float32 inputs since tree
# reduction will not loose as much precision as linear
# accumulation and float64 is much slower on GPU.
if idtype == 'float32' and odtype == 'float32':
adtype = 'float32'
greduce = op2(
op.scalar_op, axis=op.axis,
dtype=getattr(op, 'dtype', outputs[0].dtype),
acc_dtype=getattr(op, 'acc_dtype', None))
dtype=odtype,
acc_dtype=adtype)
gvar = greduce(x)
# We need to have the make node called, otherwise the mask can
# be None
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论