提交 6e25ecf9 authored 作者: carriepl's avatar carriepl

Merge pull request #3938 from carriepl/half_padding

Add 'half' padding to GpuArray dnn convolution
...@@ -13,6 +13,12 @@ int APPLY_SPECIFIC(conv_desc)(PyArrayObject *filt_shp, ...@@ -13,6 +13,12 @@ int APPLY_SPECIFIC(conv_desc)(PyArrayObject *filt_shp,
#if NB_DIMS > 2 #if NB_DIMS > 2
pad[2] = *(npy_int64 *)PyArray_GETPTR1(filt_shp, 4) - 1; pad[2] = *(npy_int64 *)PyArray_GETPTR1(filt_shp, 4) - 1;
#endif #endif
#elif BORDER_MODE == 2
pad[0] = *(npy_int64 *)PyArray_GETPTR1(filt_shp, 2) / 2;
pad[1] = *(npy_int64 *)PyArray_GETPTR1(filt_shp, 3) / 2;
#if NB_DIMS > 2
pad[2] = *(npy_int64 *)PyArray_GETPTR1(filt_shp, 4) / 2;
#endif
#endif #endif
if (PyArray_DIM(filt_shp, 0) - 2 != NB_DIMS) { if (PyArray_DIM(filt_shp, 0) - 2 != NB_DIMS) {
......
...@@ -263,10 +263,10 @@ class GpuDnnConvDesc(COp): ...@@ -263,10 +263,10 @@ class GpuDnnConvDesc(COp):
assert len(border_mode) == len(subsample) assert len(border_mode) == len(subsample)
border_mode = tuple(map(int, border_mode)) border_mode = tuple(map(int, border_mode))
if not ((isinstance(border_mode, tuple) and min(border_mode) >= 0) or if not ((isinstance(border_mode, tuple) and min(border_mode) >= 0) or
border_mode in ('valid', 'full')): border_mode in ('valid', 'full', 'half')):
raise ValueError( raise ValueError(
'invalid border_mode {}, which must be either ' 'invalid border_mode {}, which must be either '
'"valid", "full", an integer or a pair of' '"valid", "full", "half", an integer or a pair of'
' integers'.format(border_mode)) ' integers'.format(border_mode))
self.border_mode = border_mode self.border_mode = border_mode
assert len(subsample) in (2, 3) assert len(subsample) in (2, 3)
...@@ -294,9 +294,11 @@ class GpuDnnConvDesc(COp): ...@@ -294,9 +294,11 @@ class GpuDnnConvDesc(COp):
pad1 = str(self.border_mode[1]) pad1 = str(self.border_mode[1])
if len(self.border_mode) > 2: if len(self.border_mode) > 2:
pad2 = str(self.border_mode[2]) pad2 = str(self.border_mode[2])
bmode = '2' bmode = '1'
elif self.border_mode == "valid": elif self.border_mode == "valid":
bmode = '1' bmode = '1'
elif self.border_mode == "half":
bmode = '2'
elif self.border_mode == "full": elif self.border_mode == "full":
bmode = '0' bmode = '0'
else: else:
...@@ -781,7 +783,7 @@ def dnn_conv(img, kerns, border_mode='valid', subsample=(1, 1), ...@@ -781,7 +783,7 @@ def dnn_conv(img, kerns, border_mode='valid', subsample=(1, 1),
kerns kerns
Convolution filters. Convolution filters.
border_mode border_mode
One of 'valid', 'full'; additionally, the padding size One of 'valid', 'full', 'half'; additionally, the padding size
could be directly specified by an integer or a pair of integers. could be directly specified by an integer or a pair of integers.
subsample subsample
Perform subsampling of the output (default: (1, 1)). Perform subsampling of the output (default: (1, 1)).
......
...@@ -389,7 +389,7 @@ class TestDnnInferShapes(utt.InferShapeTester): ...@@ -389,7 +389,7 @@ class TestDnnInferShapes(utt.InferShapeTester):
) )
for params in product( for params in product(
['valid', 'full'], ['valid', 'full', 'half'],
[(1, 1), (2, 2)], [(1, 1), (2, 2)],
['conv', 'cross'] ['conv', 'cross']
): ):
...@@ -427,7 +427,7 @@ class TestDnnInferShapes(utt.InferShapeTester): ...@@ -427,7 +427,7 @@ class TestDnnInferShapes(utt.InferShapeTester):
) )
for params in product( for params in product(
['valid', 'full'], ['valid', 'full', 'half'],
[(1, 1)], # strides besides (1, 1) [(1, 1)], # strides besides (1, 1)
['conv', 'cross'] ['conv', 'cross']
): ):
...@@ -590,6 +590,7 @@ def test_dnn_conv_border_mode(): ...@@ -590,6 +590,7 @@ def test_dnn_conv_border_mode():
dnn.dnn_conv(img, kern, border_mode=(2, 3)) dnn.dnn_conv(img, kern, border_mode=(2, 3))
dnn.dnn_conv(img, kern, border_mode='full') dnn.dnn_conv(img, kern, border_mode='full')
dnn.dnn_conv(img, kern, border_mode='valid') dnn.dnn_conv(img, kern, border_mode='valid')
dnn.dnn_conv(img, kern, border_mode='half')
def test_dnn_conv_alpha_output_merge(): def test_dnn_conv_alpha_output_merge():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论