提交 3ce8f1cb authored 作者: Frédéric Bastien's avatar Frédéric Bastien 提交者: GitHub

Merge pull request #6362 from affanv14/3dmetafix

fix 3d optimizers for grouped convolutions and add more test cases
......@@ -3258,6 +3258,7 @@ def local_abstractconv3d_cudnn_alt(node):
border_mode = node.op.border_mode
subsample = node.op.subsample
filter_dilation = node.op.filter_dilation
num_groups = node.op.num_groups
precision = get_precision(None, [inp1, inp2])
if node.op.filter_flip:
......@@ -3266,7 +3267,7 @@ def local_abstractconv3d_cudnn_alt(node):
conv_mode = 'cross'
if isinstance(op, AbstractConv3d):
if border_mode == 'half' or subsample != (1, 1, 1):
if border_mode == 'half' or subsample != (1, 1, 1) or num_groups > 1:
return None
if border_mode == 'full':
direction_hint = 'bprop inputs'
......@@ -3284,7 +3285,7 @@ def local_abstractconv3d_cudnn_alt(node):
elif isinstance(op, AbstractConv3d_gradWeights):
if(border_mode == 'valid' and subsample == (1, 1, 1) and
filter_dilation == (1, 1, 1)):
filter_dilation == (1, 1, 1) and num_groups == 1):
img = gpu_contiguous(inp1)
topgrad = gpu_contiguous(inp2)
ctx_name = infer_context_name(img, topgrad)
......@@ -3315,7 +3316,7 @@ def local_abstractconv3d_cudnn_alt(node):
return None
elif isinstance(op, AbstractConv3d_gradInputs):
if border_mode == 'valid' and subsample == (1, 1, 1):
if border_mode == 'valid' and subsample == (1, 1, 1) and num_groups == 1:
kerns = gpu_contiguous(inp1.dimshuffle(1, 0, 2, 3, 4))
topgrad = gpu_contiguous(inp2)
ctx_name = infer_context_name(kerns, topgrad)
......
......@@ -1842,8 +1842,10 @@ def local_abstractconv3d_alt(node):
border_mode = node.op.border_mode
subsample = node.op.subsample
filter_dilation = node.op.filter_dilation
num_groups = node.op.num_groups
if ((border_mode == 'full') and (subsample == (1, 1, 1))):
if((border_mode == 'full') and (subsample == (1, 1, 1)) and
(num_groups == 1)):
if not node.op.filter_flip:
kern = kern[:, :, ::-1, ::-1, ::-1]
kern = kern.dimshuffle(1, 0, 2, 3, 4)
......@@ -1853,7 +1855,7 @@ def local_abstractconv3d_alt(node):
gpu_contiguous(kern), gpu_contiguous(img))
elif(subsample == (1, 1, 1) and filter_dilation == (1, 1, 1) and
border_mode == 'valid'):
border_mode == 'valid' and num_groups == 1):
if node.op.filter_flip:
kern = kern[:, :, ::-1, ::-1, ::-1]
rval = GpuCorr3dMM_gradWeights(border_mode,
......@@ -1881,8 +1883,10 @@ def local_abstractconv3d2d(node):
border_mode = node.op.border_mode
subsample = node.op.subsample
filter_dilation = node.op.filter_dilation
num_groups = node.op.num_groups
if subsample == (1, 1, 1) and filter_dilation == (1, 1, 1):
if(subsample == (1, 1, 1) and filter_dilation == (1, 1, 1) and
num_groups == 1):
reorder_array = [0, 2, 1, 3, 4]
rval = conv3d2d.conv3d(gpu_contiguous(img.dimshuffle(*reorder_array)),
gpu_contiguous(kern.dimshuffle(*reorder_array)),
......@@ -1968,8 +1972,10 @@ def local_abstractconv3d_gemm_gradweights_alt(node):
border_mode = node.op.border_mode
subsample = node.op.subsample
filter_dilation = node.op.filter_dilation
num_groups = node.op.num_groups
if border_mode == 'valid' and subsample == (1, 1, 1) and filter_dilation == (1, 1, 1):
if(border_mode == 'valid' and subsample == (1, 1, 1) and
filter_dilation == (1, 1, 1) and num_groups == 1):
rval = GpuCorr3dMM(border_mode,
subsample,
filter_dilation)(
......@@ -2091,8 +2097,10 @@ def local_abstractconv3d_gradinputs_gemm_alt(node):
border_mode = node.op.border_mode
subsample = node.op.subsample
filter_dilation = node.op.filter_dilation
num_groups = node.op.num_groups
if border_mode == 'valid' and subsample == (1, 1, 1):
if(border_mode == 'valid' and subsample == (1, 1, 1) and
num_groups == 1):
if not node.op.filter_flip:
kern = kern[:, :, ::-1, ::-1, ::-1]
rval = GpuCorr3dMM(border_mode='full',
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论