提交 b2c949cd authored 作者: jiakai's avatar jiakai

padding support for dnn_conv

上级 32876587
...@@ -141,13 +141,15 @@ class GpuDnnConvDesc(GpuOp): ...@@ -141,13 +141,15 @@ class GpuDnnConvDesc(GpuOp):
def c_compiler(self): def c_compiler(self):
return NVCC_compiler return NVCC_compiler
def __init__(self, border_mode, subsample=(1, 1), conv_mode='conv'): def __init__(self, border_mode, subsample=(1, 1), conv_mode='conv',
assert border_mode in ('valid', 'full') padding=(0, 0)):
assert border_mode in ('valid', 'full', 'padding')
self.border_mode = border_mode self.border_mode = border_mode
assert len(subsample) == 2 assert len(subsample) == 2
self.subsample = subsample self.subsample = subsample
assert conv_mode in ('conv', 'cross') assert conv_mode in ('conv', 'cross')
self.conv_mode = conv_mode self.conv_mode = conv_mode
self.padding = padding
def make_node(self, img_shape, kern_shape): def make_node(self, img_shape, kern_shape):
if img_shape.type.ndim != 1 or img_shape.type.dtype != 'int64': if img_shape.type.ndim != 1 or img_shape.type.dtype != 'int64':
...@@ -162,11 +164,19 @@ class GpuDnnConvDesc(GpuOp): ...@@ -162,11 +164,19 @@ class GpuDnnConvDesc(GpuOp):
img_shape, kern_shape = inputs img_shape, kern_shape = inputs
desc, = outputs desc, = outputs
if self.border_mode == "valid": pad_h_spec, pad_w_spec = map(int, self.padding)
bmode = 1 assert pad_h_spec >= 0 and pad_w_spec >= 0
if self.border_mode == 'padding':
bmode = 2
else: else:
assert self.border_mode == "full" assert pad_h_spec == 0 and pad_w_spec == 0, \
bmode = 0 'padding not zero, but border_mode != "padding"'
if self.border_mode == "valid":
bmode = 1
else:
assert self.border_mode == "full"
bmode = 0
if self.conv_mode == 'conv': if self.conv_mode == 'conv':
conv_flag = 'CUDNN_CONVOLUTION' conv_flag = 'CUDNN_CONVOLUTION'
...@@ -185,7 +195,10 @@ class GpuDnnConvDesc(GpuOp): ...@@ -185,7 +195,10 @@ class GpuDnnConvDesc(GpuOp):
%(fail)s %(fail)s
} }
if (%(bmode)d == 1) { if (%(bmode)d == 2) {
pad_h%(name)s = %(pad_h_spec)d;
pad_w%(name)s = %(pad_w_spec)d;
} else if (%(bmode)d == 1) {
pad_h%(name)s = 0; pad_h%(name)s = 0;
pad_w%(name)s = 0; pad_w%(name)s = 0;
} else if (%(bmode)d == 0) { } else if (%(bmode)d == 0) {
...@@ -218,10 +231,11 @@ class GpuDnnConvDesc(GpuOp): ...@@ -218,10 +231,11 @@ class GpuDnnConvDesc(GpuOp):
} }
""" % dict(name=name, img_shape=img_shape, kern_shape=kern_shape, desc=desc, """ % dict(name=name, img_shape=img_shape, kern_shape=kern_shape, desc=desc,
bmode=bmode, conv_flag=conv_flag, fail=sub['fail'], bmode=bmode, conv_flag=conv_flag, fail=sub['fail'],
subsx=self.subsample[0], subsy=self.subsample[1]) subsx=self.subsample[0], subsy=self.subsample[1],
pad_h_spec=pad_h_spec, pad_w_spec=pad_w_spec)
def c_code_cache_version(self): def c_code_cache_version(self):
return (1,) return (2,)
class GpuDnnConvBase(DnnBase): class GpuDnnConvBase(DnnBase):
...@@ -450,7 +464,7 @@ class GpuDnnConvGradI(GpuDnnConvBase): ...@@ -450,7 +464,7 @@ class GpuDnnConvGradI(GpuDnnConvBase):
def dnn_conv(img, kerns, border_mode='valid', subsample=(1, 1), def dnn_conv(img, kerns, border_mode='valid', subsample=(1, 1),
conv_mode='conv'): conv_mode='conv', padding=(0, 0)):
""" """
GPU convolution using cuDNN from NVIDIA. GPU convolution using cuDNN from NVIDIA.
...@@ -459,7 +473,7 @@ def dnn_conv(img, kerns, border_mode='valid', subsample=(1, 1), ...@@ -459,7 +473,7 @@ def dnn_conv(img, kerns, border_mode='valid', subsample=(1, 1),
:param img: images to do the convolution over :param img: images to do the convolution over
:param kerns: convolution filters :param kerns: convolution filters
:param border_mode: one of 'valid', 'full' (default: 'valid') :param border_mode: one of 'valid', 'full' or 'padding'(default: 'valid')
:param subsample: perform subsampling of the output (default: (1, 1)) :param subsample: perform subsampling of the output (default: (1, 1))
:param conv_mode: perform convolution (kernels flipped) or cross-correlation. One of 'conv', 'cross'. (default: 'conv') :param conv_mode: perform convolution (kernels flipped) or cross-correlation. One of 'conv', 'cross'. (default: 'conv')
...@@ -470,7 +484,7 @@ def dnn_conv(img, kerns, border_mode='valid', subsample=(1, 1), ...@@ -470,7 +484,7 @@ def dnn_conv(img, kerns, border_mode='valid', subsample=(1, 1),
img = gpu_contiguous(img) img = gpu_contiguous(img)
kerns = gpu_contiguous(kerns) kerns = gpu_contiguous(kerns)
desc = GpuDnnConvDesc(border_mode=border_mode, subsample=subsample, desc = GpuDnnConvDesc(border_mode=border_mode, subsample=subsample,
conv_mode=conv_mode)(img.shape, kerns.shape) conv_mode=conv_mode, padding=padding)(img.shape, kerns.shape)
return GpuDnnConv()(img, kerns, desc) return GpuDnnConv()(img, kerns, desc)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论