提交 b602f833 authored 作者: Frédéric Bastien's avatar Frédéric Bastien 提交者: GitHub

Merge pull request #6413 from ReyhaneAskari/gpu_stack_follow_up

Gpu stack follow up
...@@ -962,18 +962,19 @@ def local_gpu_pdbbreakpoint_op(node): ...@@ -962,18 +962,19 @@ def local_gpu_pdbbreakpoint_op(node):
return False return False
# Apply the op on the new inputs # Apply the op on the new inputs
new_op_outputs = node.op(*new_inputs, return_list=True) with inherit_stack_trace(node.outputs):
new_op_outputs = node.op(*new_inputs, return_list=True)
# Propagate the transfer to the gpu through the outputs that require
# it # Propagate the transfer to the gpu through the outputs that require
new_outputs = [] # it
for i in range(len(new_op_outputs)): new_outputs = []
if input_transfered[i]: for i in range(len(new_op_outputs)):
new_outputs.append(new_op_outputs[i].transfer('cpu')) if input_transfered[i]:
else: new_outputs.append(new_op_outputs[i].transfer('cpu'))
new_outputs.append(new_op_outputs[i]) else:
new_outputs.append(new_op_outputs[i])
return new_outputs return new_outputs
return False return False
...@@ -2407,8 +2408,8 @@ def local_gpu_elemwise_careduce(node): ...@@ -2407,8 +2408,8 @@ def local_gpu_elemwise_careduce(node):
inp = node.inputs[0].owner.inputs[0] inp = node.inputs[0].owner.inputs[0]
props = node.op._props_dict() props = node.op._props_dict()
props["pre_scalar_op"] = scalar.basic.sqr props["pre_scalar_op"] = scalar.basic.sqr
out = GpuCAReduceCuda(**props)(inp)
with inherit_stack_trace(node.outputs): with inherit_stack_trace(node.outputs):
out = GpuCAReduceCuda(**props)(inp)
return [out] return [out]
......
...@@ -40,11 +40,9 @@ def _check_stack_trace(thing): ...@@ -40,11 +40,9 @@ def _check_stack_trace(thing):
theano.tensor.elemwise.Elemwise, theano.tensor.elemwise.Elemwise,
theano.ifelse.IfElse, theano.ifelse.IfElse,
GpuFromHost, HostFromGpu, GpuFromHost, HostFromGpu,
GpuCAReduceCuda,
basic_ops.GpuContiguous, basic_ops.GpuContiguous,
GpuElemwise, GpuElemwise,
theano.printing.Print, theano.printing.Print,
PdbBreakpoint,
)) ))
return check_stack_trace(thing, ops_to_check=_ops_to_check, return check_stack_trace(thing, ops_to_check=_ops_to_check,
bug_print="ignore") bug_print="ignore")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论