提交 291f8c15 authored 作者: jiakai's avatar jiakai

better interface for dnn_conv

上级 b2c949cd
......@@ -141,15 +141,18 @@ class GpuDnnConvDesc(GpuOp):
def c_compiler(self):
return NVCC_compiler
def __init__(self, border_mode, subsample=(1, 1), conv_mode='conv',
padding=(0, 0)):
assert border_mode in ('valid', 'full', 'padding')
def __init__(self, border_mode, subsample=(1, 1), conv_mode='conv'):
if isinstance(border_mode, int):
border_mode = (border_mode, border_mode)
assert isinstance(border_mode, tuple) or \
border_mode in ('valid', 'full'), \
'invalid border_mode {}, which must be either "valid", "full", '\
'an integer or a pair of integers'.format(border_mode)
self.border_mode = border_mode
assert len(subsample) == 2
self.subsample = subsample
assert conv_mode in ('conv', 'cross')
self.conv_mode = conv_mode
self.padding = padding
def make_node(self, img_shape, kern_shape):
if img_shape.type.ndim != 1 or img_shape.type.dtype != 'int64':
......@@ -164,14 +167,13 @@ class GpuDnnConvDesc(GpuOp):
img_shape, kern_shape = inputs
desc, = outputs
pad_h_spec, pad_w_spec = map(int, self.padding)
assert pad_h_spec >= 0 and pad_w_spec >= 0
if self.border_mode == 'padding':
if isinstance(self.border_mode, tuple):
pad_h_spec, pad_w_spec = map(int, self.border_mode)
assert pad_h_spec >= 0 and pad_w_spec >= 0
bmode = 2
else:
assert pad_h_spec == 0 and pad_w_spec == 0, \
'padding not zero, but border_mode != "padding"'
pad_h_spec = pad_w_spec = 0
if self.border_mode == "valid":
bmode = 1
else:
......@@ -464,7 +466,7 @@ class GpuDnnConvGradI(GpuDnnConvBase):
def dnn_conv(img, kerns, border_mode='valid', subsample=(1, 1),
conv_mode='conv', padding=(0, 0)):
conv_mode='conv'):
"""
GPU convolution using cuDNN from NVIDIA.
......@@ -473,7 +475,8 @@ def dnn_conv(img, kerns, border_mode='valid', subsample=(1, 1),
:param img: images to do the convolution over
:param kerns: convolution filters
:param border_mode: one of 'valid', 'full' or 'padding'(default: 'valid')
:param border_mode: one of 'valid', 'full'; additionally, the padding size
could be directly specified by an integer or a pair of integers
:param subsample: perform subsampling of the output (default: (1, 1))
:param conv_mode: perform convolution (kernels flipped) or cross-correlation. One of 'conv', 'cross'. (default: 'conv')
......@@ -484,7 +487,7 @@ def dnn_conv(img, kerns, border_mode='valid', subsample=(1, 1),
img = gpu_contiguous(img)
kerns = gpu_contiguous(kerns)
desc = GpuDnnConvDesc(border_mode=border_mode, subsample=subsample,
conv_mode=conv_mode, padding=padding)(img.shape, kerns.shape)
conv_mode=conv_mode)(img.shape, kerns.shape)
return GpuDnnConv()(img, kerns, desc)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论