提交 1427753f authored 作者: --global's avatar --global

Fix Gpuarray optimization for PdbBreakpoint

上级 17b360f2
...@@ -342,8 +342,9 @@ def local_gpu_pdbbreakpoint_op(node): ...@@ -342,8 +342,9 @@ def local_gpu_pdbbreakpoint_op(node):
new_inputs = node.inputs[:1] new_inputs = node.inputs[:1]
input_transfered = [] input_transfered = []
# Propagate the transfers to gpu through the PdbBreakpoint node # Go through the monitored variables, only transfering on GPU those
# while leaving the PdbBreakpoint node fully on the host # for which the input comes from the GPU or the output will be
# transfered on the GPU.
nb_monitored_vars = len(node.outputs) nb_monitored_vars = len(node.outputs)
for i in range(nb_monitored_vars): for i in range(nb_monitored_vars):
...@@ -352,23 +353,22 @@ def local_gpu_pdbbreakpoint_op(node): ...@@ -352,23 +353,22 @@ def local_gpu_pdbbreakpoint_op(node):
input_is_from_gpu = (inp.owner and input_is_from_gpu = (inp.owner and
isinstance(inp.owner.op, HostFromGpu)) isinstance(inp.owner.op, HostFromGpu))
output_used = len(out.clients) > 0 output_goes_to_gpu = any([c[0] != "output" and
output_goes_to_gpu = all([c[0] != "output" and
isinstance(c[0].op, GpuFromHost) isinstance(c[0].op, GpuFromHost)
for c in out.clients]) for c in out.clients])
if input_is_from_gpu and output_used and not output_goes_to_gpu: if input_is_from_gpu:
# The op should be applied on the GPU version of the input # The op should be applied on the GPU version of the input
new_inputs.append(inp.owner.inputs[0]) new_inputs.append(inp.owner.inputs[0])
input_transfered.append(True) input_transfered.append(True)
elif not input_is_from_gpu and output_used and output_goes_to_gpu: elif output_goes_to_gpu:
# The input should be transfered to the gpu # The input should be transfered to the gpu
new_inputs.append(gpu_from_host(inp)) new_inputs.append(gpu_from_host(inp))
input_transfered.append(True) input_transfered.append(True)
else: else:
# Both are on the gpu or on the host. No transfer is required. # No transfer is required.
new_inputs.append(inp) new_inputs.append(inp)
input_transfered.append(False) input_transfered.append(False)
...@@ -378,12 +378,7 @@ def local_gpu_pdbbreakpoint_op(node): ...@@ -378,12 +378,7 @@ def local_gpu_pdbbreakpoint_op(node):
return False return False
# Apply the op on the new inputs # Apply the op on the new inputs
new_op_outputs = node.op(*new_inputs) new_op_outputs = node.op(*new_inputs, return_list=True)
# Ensure that new_op_outputs is a list of outputs (in case the op has
# only one output)
if not isinstance(new_op_outputs, list):
new_op_outputs = [new_op_outputs]
# Propagate the transfer to the gpu through the outputs that require # Propagate the transfer to the gpu through the outputs that require
# it # it
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论