提交 5df8dbe9 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Make InputsToGpuOptimizer skip processing if all clients are transfers or…

Make InputsToGpuOptimizer skip processing if all clients are transfers or outputs, no matter their number. Also, fix the error condition.
上级 4cc62ca6
...@@ -165,9 +165,9 @@ class InputToGpuOptimizer(Optimizer): ...@@ -165,9 +165,9 @@ class InputToGpuOptimizer(Optimizer):
if isinstance(input.type, GpuArrayType): if isinstance(input.type, GpuArrayType):
continue continue
if (len(input.clients) == 1 and # If all clients are outputs or transfers don't do anything.
(input.clients[0][0] == 'output' or if (all(cl[0] == 'output' or isinstance(cl[0].op, GpuFromHost)
isinstance(input.clients[0][0].op, GpuFromHost))): for cl in input.clients)):
continue continue
ctx_name = getattr(input.tag, 'context_name', None) ctx_name = getattr(input.tag, 'context_name', None)
...@@ -179,10 +179,10 @@ class InputToGpuOptimizer(Optimizer): ...@@ -179,10 +179,10 @@ class InputToGpuOptimizer(Optimizer):
# This could fail if the inputs are not TensorTypes # This could fail if the inputs are not TensorTypes
pass pass
except ContextNotDefined: except ContextNotDefined:
if hasattr(input.tag, 'context_name'):
raise
# If there is no context tag and no default context # If there is no context tag and no default context
# then it stays on the CPU # then it stays on the CPU
if not hasattr(input.tag, 'context_name'):
raise
pass pass
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论