提交 7d303ff1 authored 作者: Frederic Bastien's avatar Frederic Bastien 提交者: sentient07

Fix cycle with GraphToGPU.

上级 5c8e2cbf
...@@ -376,9 +376,20 @@ class GraphToGPU(NavigatorOptimizer): ...@@ -376,9 +376,20 @@ class GraphToGPU(NavigatorOptimizer):
if new_o.type != o.type: if new_o.type != o.type:
assert isinstance(o.type, tensor.TensorType) assert isinstance(o.type, tensor.TensorType)
assert isinstance(new_o.type, GpuArrayType) assert isinstance(new_o.type, GpuArrayType)
assert new_o.owner
# This condition is needed in the case one input is an
# output of the graph. Without this, it would
# introduce cycle as we don't replace correctly that
# case. It would also add extra transfer to/from the
# gpu.
if (isinstance(new_o.owner.op, GpuFromHost) and
new_o.owner.inputs[0].type == o.type):
new_o = new_o.owner.inputs[0]
else:
new_o = host_from_gpu(new_o) new_o = host_from_gpu(new_o)
new_nodes.append(new_o) new_nodes.append(new_o)
fgraph.replace_all_validate(zip(fgraph.outputs, new_nodes)) fgraph.replace_all_validate(zip(fgraph.outputs, new_nodes),
reason=self.__class__.__name__)
return (self, toposort_timing, time_opts, node_created, process_count) return (self, toposort_timing, time_opts, node_created, process_count)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论