提交 195f9b1d authored 作者: sentient07's avatar sentient07

Fixing compilation of graphs with only one host_from_gpu OP

上级 345d9024
...@@ -389,17 +389,18 @@ class GraphToGPU(NavigatorOptimizer): ...@@ -389,17 +389,18 @@ class GraphToGPU(NavigatorOptimizer):
if new_o.type != o.type: if new_o.type != o.type:
assert isinstance(o.type, tensor.TensorType) assert isinstance(o.type, tensor.TensorType)
assert isinstance(new_o.type, GpuArrayType) assert isinstance(new_o.type, GpuArrayType)
assert new_o.owner
# This condition is needed in the case one input is an # This condition is needed in the case one input is an
# output of the graph. Without this, it would # output of the graph. Without this, it would
# introduce cycle as we don't replace correctly that # introduce cycle as we don't replace correctly that
# case. It would also add extra transfer to/from the # case. It would also add extra transfer to/from the
# gpu. # gpu.
if (isinstance(new_o.owner.op, GpuFromHost) and if (new_o.owner and
isinstance(new_o.owner.op, GpuFromHost) and
new_o.owner.inputs[0].type == o.type): new_o.owner.inputs[0].type == o.type):
new_o = new_o.owner.inputs[0] new_o = new_o.owner.inputs[0]
else: else:
new_o = host_from_gpu(new_o) new_o = safe_to_cpu(new_o)
new_nodes.append(new_o) new_nodes.append(new_o)
fgraph.replace_all_validate(zip(fgraph.outputs, new_nodes), fgraph.replace_all_validate(zip(fgraph.outputs, new_nodes),
reason=self.__class__.__name__) reason=self.__class__.__name__)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论