提交 6a13ffda authored 作者: Reyhane Askari's avatar Reyhane Askari

moved inherit_stack_trace inside inplace_allocempty decorator

上级 a1f21688
......@@ -3383,20 +3383,17 @@ def local_abstractconv_gi_cudnn(node):
@inplace_allocempty(GpuDnnConv, 2)
def local_dnn_conv_inplace(node, inputs):
with inherit_stack_trace(node.outputs):
return [GpuDnnConv(algo=node.op.algo, inplace=True, num_groups=node.op.num_groups)(*inputs)]
return [GpuDnnConv(algo=node.op.algo, inplace=True, num_groups=node.op.num_groups)(*inputs)]
@inplace_allocempty(GpuDnnConvGradW, 2)
def local_dnn_convgw_inplace(node, inputs):
with inherit_stack_trace(node.outputs):
return [GpuDnnConvGradW(algo=node.op.algo, inplace=True, num_groups=node.op.num_groups)(*inputs)]
return [GpuDnnConvGradW(algo=node.op.algo, inplace=True, num_groups=node.op.num_groups)(*inputs)]
@inplace_allocempty(GpuDnnConvGradI, 2)
def local_dnn_convgi_inplace(node, inputs):
with inherit_stack_trace(node.outputs):
return [GpuDnnConvGradI(algo=node.op.algo, inplace=True, num_groups=node.op.num_groups)(*inputs)]
return [GpuDnnConvGradI(algo=node.op.algo, inplace=True, num_groups=node.op.num_groups)(*inputs)]
optdb.register('local_dnna_conv_inplace',
tensor.opt.in2out(local_dnn_conv_inplace,
......
......@@ -5,6 +5,7 @@ import numpy as np
from theano import tensor, scalar as scal, Constant
from theano.gof import local_optimizer
from theano.gof.opt import inherit_stack_trace
from theano.tensor import (DimShuffle, get_scalar_constant_value,
NotScalarConstantError)
......@@ -326,7 +327,8 @@ def inplace_allocempty(op, idx):
len(alloc.clients) > 1):
alloc_op = GpuAllocEmpty(alloc.owner.op.dtype, alloc.owner.op.context_name)
inputs[idx] = alloc_op(*alloc.owner.inputs)
return maker(node, inputs)
with inherit_stack_trace(node.outputs):
return maker(node, inputs)
return opt
return wrapper
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论