提交 e102a900 authored 作者: Reyhane Askari's avatar Reyhane Askari

moved with inside if statements

上级 20f39e45
...@@ -3133,12 +3133,13 @@ def local_abstractconv_cudnn(node): ...@@ -3133,12 +3133,13 @@ def local_abstractconv_cudnn(node):
ctx = infer_context_name(*node.inputs) ctx = infer_context_name(*node.inputs)
if not isinstance(node.inputs[0].type, GpuArrayType): if not isinstance(node.inputs[0].type, GpuArrayType):
return return
with inherit_stack_trace(node.outputs): if node.op.unshared:
if node.op.unshared: return None
return None if isinstance(node.op, AbstractConv2d):
if isinstance(node.op, AbstractConv2d): with inherit_stack_trace(node.outputs):
return local_abstractconv_cudnn_graph(node.op, ctx, node.inputs, node.outputs) return local_abstractconv_cudnn_graph(node.op, ctx, node.inputs, node.outputs)
elif isinstance(node.op, AbstractConv3d): elif isinstance(node.op, AbstractConv3d):
with inherit_stack_trace(node.outputs):
return local_abstractconv3d_cudnn_graph(node.op, ctx, node.inputs, node.outputs) return local_abstractconv3d_cudnn_graph(node.op, ctx, node.inputs, node.outputs)
...@@ -3358,12 +3359,13 @@ def local_abstractconv_gw_cudnn(node): ...@@ -3358,12 +3359,13 @@ def local_abstractconv_gw_cudnn(node):
ctx = infer_context_name(*node.inputs) ctx = infer_context_name(*node.inputs)
if not isinstance(node.inputs[0].type, GpuArrayType): if not isinstance(node.inputs[0].type, GpuArrayType):
return return
with inherit_stack_trace(node.outputs): if node.op.unshared:
if node.op.unshared: return None
return None if isinstance(node.op, AbstractConv2d_gradWeights):
if isinstance(node.op, AbstractConv2d_gradWeights): with inherit_stack_trace(node.outputs):
return local_abstractconv_cudnn_graph(node.op, ctx, node.inputs, node.outputs) return local_abstractconv_cudnn_graph(node.op, ctx, node.inputs, node.outputs)
elif isinstance(node.op, AbstractConv3d_gradWeights): elif isinstance(node.op, AbstractConv3d_gradWeights):
with inherit_stack_trace(node.outputs):
return local_abstractconv3d_cudnn_graph(node.op, ctx, node.inputs, node.outputs) return local_abstractconv3d_cudnn_graph(node.op, ctx, node.inputs, node.outputs)
...@@ -3372,12 +3374,13 @@ def local_abstractconv_gi_cudnn(node): ...@@ -3372,12 +3374,13 @@ def local_abstractconv_gi_cudnn(node):
ctx = infer_context_name(*node.inputs) ctx = infer_context_name(*node.inputs)
if not isinstance(node.inputs[0].type, GpuArrayType): if not isinstance(node.inputs[0].type, GpuArrayType):
return return
with inherit_stack_trace(node.outputs): if node.op.unshared:
if node.op.unshared: return None
return None if isinstance(node.op, AbstractConv2d_gradInputs):
if isinstance(node.op, AbstractConv2d_gradInputs): with inherit_stack_trace(node.outputs):
return local_abstractconv_cudnn_graph(node.op, ctx, node.inputs, node.outputs) return local_abstractconv_cudnn_graph(node.op, ctx, node.inputs, node.outputs)
elif isinstance(node.op, AbstractConv3d_gradInputs): elif isinstance(node.op, AbstractConv3d_gradInputs):
with inherit_stack_trace(node.outputs):
return local_abstractconv3d_cudnn_graph(node.op, ctx, node.inputs, node.outputs) return local_abstractconv3d_cudnn_graph(node.op, ctx, node.inputs, node.outputs)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论