提交 7a90c786 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Don't transfer int inputs to the GPU by default.

Also recognize the target attribute to prevent transferring of floats.
上级 f279798b
...@@ -159,7 +159,6 @@ class InputToGpuOptimizer(Optimizer): ...@@ -159,7 +159,6 @@ class InputToGpuOptimizer(Optimizer):
Transfer the input to the gpu to start the rolling wave. Transfer the input to the gpu to start the rolling wave.
""" """
def add_requirements(self, fgraph): def add_requirements(self, fgraph):
fgraph.attach_feature(toolbox.ReplaceValidate()) fgraph.attach_feature(toolbox.ReplaceValidate())
...@@ -173,7 +172,15 @@ class InputToGpuOptimizer(Optimizer): ...@@ -173,7 +172,15 @@ class InputToGpuOptimizer(Optimizer):
for cl in input.clients)): for cl in input.clients)):
continue continue
ctx_name = getattr(input.tag, 'context_name', None) target = getattr(input.tag, 'target', None)
if target == 'cpu':
continue
if (not input.type.dtype.startswith('float') and
not hasattr(input.tag, 'target')):
continue
ctx_name = getattr(input.tag, 'context_name', target)
try: try:
new_input = host_from_gpu(GpuFromHost(ctx_name)(input)) new_input = host_from_gpu(GpuFromHost(ctx_name)(input))
fgraph.replace_validate(input, new_input, fgraph.replace_validate(input, new_input,
...@@ -182,7 +189,8 @@ class InputToGpuOptimizer(Optimizer): ...@@ -182,7 +189,8 @@ class InputToGpuOptimizer(Optimizer):
# This could fail if the inputs are not TensorTypes # This could fail if the inputs are not TensorTypes
pass pass
except ContextNotDefined: except ContextNotDefined:
if hasattr(input.tag, 'context_name'): if (hasattr(input.tag, 'context_name') or
hasattr(input.tag, 'target')):
raise raise
# If there is no context tag and no default context # If there is no context tag and no default context
# then it stays on the CPU # then it stays on the CPU
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论