提交 d1407834 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Fix non-cudnn reductions that were all subtly broken by 91bc16c3.

上级 42275cba
......@@ -3783,7 +3783,7 @@ def local_dnn_reduction(node):
with inherit_stack_trace(node.outputs):
ret = GpuDnnReduction(scal,
node.op.axis,
node.op.acc_dtype,
node.op._acc_dtype(node.inputs[0].dtype),
node.op.dtype,
False)(node.inputs[0])
return [post(ret)]
......
......@@ -1207,7 +1207,7 @@ def local_gpua_careduce(op, context_name, inputs, outputs):
return False
x, = inputs
idtype = x.dtype
adtype = getattr(op, 'acc_dtype', idtype)
adtype = getattr(op, 'acc_dtype', None)
odtype = getattr(op, 'dtype', outputs[0].dtype)
# Force accumulator to float32 for float32 inputs since tree
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论