提交 632e0a1b authored 作者: affanv14's avatar affanv14

change positioning of num_groups parameter and add algo param to GpuconvGradW and GpuconvGradI

上级 2e0d1d1f
......@@ -868,8 +868,8 @@ class GpuDnnConvGradI(DnnBase):
def dnn_conv(img, kerns, border_mode='valid', subsample=(1, 1), dilation=(1, 1),
num_groups=1, conv_mode='conv', direction_hint=None, workmem=None,
algo=None, precision=None):
conv_mode='conv', direction_hint=None, workmem=None,
algo=None, precision=None, num_groups=1):
"""
GPU convolution using cuDNN from NVIDIA.
......@@ -1111,7 +1111,8 @@ def dnn_conv3d(img, kerns, border_mode='valid', subsample=(1, 1, 1), dilation=(1
def dnn_gradweight(img, topgrad, kerns_shp, border_mode='valid',
subsample=(1, 1), dilation=(1, 1), num_groups=1, conv_mode='conv', precision=None):
subsample=(1, 1), dilation=(1, 1), conv_mode='conv',
precision=None, algo=None, num_groups=1):
"""
TODO: document this
"""
......@@ -1126,7 +1127,7 @@ def dnn_gradweight(img, topgrad, kerns_shp, border_mode='valid',
desc = GpuDnnConvDesc(border_mode=border_mode, subsample=subsample, dilation=dilation,
conv_mode=conv_mode, precision=precision)(kerns_shp)
out = GpuAllocEmpty(dtype=img.dtype, context_name=ctx_name)(*kerns_shp)
return GpuDnnConvGradW(num_groups=num_groups)(img, topgrad, out, desc)
return GpuDnnConvGradW(algo=algo, num_groups=num_groups)(img, topgrad, out, desc)
def dnn_gradweight3d(img, topgrad, kerns_shp, border_mode='valid',
......@@ -1139,7 +1140,8 @@ def dnn_gradweight3d(img, topgrad, kerns_shp, border_mode='valid',
def dnn_gradinput(kerns, topgrad, img_shp, border_mode='valid',
subsample=(1, 1), dilation=(1, 1), num_groups=1, conv_mode='conv', precision=None):
subsample=(1, 1), dilation=(1, 1), conv_mode='conv',
precision=None, algo=None, num_groups=1):
"""
TODO: document this
"""
......@@ -1154,7 +1156,7 @@ def dnn_gradinput(kerns, topgrad, img_shp, border_mode='valid',
desc = GpuDnnConvDesc(border_mode=border_mode, subsample=subsample, dilation=dilation,
conv_mode=conv_mode, precision=precision)(kerns.shape)
out = GpuAllocEmpty(dtype=kerns.dtype, context_name=ctx_name)(*img_shp)
return GpuDnnConvGradI(num_groups=num_groups)(kerns, topgrad, out, desc)
return GpuDnnConvGradI(algo=algo, num_groups=num_groups)(kerns, topgrad, out, desc)
def dnn_gradinput3d(kerns, topgrad, img_shp, border_mode='valid',
......@@ -2679,9 +2681,9 @@ def local_abstractconv_cudnn_graph(op, context_name, inputs, outputs):
border_mode=op.border_mode,
subsample=op.subsample,
dilation=op.filter_dilation,
num_groups=op.num_groups,
direction_hint='forward!',
conv_mode=conv_mode)
conv_mode=conv_mode,
num_groups=op.num_groups)
elif isinstance(op, AbstractConv2d_gradWeights):
shape = (inp2.shape[1], inp1.shape[1],
inputs[2][0], inputs[2][1])
......@@ -2689,8 +2691,8 @@ def local_abstractconv_cudnn_graph(op, context_name, inputs, outputs):
border_mode=op.border_mode,
subsample=op.subsample,
dilation=op.filter_dilation,
num_groups=op.num_groups,
conv_mode=conv_mode)
conv_mode=conv_mode,
num_groups=op.num_groups)
elif isinstance(op, AbstractConv2d_gradInputs):
shape = (inp2.shape[0], inp1.shape[1],
inputs[2][0], inputs[2][1])
......@@ -2698,8 +2700,8 @@ def local_abstractconv_cudnn_graph(op, context_name, inputs, outputs):
border_mode=op.border_mode,
subsample=op.subsample,
dilation=op.filter_dilation,
num_groups=op.num_groups,
conv_mode=conv_mode)
conv_mode=conv_mode,
num_groups=op.num_groups)
return [rval]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论