提交 c1fd66d6 authored 作者: Reyhane Askari's avatar Reyhane Askari

moved inherit_stack_trace inside alpha_merge decorator

上级 6a13ffda
...@@ -3406,21 +3406,18 @@ optdb.register('local_dnna_conv_inplace', ...@@ -3406,21 +3406,18 @@ optdb.register('local_dnna_conv_inplace',
@register_opt('cudnn') @register_opt('cudnn')
@alpha_merge(GpuDnnConv, alpha_in=4, beta_in=5) @alpha_merge(GpuDnnConv, alpha_in=4, beta_in=5)
def local_dnn_conv_alpha_merge(node, *inputs): def local_dnn_conv_alpha_merge(node, *inputs):
with inherit_stack_trace(node.outputs):
return [GpuDnnConv(algo=node.op.algo, num_groups=node.op.num_groups)(*inputs)] return [GpuDnnConv(algo=node.op.algo, num_groups=node.op.num_groups)(*inputs)]
@register_opt('cudnn') @register_opt('cudnn')
@alpha_merge(GpuDnnConvGradW, alpha_in=4, beta_in=5) @alpha_merge(GpuDnnConvGradW, alpha_in=4, beta_in=5)
def local_dnn_convw_alpha_merge(node, *inputs): def local_dnn_convw_alpha_merge(node, *inputs):
with inherit_stack_trace(node.outputs):
return [GpuDnnConvGradW(algo=node.op.algo, num_groups=node.op.num_groups)(*inputs)] return [GpuDnnConvGradW(algo=node.op.algo, num_groups=node.op.num_groups)(*inputs)]
@register_opt('cudnn') @register_opt('cudnn')
@alpha_merge(GpuDnnConvGradI, alpha_in=4, beta_in=5) @alpha_merge(GpuDnnConvGradI, alpha_in=4, beta_in=5)
def local_dnn_convi_alpha_merge(node, *inputs): def local_dnn_convi_alpha_merge(node, *inputs):
with inherit_stack_trace(node.outputs):
return [GpuDnnConvGradI(algo=node.op.algo, num_groups=node.op.num_groups)(*inputs)] return [GpuDnnConvGradI(algo=node.op.algo, num_groups=node.op.num_groups)(*inputs)]
......
...@@ -185,6 +185,7 @@ def alpha_merge(cls, alpha_in, beta_in): ...@@ -185,6 +185,7 @@ def alpha_merge(cls, alpha_in, beta_in):
except NotScalarConstantError: except NotScalarConstantError:
inputs[alpha_in] = lr * targ.inputs[alpha_in] inputs[alpha_in] = lr * targ.inputs[alpha_in]
inputs[beta_in] = lr * targ.inputs[beta_in] inputs[beta_in] = lr * targ.inputs[beta_in]
with inherit_stack_trace(node.outputs):
return maker(targ, *inputs) return maker(targ, *inputs)
return opt return opt
return wrapper return wrapper
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论