提交 8411ea02 authored 作者: affanv14's avatar affanv14

optimization fix

上级 c749cbc6
...@@ -3054,7 +3054,7 @@ def local_abstractconv_cudnn_graph(op, context_name, inputs, outputs): ...@@ -3054,7 +3054,7 @@ def local_abstractconv_cudnn_graph(op, context_name, inputs, outputs):
conv_mode=conv_mode, conv_mode=conv_mode,
num_groups=op.num_groups) num_groups=op.num_groups)
elif isinstance(op, AbstractConv2d_gradWeights): elif isinstance(op, AbstractConv2d_gradWeights):
shape = (inp2.shape[1], inp1.shape[1], shape = (inp2.shape[1], inp1.shape[1] // op.num_groups,
inputs[2][0], inputs[2][1]) inputs[2][0], inputs[2][1])
rval = dnn_gradweight(inp1, inp2, shape, rval = dnn_gradweight(inp1, inp2, shape,
border_mode=op.border_mode, border_mode=op.border_mode,
...@@ -3063,7 +3063,7 @@ def local_abstractconv_cudnn_graph(op, context_name, inputs, outputs): ...@@ -3063,7 +3063,7 @@ def local_abstractconv_cudnn_graph(op, context_name, inputs, outputs):
conv_mode=conv_mode, conv_mode=conv_mode,
num_groups=op.num_groups) num_groups=op.num_groups)
elif isinstance(op, AbstractConv2d_gradInputs): elif isinstance(op, AbstractConv2d_gradInputs):
shape = (inp2.shape[0], inp1.shape[1], shape = (inp2.shape[0], inp1.shape[1] * op.num_groups,
inputs[2][0], inputs[2][1]) inputs[2][0], inputs[2][1])
rval = dnn_gradinput(inp1, inp2, shape, rval = dnn_gradinput(inp1, inp2, shape,
border_mode=op.border_mode, border_mode=op.border_mode,
...@@ -3103,7 +3103,7 @@ def local_abstractconv3d_cudnn_graph(op, context_name, inputs, outputs): ...@@ -3103,7 +3103,7 @@ def local_abstractconv3d_cudnn_graph(op, context_name, inputs, outputs):
conv_mode=conv_mode, conv_mode=conv_mode,
num_groups=op.num_groups) num_groups=op.num_groups)
elif isinstance(op, AbstractConv3d_gradWeights): elif isinstance(op, AbstractConv3d_gradWeights):
shape = (inp2.shape[1], inp1.shape[1], shape = (inp2.shape[1], inp1.shape[1] // op.num_groups,
inputs[2][0], inputs[2][1], inputs[2][2]) inputs[2][0], inputs[2][1], inputs[2][2])
rval = dnn_gradweight3d(inp1, inp2, shape, rval = dnn_gradweight3d(inp1, inp2, shape,
border_mode=op.border_mode, border_mode=op.border_mode,
...@@ -3112,7 +3112,7 @@ def local_abstractconv3d_cudnn_graph(op, context_name, inputs, outputs): ...@@ -3112,7 +3112,7 @@ def local_abstractconv3d_cudnn_graph(op, context_name, inputs, outputs):
conv_mode=conv_mode, conv_mode=conv_mode,
num_groups=op.num_groups) num_groups=op.num_groups)
elif isinstance(op, AbstractConv3d_gradInputs): elif isinstance(op, AbstractConv3d_gradInputs):
shape = (inp2.shape[0], inp1.shape[1], shape = (inp2.shape[0], inp1.shape[1] * op.num_groups,
inputs[2][0], inputs[2][1], inputs[2][2]) inputs[2][0], inputs[2][1], inputs[2][2])
rval = dnn_gradinput3d(inp1, inp2, shape, rval = dnn_gradinput3d(inp1, inp2, shape,
border_mode=op.border_mode, border_mode=op.border_mode,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论