提交 126240a6 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Prevent the op_lifter from moving to GPU stuff that only depends on scalars.

上级 5bb6abe5
......@@ -193,7 +193,8 @@ def op_lifter(OP, cuda_only=False):
context_name = None
# We replace if any input is a host_from_gpu
for i in node.inputs:
if i.owner and i.owner.op == host_from_gpu:
if (i.owner and i.owner.op == host_from_gpu and
move_to_gpu(i)):
context_name = i.owner.inputs[0].type.context_name
replace = True
break
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论