提交 2e10599b authored 作者: Reyhane Askari's avatar Reyhane Askari

moved inherit_stack_trace inside output_merge decorator

上级 c1fd66d6
...@@ -3424,25 +3424,22 @@ def local_dnn_convi_alpha_merge(node, *inputs): ...@@ -3424,25 +3424,22 @@ def local_dnn_convi_alpha_merge(node, *inputs):
@register_opt('cudnn') @register_opt('cudnn')
@output_merge(GpuDnnConv, alpha_in=4, beta_in=5, out_in=2) @output_merge(GpuDnnConv, alpha_in=4, beta_in=5, out_in=2)
def local_dnn_conv_output_merge(node, *inputs): def local_dnn_conv_output_merge(node, *inputs):
with inherit_stack_trace(node.outputs): inputs = inputs[0:2] + (gpu_contiguous(inputs[2]),) + inputs[3:]
inputs = inputs[0:2] + (gpu_contiguous(inputs[2]),) + inputs[3:] return [GpuDnnConv(algo=node.op.algo)(*inputs)]
return [GpuDnnConv(algo=node.op.algo)(*inputs)]
@register_opt('cudnn') @register_opt('cudnn')
@output_merge(GpuDnnConvGradW, alpha_in=4, beta_in=5, out_in=2) @output_merge(GpuDnnConvGradW, alpha_in=4, beta_in=5, out_in=2)
def local_dnn_convw_output_merge(node, *inputs): def local_dnn_convw_output_merge(node, *inputs):
with inherit_stack_trace(node.outputs): inputs = inputs[0:2] + (gpu_contiguous(inputs[2]),) + inputs[3:]
inputs = inputs[0:2] + (gpu_contiguous(inputs[2]),) + inputs[3:] return [GpuDnnConvGradW(algo=node.op.algo)(*inputs)]
return [GpuDnnConvGradW(algo=node.op.algo)(*inputs)]
@register_opt('cudnn') @register_opt('cudnn')
@output_merge(GpuDnnConvGradI, alpha_in=4, beta_in=5, out_in=2) @output_merge(GpuDnnConvGradI, alpha_in=4, beta_in=5, out_in=2)
def local_dnn_convi_output_merge(node, *inputs): def local_dnn_convi_output_merge(node, *inputs):
with inherit_stack_trace(node.outputs): inputs = inputs[0:2] + (gpu_contiguous(inputs[2]),) + inputs[3:]
inputs = inputs[0:2] + (gpu_contiguous(inputs[2]),) + inputs[3:] return [GpuDnnConvGradI(algo=node.op.algo)(*inputs)]
return [GpuDnnConvGradI(algo=node.op.algo)(*inputs)]
def local_gpua_pool_dnn_alternative(op, ctx_name, inputs, outputs): def local_gpua_pool_dnn_alternative(op, ctx_name, inputs, outputs):
......
...@@ -274,7 +274,8 @@ def output_merge(cls, alpha_in, beta_in, out_in): ...@@ -274,7 +274,8 @@ def output_merge(cls, alpha_in, beta_in, out_in):
inputs = list(targ.inputs) inputs = list(targ.inputs)
inputs[out_in] = W inputs[out_in] = W
inputs[beta_in] = _one.clone() inputs[beta_in] = _one.clone()
return maker(targ, *inputs) with inherit_stack_trace(node.outputs):
return maker(targ, *inputs)
return opt return opt
return wrapper return wrapper
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论