提交 5df8dbe9 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Make InputsToGpuOptimizer skip processing if all clients are transfers or…

Make InputsToGpuOptimizer skip processing if all clients are transfers or outputs, no matter their number. Also, fix the error condition.
上级 4cc62ca6
......@@ -165,9 +165,9 @@ class InputToGpuOptimizer(Optimizer):
if isinstance(input.type, GpuArrayType):
continue
if (len(input.clients) == 1 and
(input.clients[0][0] == 'output' or
isinstance(input.clients[0][0].op, GpuFromHost))):
# If all clients are outputs or transfers don't do anything.
if (all(cl[0] == 'output' or isinstance(cl[0].op, GpuFromHost)
for cl in input.clients)):
continue
ctx_name = getattr(input.tag, 'context_name', None)
......@@ -179,10 +179,10 @@ class InputToGpuOptimizer(Optimizer):
# This could fail if the inputs are not TensorTypes
pass
except ContextNotDefined:
if hasattr(input.tag, 'context_name'):
raise
# If there is no context tag and no default context
# then it stays on the CPU
if not hasattr(input.tag, 'context_name'):
raise
pass
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论