提交 233e7813 authored 作者: Frederic's avatar Frederic 提交者: Arnaud Bergeron

If a reduce upcast the input, don't move it to the GPU.

上级 0b5aa217
...@@ -787,7 +787,7 @@ def local_gpu_careduce(node): ...@@ -787,7 +787,7 @@ def local_gpu_careduce(node):
x, = node.inputs x, = node.inputs
# Otherwise, is some corner case, we will try to move it # Otherwise, is some corner case, we will try to move it
# to the GPU later and this cause not wanted user warning. # to the GPU later and this cause not wanted user warning.
if x.dtype != 'float32': if x.dtype != 'float32' or node.outputs[0].dtype != "float32":
return return
replace = False replace = False
if x.owner and isinstance(x.owner.op, HostFromGpu): if x.owner and isinstance(x.owner.op, HostFromGpu):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论