提交 7a90c786 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Don't transfer int inputs to the GPU by default.

Also recognize the target attribute to prevent transferring of floats.
上级 f279798b
......@@ -159,7 +159,6 @@ class InputToGpuOptimizer(Optimizer):
Transfer the input to the gpu to start the rolling wave.
"""
def add_requirements(self, fgraph):
fgraph.attach_feature(toolbox.ReplaceValidate())
......@@ -173,7 +172,15 @@ class InputToGpuOptimizer(Optimizer):
for cl in input.clients)):
continue
ctx_name = getattr(input.tag, 'context_name', None)
target = getattr(input.tag, 'target', None)
if target == 'cpu':
continue
if (not input.type.dtype.startswith('float') and
not hasattr(input.tag, 'target')):
continue
ctx_name = getattr(input.tag, 'context_name', target)
try:
new_input = host_from_gpu(GpuFromHost(ctx_name)(input))
fgraph.replace_validate(input, new_input,
......@@ -182,7 +189,8 @@ class InputToGpuOptimizer(Optimizer):
# This could fail if the inputs are not TensorTypes
pass
except ContextNotDefined:
if hasattr(input.tag, 'context_name'):
if (hasattr(input.tag, 'context_name') or
hasattr(input.tag, 'target')):
raise
# If there is no context tag and no default context
# then it stays on the CPU
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论