提交 1427753f authored 作者: --global's avatar --global

Fix Gpuarray optimization for PdbBreakpoint

上级 17b360f2
......@@ -342,8 +342,9 @@ def local_gpu_pdbbreakpoint_op(node):
new_inputs = node.inputs[:1]
input_transfered = []
# Propagate the transfers to gpu through the PdbBreakpoint node
# while leaving the PdbBreakpoint node fully on the host
# Go through the monitored variables, only transfering on GPU those
# for which the input comes from the GPU or the output will be
# transfered on the GPU.
nb_monitored_vars = len(node.outputs)
for i in range(nb_monitored_vars):
......@@ -352,23 +353,22 @@ def local_gpu_pdbbreakpoint_op(node):
input_is_from_gpu = (inp.owner and
isinstance(inp.owner.op, HostFromGpu))
output_used = len(out.clients) > 0
output_goes_to_gpu = all([c[0] != "output" and
output_goes_to_gpu = any([c[0] != "output" and
isinstance(c[0].op, GpuFromHost)
for c in out.clients])
if input_is_from_gpu and output_used and not output_goes_to_gpu:
if input_is_from_gpu:
# The op should be applied on the GPU version of the input
new_inputs.append(inp.owner.inputs[0])
input_transfered.append(True)
elif not input_is_from_gpu and output_used and output_goes_to_gpu:
elif output_goes_to_gpu:
# The input should be transfered to the gpu
new_inputs.append(gpu_from_host(inp))
input_transfered.append(True)
else:
# Both are on the gpu or on the host. No transfer is required.
# No transfer is required.
new_inputs.append(inp)
input_transfered.append(False)
......@@ -378,12 +378,7 @@ def local_gpu_pdbbreakpoint_op(node):
return False
# Apply the op on the new inputs
new_op_outputs = node.op(*new_inputs)
# Ensure that new_op_outputs is a list of outputs (in case the op has
# only one output)
if not isinstance(new_op_outputs, list):
new_op_outputs = [new_op_outputs]
new_op_outputs = node.op(*new_inputs, return_list=True)
# Propagate the transfer to the gpu through the outputs that require
# it
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论