提交 291f8c15 authored 作者: jiakai's avatar jiakai

better interface for dnn_conv

上级 b2c949cd
...@@ -141,15 +141,18 @@ class GpuDnnConvDesc(GpuOp): ...@@ -141,15 +141,18 @@ class GpuDnnConvDesc(GpuOp):
def c_compiler(self): def c_compiler(self):
return NVCC_compiler return NVCC_compiler
def __init__(self, border_mode, subsample=(1, 1), conv_mode='conv', def __init__(self, border_mode, subsample=(1, 1), conv_mode='conv'):
padding=(0, 0)): if isinstance(border_mode, int):
assert border_mode in ('valid', 'full', 'padding') border_mode = (border_mode, border_mode)
assert isinstance(border_mode, tuple) or \
border_mode in ('valid', 'full'), \
'invalid border_mode {}, which must be either "valid", "full", '\
'an integer or a pair of integers'.format(border_mode)
self.border_mode = border_mode self.border_mode = border_mode
assert len(subsample) == 2 assert len(subsample) == 2
self.subsample = subsample self.subsample = subsample
assert conv_mode in ('conv', 'cross') assert conv_mode in ('conv', 'cross')
self.conv_mode = conv_mode self.conv_mode = conv_mode
self.padding = padding
def make_node(self, img_shape, kern_shape): def make_node(self, img_shape, kern_shape):
if img_shape.type.ndim != 1 or img_shape.type.dtype != 'int64': if img_shape.type.ndim != 1 or img_shape.type.dtype != 'int64':
...@@ -164,14 +167,13 @@ class GpuDnnConvDesc(GpuOp): ...@@ -164,14 +167,13 @@ class GpuDnnConvDesc(GpuOp):
img_shape, kern_shape = inputs img_shape, kern_shape = inputs
desc, = outputs desc, = outputs
pad_h_spec, pad_w_spec = map(int, self.padding) if isinstance(self.border_mode, tuple):
pad_h_spec, pad_w_spec = map(int, self.border_mode)
assert pad_h_spec >= 0 and pad_w_spec >= 0 assert pad_h_spec >= 0 and pad_w_spec >= 0
if self.border_mode == 'padding':
bmode = 2 bmode = 2
else: else:
assert pad_h_spec == 0 and pad_w_spec == 0, \ pad_h_spec = pad_w_spec = 0
'padding not zero, but border_mode != "padding"'
if self.border_mode == "valid": if self.border_mode == "valid":
bmode = 1 bmode = 1
else: else:
...@@ -464,7 +466,7 @@ class GpuDnnConvGradI(GpuDnnConvBase): ...@@ -464,7 +466,7 @@ class GpuDnnConvGradI(GpuDnnConvBase):
def dnn_conv(img, kerns, border_mode='valid', subsample=(1, 1), def dnn_conv(img, kerns, border_mode='valid', subsample=(1, 1),
conv_mode='conv', padding=(0, 0)): conv_mode='conv'):
""" """
GPU convolution using cuDNN from NVIDIA. GPU convolution using cuDNN from NVIDIA.
...@@ -473,7 +475,8 @@ def dnn_conv(img, kerns, border_mode='valid', subsample=(1, 1), ...@@ -473,7 +475,8 @@ def dnn_conv(img, kerns, border_mode='valid', subsample=(1, 1),
:param img: images to do the convolution over :param img: images to do the convolution over
:param kerns: convolution filters :param kerns: convolution filters
:param border_mode: one of 'valid', 'full' or 'padding'(default: 'valid') :param border_mode: one of 'valid', 'full'; additionally, the padding size
could be directly specified by an integer or a pair of integers
:param subsample: perform subsampling of the output (default: (1, 1)) :param subsample: perform subsampling of the output (default: (1, 1))
:param conv_mode: perform convolution (kernels flipped) or cross-correlation. One of 'conv', 'cross'. (default: 'conv') :param conv_mode: perform convolution (kernels flipped) or cross-correlation. One of 'conv', 'cross'. (default: 'conv')
...@@ -484,7 +487,7 @@ def dnn_conv(img, kerns, border_mode='valid', subsample=(1, 1), ...@@ -484,7 +487,7 @@ def dnn_conv(img, kerns, border_mode='valid', subsample=(1, 1),
img = gpu_contiguous(img) img = gpu_contiguous(img)
kerns = gpu_contiguous(kerns) kerns = gpu_contiguous(kerns)
desc = GpuDnnConvDesc(border_mode=border_mode, subsample=subsample, desc = GpuDnnConvDesc(border_mode=border_mode, subsample=subsample,
conv_mode=conv_mode, padding=padding)(img.shape, kerns.shape) conv_mode=conv_mode)(img.shape, kerns.shape)
return GpuDnnConv()(img, kerns, desc) return GpuDnnConv()(img, kerns, desc)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论