提交 b2c949cd authored 作者: jiakai's avatar jiakai

padding support for dnn_conv

上级 32876587
......@@ -141,13 +141,15 @@ class GpuDnnConvDesc(GpuOp):
def c_compiler(self):
return NVCC_compiler
def __init__(self, border_mode, subsample=(1, 1), conv_mode='conv'):
assert border_mode in ('valid', 'full')
def __init__(self, border_mode, subsample=(1, 1), conv_mode='conv',
padding=(0, 0)):
assert border_mode in ('valid', 'full', 'padding')
self.border_mode = border_mode
assert len(subsample) == 2
self.subsample = subsample
assert conv_mode in ('conv', 'cross')
self.conv_mode = conv_mode
self.padding = padding
def make_node(self, img_shape, kern_shape):
if img_shape.type.ndim != 1 or img_shape.type.dtype != 'int64':
......@@ -162,11 +164,19 @@ class GpuDnnConvDesc(GpuOp):
img_shape, kern_shape = inputs
desc, = outputs
if self.border_mode == "valid":
bmode = 1
pad_h_spec, pad_w_spec = map(int, self.padding)
assert pad_h_spec >= 0 and pad_w_spec >= 0
if self.border_mode == 'padding':
bmode = 2
else:
assert self.border_mode == "full"
bmode = 0
assert pad_h_spec == 0 and pad_w_spec == 0, \
'padding not zero, but border_mode != "padding"'
if self.border_mode == "valid":
bmode = 1
else:
assert self.border_mode == "full"
bmode = 0
if self.conv_mode == 'conv':
conv_flag = 'CUDNN_CONVOLUTION'
......@@ -185,7 +195,10 @@ class GpuDnnConvDesc(GpuOp):
%(fail)s
}
if (%(bmode)d == 1) {
if (%(bmode)d == 2) {
pad_h%(name)s = %(pad_h_spec)d;
pad_w%(name)s = %(pad_w_spec)d;
} else if (%(bmode)d == 1) {
pad_h%(name)s = 0;
pad_w%(name)s = 0;
} else if (%(bmode)d == 0) {
......@@ -218,10 +231,11 @@ class GpuDnnConvDesc(GpuOp):
}
""" % dict(name=name, img_shape=img_shape, kern_shape=kern_shape, desc=desc,
bmode=bmode, conv_flag=conv_flag, fail=sub['fail'],
subsx=self.subsample[0], subsy=self.subsample[1])
subsx=self.subsample[0], subsy=self.subsample[1],
pad_h_spec=pad_h_spec, pad_w_spec=pad_w_spec)
def c_code_cache_version(self):
return (1,)
return (2,)
class GpuDnnConvBase(DnnBase):
......@@ -450,7 +464,7 @@ class GpuDnnConvGradI(GpuDnnConvBase):
def dnn_conv(img, kerns, border_mode='valid', subsample=(1, 1),
conv_mode='conv'):
conv_mode='conv', padding=(0, 0)):
"""
GPU convolution using cuDNN from NVIDIA.
......@@ -459,7 +473,7 @@ def dnn_conv(img, kerns, border_mode='valid', subsample=(1, 1),
:param img: images to do the convolution over
:param kerns: convolution filters
:param border_mode: one of 'valid', 'full' (default: 'valid')
:param border_mode: one of 'valid', 'full' or 'padding'(default: 'valid')
:param subsample: perform subsampling of the output (default: (1, 1))
:param conv_mode: perform convolution (kernels flipped) or cross-correlation. One of 'conv', 'cross'. (default: 'conv')
......@@ -470,7 +484,7 @@ def dnn_conv(img, kerns, border_mode='valid', subsample=(1, 1),
img = gpu_contiguous(img)
kerns = gpu_contiguous(kerns)
desc = GpuDnnConvDesc(border_mode=border_mode, subsample=subsample,
conv_mode=conv_mode)(img.shape, kerns.shape)
conv_mode=conv_mode, padding=padding)(img.shape, kerns.shape)
return GpuDnnConv()(img, kerns, desc)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论