提交 ca5b9581 authored 作者: Frédéric Bastien's avatar Frédéric Bastien 提交者: GitHub

Merge pull request #6416 from ReyhaneAskari/gpu_stack_follow_up

Gpu stack follow up
...@@ -424,7 +424,8 @@ class GraphToGPU(Optimizer): ...@@ -424,7 +424,8 @@ class GraphToGPU(Optimizer):
outputs = [] outputs = []
if isinstance(new_ops, theano.Op): if isinstance(new_ops, theano.Op):
outputs = new_ops(*[mapping[i] for i in node.inputs], return_list=True) with inherit_stack_trace(node.outputs):
outputs = new_ops(*[mapping[i] for i in node.inputs], return_list=True)
elif not new_ops: elif not new_ops:
newnode = node.clone_with_new_inputs([mapping.get(i) for i in node.inputs]) newnode = node.clone_with_new_inputs([mapping.get(i) for i in node.inputs])
outputs = newnode.outputs outputs = newnode.outputs
...@@ -904,10 +905,11 @@ def gpu_print_wrapper(op, cnda): ...@@ -904,10 +905,11 @@ def gpu_print_wrapper(op, cnda):
@register_opt2([tensor.printing.Print], 'fast_compile') @register_opt2([tensor.printing.Print], 'fast_compile')
def local_gpua_print_op(op, context_name, inputs, outputs): def local_gpua_print_op(op, context_name, inputs, outputs):
x, = inputs x, = inputs
gpu_x = as_gpuarray_variable(x, context_name=context_name) with inherit_stack_trace(outputs):
new_op = op.__class__(global_fn=gpu_print_wrapper) gpu_x = as_gpuarray_variable(x, context_name=context_name)
new_op.old_op = op new_op = op.__class__(global_fn=gpu_print_wrapper)
return new_op(gpu_x) new_op.old_op = op
return new_op(gpu_x)
@register_opt('fast_compile') @register_opt('fast_compile')
......
...@@ -40,9 +40,6 @@ def _check_stack_trace(thing): ...@@ -40,9 +40,6 @@ def _check_stack_trace(thing):
theano.tensor.elemwise.Elemwise, theano.tensor.elemwise.Elemwise,
theano.ifelse.IfElse, theano.ifelse.IfElse,
GpuFromHost, HostFromGpu, GpuFromHost, HostFromGpu,
basic_ops.GpuContiguous,
GpuElemwise,
theano.printing.Print,
)) ))
return check_stack_trace(thing, ops_to_check=_ops_to_check, return check_stack_trace(thing, ops_to_check=_ops_to_check,
bug_print="ignore") bug_print="ignore")
......
...@@ -310,7 +310,7 @@ class Print(Op): ...@@ -310,7 +310,7 @@ class Print(Op):
self.global_fn = global_fn self.global_fn = global_fn
def make_node(self, xin): def make_node(self, xin):
xout = xin.type.make_variable() xout = xin.type()
return Apply(op=self, inputs=[xin], outputs=[xout]) return Apply(op=self, inputs=[xin], outputs=[xout])
def perform(self, node, inputs, output_storage): def perform(self, node, inputs, output_storage):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论