提交 30335bbd authored 作者: Reyhane Askari's avatar Reyhane Askari

fix for PdbBreakpoint stacktrace

上级 404cea07
...@@ -962,18 +962,19 @@ def local_gpu_pdbbreakpoint_op(node): ...@@ -962,18 +962,19 @@ def local_gpu_pdbbreakpoint_op(node):
return False return False
# Apply the op on the new inputs # Apply the op on the new inputs
new_op_outputs = node.op(*new_inputs, return_list=True) with inherit_stack_trace(node.outputs):
new_op_outputs = node.op(*new_inputs, return_list=True)
# Propagate the transfer to the gpu through the outputs that require
# it # Propagate the transfer to the gpu through the outputs that require
new_outputs = [] # it
for i in range(len(new_op_outputs)): new_outputs = []
if input_transfered[i]: for i in range(len(new_op_outputs)):
new_outputs.append(new_op_outputs[i].transfer('cpu')) if input_transfered[i]:
else: new_outputs.append(new_op_outputs[i].transfer('cpu'))
new_outputs.append(new_op_outputs[i]) else:
new_outputs.append(new_op_outputs[i])
return new_outputs return new_outputs
return False return False
......
...@@ -44,7 +44,6 @@ def _check_stack_trace(thing): ...@@ -44,7 +44,6 @@ def _check_stack_trace(thing):
basic_ops.GpuContiguous, basic_ops.GpuContiguous,
GpuElemwise, GpuElemwise,
theano.printing.Print, theano.printing.Print,
PdbBreakpoint,
)) ))
return check_stack_trace(thing, ops_to_check=_ops_to_check, return check_stack_trace(thing, ops_to_check=_ops_to_check,
bug_print="ignore") bug_print="ignore")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论