Unverified 提交 727477b5 authored 作者: abergeron's avatar abergeron 提交者: GitHub

Merge pull request #6624 from abergeron/fix_dnn_conv_groups

Make sure to always pass num_groups
...@@ -1108,8 +1108,9 @@ def dnn_conv(img, kerns, border_mode='valid', subsample=(1, 1), dilation=(1, 1), ...@@ -1108,8 +1108,9 @@ def dnn_conv(img, kerns, border_mode='valid', subsample=(1, 1), dilation=(1, 1),
out = GpuAllocEmpty(dtype=img.dtype, context_name=ctx_name)(*out_shp) out = GpuAllocEmpty(dtype=img.dtype, context_name=ctx_name)(*out_shp)
precision, _ = get_precision(precision, [img, kerns], for_grad=True) precision, _ = get_precision(precision, [img, kerns], for_grad=True)
desc = GpuDnnConvDesc(border_mode='valid', subsample=(1, 1), dilation=(1, 1), desc = GpuDnnConvDesc(border_mode='valid', subsample=(1, 1), dilation=(1, 1),
num_groups=num_groups,
conv_mode='cross', precision=precision)(out.shape) conv_mode='cross', precision=precision)(out.shape)
conv = GpuDnnConvGradW()(img, kerns, out, desc) conv = GpuDnnConvGradW(num_groups=num_groups)(img, kerns, out, desc)
return as_gpuarray_variable(conv.dimshuffle(1, 0, 2, 3), ctx_name) return as_gpuarray_variable(conv.dimshuffle(1, 0, 2, 3), ctx_name)
elif (border_mode == 'full' and subsample == (1, 1) and elif (border_mode == 'full' and subsample == (1, 1) and
...@@ -1128,8 +1129,9 @@ def dnn_conv(img, kerns, border_mode='valid', subsample=(1, 1), dilation=(1, 1), ...@@ -1128,8 +1129,9 @@ def dnn_conv(img, kerns, border_mode='valid', subsample=(1, 1), dilation=(1, 1),
out = GpuAllocEmpty(dtype=img.dtype, context_name=ctx_name)(*out_shp) out = GpuAllocEmpty(dtype=img.dtype, context_name=ctx_name)(*out_shp)
precision, _ = get_precision(precision, [img, kerns], for_grad=True) precision, _ = get_precision(precision, [img, kerns], for_grad=True)
desc = GpuDnnConvDesc(border_mode='valid', subsample=(1, 1), dilation=dilation, desc = GpuDnnConvDesc(border_mode='valid', subsample=(1, 1), dilation=dilation,
num_groups=num_groups,
conv_mode=conv_mode, precision=precision)(kerns.shape) conv_mode=conv_mode, precision=precision)(kerns.shape)
return GpuDnnConvGradI()(kerns, img, out, desc) return GpuDnnConvGradI(num_groups=num_groups)(kerns, img, out, desc)
# Standard case: We use GpuDnnConv with suitable padding. # Standard case: We use GpuDnnConv with suitable padding.
return _dnn_conv(img, kerns, algo=algo, border_mode=border_mode, subsample=subsample, dilation=dilation, return _dnn_conv(img, kerns, algo=algo, border_mode=border_mode, subsample=subsample, dilation=dilation,
...@@ -1213,8 +1215,9 @@ def dnn_conv3d(img, kerns, border_mode='valid', subsample=(1, 1, 1), dilation=(1 ...@@ -1213,8 +1215,9 @@ def dnn_conv3d(img, kerns, border_mode='valid', subsample=(1, 1, 1), dilation=(1
out = GpuAllocEmpty(dtype=img.dtype, context_name=ctx_name)(*out_shp) out = GpuAllocEmpty(dtype=img.dtype, context_name=ctx_name)(*out_shp)
precision, _ = get_precision(precision, [img, kerns], for_grad=True) precision, _ = get_precision(precision, [img, kerns], for_grad=True)
desc = GpuDnnConvDesc(border_mode='valid', subsample=(1, 1, 1), dilation=(1, 1, 1), desc = GpuDnnConvDesc(border_mode='valid', subsample=(1, 1, 1), dilation=(1, 1, 1),
num_groups=num_groups,
conv_mode='cross', precision=precision)(out.shape) conv_mode='cross', precision=precision)(out.shape)
conv = GpuDnnConvGradW()(img, kerns, out, desc) conv = GpuDnnConvGradW(num_groups=num_groups)(img, kerns, out, desc)
return as_gpuarray_variable(conv.dimshuffle(1, 0, 2, 3, 4), ctx_name) return as_gpuarray_variable(conv.dimshuffle(1, 0, 2, 3, 4), ctx_name)
elif (border_mode == 'full' and subsample == (1, 1, 1) and elif (border_mode == 'full' and subsample == (1, 1, 1) and
...@@ -1234,8 +1237,9 @@ def dnn_conv3d(img, kerns, border_mode='valid', subsample=(1, 1, 1), dilation=(1 ...@@ -1234,8 +1237,9 @@ def dnn_conv3d(img, kerns, border_mode='valid', subsample=(1, 1, 1), dilation=(1
out = GpuAllocEmpty(dtype=img.dtype, context_name=ctx_name)(*out_shp) out = GpuAllocEmpty(dtype=img.dtype, context_name=ctx_name)(*out_shp)
precision, _ = get_precision(precision, [img, kerns], for_grad=True) precision, _ = get_precision(precision, [img, kerns], for_grad=True)
desc = GpuDnnConvDesc(border_mode='valid', subsample=(1, 1, 1), dilation=dilation, desc = GpuDnnConvDesc(border_mode='valid', subsample=(1, 1, 1), dilation=dilation,
num_groups=num_groups,
conv_mode=conv_mode, precision=precision)(kerns.shape) conv_mode=conv_mode, precision=precision)(kerns.shape)
return GpuDnnConvGradI()(kerns, img, out, desc) return GpuDnnConvGradI(num_groups=num_groups)(kerns, img, out, desc)
# Standard case: We use GpuDnnConv with suitable padding. # Standard case: We use GpuDnnConv with suitable padding.
return _dnn_conv(img, kerns, algo=algo, border_mode=border_mode, subsample=subsample, dilation=dilation, return _dnn_conv(img, kerns, algo=algo, border_mode=border_mode, subsample=subsample, dilation=dilation,
...@@ -3435,9 +3439,11 @@ def local_abstractconv3d_cudnn_alt(node): ...@@ -3435,9 +3439,11 @@ def local_abstractconv3d_cudnn_alt(node):
subsample=subsample, subsample=subsample,
dilation=filter_dilation, dilation=filter_dilation,
conv_mode='cross', conv_mode='cross',
num_groups=num_groups,
precision=precision)(out.shape) precision=precision)(out.shape)
conv = GpuDnnConv(algo=None)(img, topgrad, out, desc) conv = GpuDnnConv(algo=None, num_groups=num_groups)(
img, topgrad, out, desc)
if conv_mode == 'conv': if conv_mode == 'conv':
conv = conv[:, :, ::-1, ::-1, ::-1] conv = conv[:, :, ::-1, ::-1, ::-1]
...@@ -3455,6 +3461,7 @@ def local_abstractconv3d_cudnn_alt(node): ...@@ -3455,6 +3461,7 @@ def local_abstractconv3d_cudnn_alt(node):
subsample=subsample, subsample=subsample,
dilation=filter_dilation, dilation=filter_dilation,
conv_mode=conv_mode, conv_mode=conv_mode,
num_groups=num_groups,
precision=precision)(kerns.shape) precision=precision)(kerns.shape)
tshape = [shape_i_op(i)(topgrad) for i in range(topgrad.ndim)] tshape = [shape_i_op(i)(topgrad) for i in range(topgrad.ndim)]
...@@ -3467,7 +3474,8 @@ def local_abstractconv3d_cudnn_alt(node): ...@@ -3467,7 +3474,8 @@ def local_abstractconv3d_cudnn_alt(node):
shape = assert_conv_shape(shape) shape = assert_conv_shape(shape)
out = GpuAllocEmpty(dtype=topgrad.dtype, context_name=ctx_name)(*shape) out = GpuAllocEmpty(dtype=topgrad.dtype, context_name=ctx_name)(*shape)
rval = GpuDnnConv(algo=None)(topgrad, kerns, out, desc) rval = GpuDnnConv(algo=None, num_groups=num_groups)(
topgrad, kerns, out, desc)
else: else:
return None return None
...@@ -3554,21 +3562,21 @@ def local_dnn_convi_alpha_merge(node, *inputs): ...@@ -3554,21 +3562,21 @@ def local_dnn_convi_alpha_merge(node, *inputs):
@output_merge(GpuDnnConv, alpha_in=4, beta_in=5, out_in=2) @output_merge(GpuDnnConv, alpha_in=4, beta_in=5, out_in=2)
def local_dnn_conv_output_merge(node, *inputs): def local_dnn_conv_output_merge(node, *inputs):
inputs = inputs[0:2] + (gpu_contiguous(inputs[2]),) + inputs[3:] inputs = inputs[0:2] + (gpu_contiguous(inputs[2]),) + inputs[3:]
return [GpuDnnConv(algo=node.op.algo)(*inputs)] return [GpuDnnConv(algo=node.op.algo, num_groups=node.op.num_groups)(*inputs)]
@register_opt('cudnn') @register_opt('cudnn')
@output_merge(GpuDnnConvGradW, alpha_in=4, beta_in=5, out_in=2) @output_merge(GpuDnnConvGradW, alpha_in=4, beta_in=5, out_in=2)
def local_dnn_convw_output_merge(node, *inputs): def local_dnn_convw_output_merge(node, *inputs):
inputs = inputs[0:2] + (gpu_contiguous(inputs[2]),) + inputs[3:] inputs = inputs[0:2] + (gpu_contiguous(inputs[2]),) + inputs[3:]
return [GpuDnnConvGradW(algo=node.op.algo)(*inputs)] return [GpuDnnConvGradW(algo=node.op.algo, num_groups=node.op.num_groups)(*inputs)]
@register_opt('cudnn') @register_opt('cudnn')
@output_merge(GpuDnnConvGradI, alpha_in=4, beta_in=5, out_in=2) @output_merge(GpuDnnConvGradI, alpha_in=4, beta_in=5, out_in=2)
def local_dnn_convi_output_merge(node, *inputs): def local_dnn_convi_output_merge(node, *inputs):
inputs = inputs[0:2] + (gpu_contiguous(inputs[2]),) + inputs[3:] inputs = inputs[0:2] + (gpu_contiguous(inputs[2]),) + inputs[3:]
return [GpuDnnConvGradI(algo=node.op.algo)(*inputs)] return [GpuDnnConvGradI(algo=node.op.algo, num_groups=node.op.num_groups)(*inputs)]
def local_gpua_pool_dnn_alternative(op, ctx_name, inputs, outputs): def local_gpua_pool_dnn_alternative(op, ctx_name, inputs, outputs):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论