提交 6a13ffda authored 作者: Reyhane Askari's avatar Reyhane Askari

moved inherit_stack_trace inside inplace_allocempty decorator

上级 a1f21688
...@@ -3383,19 +3383,16 @@ def local_abstractconv_gi_cudnn(node): ...@@ -3383,19 +3383,16 @@ def local_abstractconv_gi_cudnn(node):
@inplace_allocempty(GpuDnnConv, 2) @inplace_allocempty(GpuDnnConv, 2)
def local_dnn_conv_inplace(node, inputs): def local_dnn_conv_inplace(node, inputs):
with inherit_stack_trace(node.outputs):
return [GpuDnnConv(algo=node.op.algo, inplace=True, num_groups=node.op.num_groups)(*inputs)] return [GpuDnnConv(algo=node.op.algo, inplace=True, num_groups=node.op.num_groups)(*inputs)]
@inplace_allocempty(GpuDnnConvGradW, 2) @inplace_allocempty(GpuDnnConvGradW, 2)
def local_dnn_convgw_inplace(node, inputs): def local_dnn_convgw_inplace(node, inputs):
with inherit_stack_trace(node.outputs):
return [GpuDnnConvGradW(algo=node.op.algo, inplace=True, num_groups=node.op.num_groups)(*inputs)] return [GpuDnnConvGradW(algo=node.op.algo, inplace=True, num_groups=node.op.num_groups)(*inputs)]
@inplace_allocempty(GpuDnnConvGradI, 2) @inplace_allocempty(GpuDnnConvGradI, 2)
def local_dnn_convgi_inplace(node, inputs): def local_dnn_convgi_inplace(node, inputs):
with inherit_stack_trace(node.outputs):
return [GpuDnnConvGradI(algo=node.op.algo, inplace=True, num_groups=node.op.num_groups)(*inputs)] return [GpuDnnConvGradI(algo=node.op.algo, inplace=True, num_groups=node.op.num_groups)(*inputs)]
optdb.register('local_dnna_conv_inplace', optdb.register('local_dnna_conv_inplace',
......
...@@ -5,6 +5,7 @@ import numpy as np ...@@ -5,6 +5,7 @@ import numpy as np
from theano import tensor, scalar as scal, Constant from theano import tensor, scalar as scal, Constant
from theano.gof import local_optimizer from theano.gof import local_optimizer
from theano.gof.opt import inherit_stack_trace
from theano.tensor import (DimShuffle, get_scalar_constant_value, from theano.tensor import (DimShuffle, get_scalar_constant_value,
NotScalarConstantError) NotScalarConstantError)
...@@ -326,6 +327,7 @@ def inplace_allocempty(op, idx): ...@@ -326,6 +327,7 @@ def inplace_allocempty(op, idx):
len(alloc.clients) > 1): len(alloc.clients) > 1):
alloc_op = GpuAllocEmpty(alloc.owner.op.dtype, alloc.owner.op.context_name) alloc_op = GpuAllocEmpty(alloc.owner.op.dtype, alloc.owner.op.context_name)
inputs[idx] = alloc_op(*alloc.owner.inputs) inputs[idx] = alloc_op(*alloc.owner.inputs)
with inherit_stack_trace(node.outputs):
return maker(node, inputs) return maker(node, inputs)
return opt return opt
return wrapper return wrapper
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论