Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
38e03e0f
提交
38e03e0f
authored
8月 24, 2017
作者:
Vikram
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Started asymmetric padding
上级
9592125c
隐藏空白字符变更
内嵌
并排
正在显示
2 个修改的文件
包含
61 行增加
和
44 行删除
+61
-44
corr_gemm.c
theano/tensor/nnet/c_code/corr_gemm.c
+21
-18
corr.py
theano/tensor/nnet/corr.py
+40
-26
没有找到文件。
theano/tensor/nnet/c_code/corr_gemm.c
浏览文件 @
38e03e0f
...
@@ -31,23 +31,23 @@ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
...
@@ -31,23 +31,23 @@ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
void
im2col
(
const
%
(
float_type
)
s
*
data_im
,
const
int
channels
,
void
im2col
(
const
%
(
float_type
)
s
*
data_im
,
const
int
channels
,
const
int
height
,
const
int
width
,
const
int
kernel_h
,
const
int
kernel_w
,
const
int
height
,
const
int
width
,
const
int
kernel_h
,
const
int
kernel_w
,
const
int
dilation_h
,
const
int
dilation_w
,
const
int
dilation_h
,
const
int
dilation_w
,
const
int
pad_h
,
const
int
pad_w
,
const
int
pad_h
l
,
const
int
pad_hr
,
const
int
pad_wl
,
const
int
pad_wr
,
const
int
stride_h
,
const
int
stride_w
,
const
int
stride_h
,
const
int
stride_w
,
%
(
float_type
)
s
*
data_col
)
{
%
(
float_type
)
s
*
data_col
)
{
// Implicit dilated kernel size
// Implicit dilated kernel size
int
dil_kernel_h
=
(
kernel_h
-
1
)
*
dilation_h
+
1
;
int
dil_kernel_h
=
(
kernel_h
-
1
)
*
dilation_h
+
1
;
int
dil_kernel_w
=
(
kernel_w
-
1
)
*
dilation_w
+
1
;
int
dil_kernel_w
=
(
kernel_w
-
1
)
*
dilation_w
+
1
;
int
height_col
=
(
height
+
2
*
pad_h
-
dil_kernel_h
)
/
stride_h
+
1
;
int
height_col
=
(
height
+
pad_hl
+
pad_hr
-
dil_kernel_h
)
/
stride_h
+
1
;
int
width_col
=
(
width
+
2
*
pad_w
-
dil_kernel_w
)
/
stride_w
+
1
;
int
width_col
=
(
width
+
pad_wl
+
pad_wr
-
dil_kernel_w
)
/
stride_w
+
1
;
int
channels_col
=
channels
*
kernel_h
*
kernel_w
;
int
channels_col
=
channels
*
kernel_h
*
kernel_w
;
for
(
int
c
=
0
;
c
<
channels_col
;
++
c
)
{
for
(
int
c
=
0
;
c
<
channels_col
;
++
c
)
{
int
w_offset
=
c
%%
kernel_w
;
int
w_offset
=
c
%%
kernel_w
;
int
h_offset
=
(
c
/
kernel_w
)
%%
kernel_h
;
int
h_offset
=
(
c
/
kernel_w
)
%%
kernel_h
;
int
c_im
=
c
/
kernel_h
/
kernel_w
;
int
c_im
=
c
/
kernel_h
/
kernel_w
;
for
(
int
h
=
0
;
h
<
height_col
;
++
h
)
{
for
(
int
h
=
0
;
h
<
height_col
;
++
h
)
{
int
h_pad
=
h
*
stride_h
-
pad_h
+
h_offset
*
dilation_h
;
int
h_pad
=
h
*
stride_h
-
pad_h
l
+
h_offset
*
dilation_h
;
for
(
int
w
=
0
;
w
<
width_col
;
++
w
)
{
for
(
int
w
=
0
;
w
<
width_col
;
++
w
)
{
int
w_pad
=
w
*
stride_w
-
pad_w
+
w_offset
*
dilation_w
;
int
w_pad
=
w
*
stride_w
-
pad_w
l
+
w_offset
*
dilation_w
;
if
(
h_pad
>=
0
&&
h_pad
<
height
&&
w_pad
>=
0
&&
w_pad
<
width
)
if
(
h_pad
>=
0
&&
h_pad
<
height
&&
w_pad
>=
0
&&
w_pad
<
width
)
data_col
[(
npy_intp
)(
c
*
height_col
+
h
)
*
width_col
+
w
]
=
data_col
[(
npy_intp
)(
c
*
height_col
+
h
)
*
width_col
+
w
]
=
data_im
[(
npy_intp
)(
c_im
*
height
+
h_pad
)
*
width
+
w_pad
];
data_im
[(
npy_intp
)(
c_im
*
height
+
h_pad
)
*
width
+
w_pad
];
...
@@ -64,13 +64,14 @@ void im2col(const %(float_type)s* data_im, const int channels,
...
@@ -64,13 +64,14 @@ void im2col(const %(float_type)s* data_im, const int channels,
void
col2im
(
const
%
(
float_type
)
s
*
data_col
,
const
int
channels
,
void
col2im
(
const
%
(
float_type
)
s
*
data_col
,
const
int
channels
,
const
int
height
,
const
int
width
,
const
int
patch_h
,
const
int
patch_w
,
const
int
height
,
const
int
width
,
const
int
patch_h
,
const
int
patch_w
,
const
int
dilation_h
,
const
int
dilation_w
,
const
int
dilation_h
,
const
int
dilation_w
,
const
int
pad_h
,
const
int
pad_w
,
const
int
stride_h
,
const
int
pad_hl
,
const
int
pad_hr
,
const
int
pad_wl
,
const
int
pad_wr
,
const
int
stride_w
,
%
(
float_type
)
s
*
data_im
)
{
const
int
stride_h
,
const
int
stride_w
,
%
(
float_type
)
s
*
data_im
)
{
// Implicit dilated patch
// Implicit dilated patch
int
dil_patch_h
=
(
patch_h
-
1
)
*
dilation_h
+
1
;
int
dil_patch_h
=
(
patch_h
-
1
)
*
dilation_h
+
1
;
int
dil_patch_w
=
(
patch_w
-
1
)
*
dilation_w
+
1
;
int
dil_patch_w
=
(
patch_w
-
1
)
*
dilation_w
+
1
;
int
height_col
=
(
height
+
2
*
pad_h
-
dil_patch_h
)
/
stride_h
+
1
;
int
height_col
=
(
height
+
pad_hl
+
pad_hr
-
dil_patch_h
)
/
stride_h
+
1
;
int
width_col
=
(
width
+
2
*
pad_w
-
dil_patch_w
)
/
stride_w
+
1
;
int
width_col
=
(
width
+
pad_wl
+
pad_wr
-
dil_patch_w
)
/
stride_w
+
1
;
int
num_kernels
=
channels
*
height
*
width
;
int
num_kernels
=
channels
*
height
*
width
;
int
channels_col
=
channels
*
patch_h
*
patch_w
;
int
channels_col
=
channels
*
patch_h
*
patch_w
;
for
(
int
c
=
0
;
c
<
channels_col
;
++
c
)
{
for
(
int
c
=
0
;
c
<
channels_col
;
++
c
)
{
...
@@ -78,9 +79,9 @@ void col2im(const %(float_type)s* data_col, const int channels,
...
@@ -78,9 +79,9 @@ void col2im(const %(float_type)s* data_col, const int channels,
int
h_offset
=
(
c
/
patch_w
)
%%
patch_h
;
int
h_offset
=
(
c
/
patch_w
)
%%
patch_h
;
int
c_im
=
c
/
patch_h
/
patch_w
;
int
c_im
=
c
/
patch_h
/
patch_w
;
for
(
int
h
=
0
;
h
<
height_col
;
++
h
)
{
for
(
int
h
=
0
;
h
<
height_col
;
++
h
)
{
int
h_pad
=
h
*
stride_h
-
pad_h
+
h_offset
*
dilation_h
;
int
h_pad
=
h
*
stride_h
-
pad_h
l
+
h_offset
*
dilation_h
;
for
(
int
w
=
0
;
w
<
width_col
;
++
w
)
{
for
(
int
w
=
0
;
w
<
width_col
;
++
w
)
{
int
w_pad
=
w
*
stride_w
-
pad_w
+
w_offset
*
dilation_w
;
int
w_pad
=
w
*
stride_w
-
pad_w
l
+
w_offset
*
dilation_w
;
if
(
h_pad
>=
0
&&
h_pad
<
height
&&
w_pad
>=
0
&&
w_pad
<
width
)
if
(
h_pad
>=
0
&&
h_pad
<
height
&&
w_pad
>=
0
&&
w_pad
<
width
)
data_im
[(
npy_intp
)(
c_im
*
height
+
h_pad
)
*
width
+
w_pad
]
+=
data_im
[(
npy_intp
)(
c_im
*
height
+
h_pad
)
*
width
+
w_pad
]
+=
data_col
[(
npy_intp
)(
c
*
height_col
+
h
)
*
width_col
+
w
];
data_col
[(
npy_intp
)(
c
*
height_col
+
h
)
*
width_col
+
w
];
...
@@ -105,8 +106,10 @@ PyArrayObject* corrMM(PyArrayObject* bottom,
...
@@ -105,8 +106,10 @@ PyArrayObject* corrMM(PyArrayObject* bottom,
const
int
dW
=
1
,
const
int
dW
=
1
,
const
int
dilH
=
1
,
const
int
dilH
=
1
,
const
int
dilW
=
1
,
const
int
dilW
=
1
,
const
int
padH
=
0
,
const
int
padH_l
=
0
,
const
int
padW
=
0
,
const
int
padH_r
=
0
,
const
int
padW_l
=
0
,
const
int
padW_r
=
0
,
const
int
numgroups
=
1
,
const
int
numgroups
=
1
,
const
int
unshared
=
0
)
const
int
unshared
=
0
)
{
{
...
@@ -172,8 +175,8 @@ PyArrayObject* corrMM(PyArrayObject* bottom,
...
@@ -172,8 +175,8 @@ PyArrayObject* corrMM(PyArrayObject* bottom,
const
int
dil_kH
=
(
kH
-
1
)
*
dilH
+
1
;
const
int
dil_kH
=
(
kH
-
1
)
*
dilH
+
1
;
const
int
dil_kW
=
(
kW
-
1
)
*
dilW
+
1
;
const
int
dil_kW
=
(
kW
-
1
)
*
dilW
+
1
;
// top: (batchSize, nFilters, topHeight, topWidth)
// top: (batchSize, nFilters, topHeight, topWidth)
const
int
topHeightNoDH
=
(
bottomHeight
+
2
*
padH
-
dil_kH
);
const
int
topHeightNoDH
=
(
bottomHeight
+
padH_l
+
padH_r
-
dil_kH
);
const
int
topWidthNoDW
=
(
bottomWidth
+
2
*
padW
-
dil_kW
);
const
int
topWidthNoDW
=
(
bottomWidth
+
padW_l
+
padW_r
-
dil_kW
);
// the above values might be negative so we need to use Python-like
// the above values might be negative so we need to use Python-like
// flooring integer division to be compatible with get_conv_output.
// flooring integer division to be compatible with get_conv_output.
// note: this macro implements Python's // for negative x only
// note: this macro implements Python's // for negative x only
...
@@ -303,7 +306,7 @@ PyArrayObject* corrMM(PyArrayObject* bottom,
...
@@ -303,7 +306,7 @@ PyArrayObject* corrMM(PyArrayObject* bottom,
int
tid
=
%
(
omp_get_thread_num
)
s
;
int
tid
=
%
(
omp_get_thread_num
)
s
;
// First, im2col
// First, im2col
im2col
((
%
(
float_type
)
s
*
)
PyArray_DATA
(
bottom
)
+
n
*
batch_bottom_stride
,
nChannels
,
im2col
((
%
(
float_type
)
s
*
)
PyArray_DATA
(
bottom
)
+
n
*
batch_bottom_stride
,
nChannels
,
bottomHeight
,
bottomWidth
,
kH
,
kW
,
dilH
,
dilW
,
padH
,
padW
,
dH
,
dW
,
bottomHeight
,
bottomWidth
,
kH
,
kW
,
dilH
,
dilW
,
padH
_l
,
padH_r
,
padW_l
,
padW_r
,
dH
,
dW
,
(
%
(
float_type
)
s
*
)
PyArray_DATA
(
col
)
+
tid
*
col_stride
);
(
%
(
float_type
)
s
*
)
PyArray_DATA
(
col
)
+
tid
*
col_stride
);
// Second, gemm
// Second, gemm
if
(
unshared
)
{
if
(
unshared
)
{
...
@@ -396,7 +399,7 @@ PyArrayObject* corrMM(PyArrayObject* bottom,
...
@@ -396,7 +399,7 @@ PyArrayObject* corrMM(PyArrayObject* bottom,
int
tid
=
%
(
omp_get_thread_num
)
s
;
int
tid
=
%
(
omp_get_thread_num
)
s
;
// First, im2col
// First, im2col
im2col
((
%
(
float_type
)
s
*
)
PyArray_DATA
(
bottom
)
+
n
*
batch_bottom_stride
,
im2col
((
%
(
float_type
)
s
*
)
PyArray_DATA
(
bottom
)
+
n
*
batch_bottom_stride
,
nChannels
,
bottomHeight
,
bottomWidth
,
kH
,
kW
,
dilH
,
dilW
,
padH
,
padW
,
dH
,
dW
,
nChannels
,
bottomHeight
,
bottomWidth
,
kH
,
kW
,
dilH
,
dilW
,
padH
_l
,
padH_r
,
padW_l
,
padW_r
,
dH
,
dW
,
(
%
(
float_type
)
s
*
)
PyArray_DATA
(
col
)
+
tid
*
col_stride
);
(
%
(
float_type
)
s
*
)
PyArray_DATA
(
col
)
+
tid
*
col_stride
);
// Second, gemm
// Second, gemm
// Note that we accumulate into weight. We do so by setting beta = 0
// Note that we accumulate into weight. We do so by setting beta = 0
...
@@ -519,7 +522,7 @@ PyArrayObject* corrMM(PyArrayObject* bottom,
...
@@ -519,7 +522,7 @@ PyArrayObject* corrMM(PyArrayObject* bottom,
}
}
// col2im back to the data
// col2im back to the data
col2im
((
%
(
float_type
)
s
*
)
PyArray_DATA
(
col
)
+
tid
*
col_stride
,
nChannels
,
bottomHeight
,
bottomWidth
,
col2im
((
%
(
float_type
)
s
*
)
PyArray_DATA
(
col
)
+
tid
*
col_stride
,
nChannels
,
bottomHeight
,
bottomWidth
,
kH
,
kW
,
dilH
,
dilW
,
padH
,
padW
,
kH
,
kW
,
dilH
,
dilW
,
padH
_l
,
padH_r
,
padW_l
,
padW_r
,
dH
,
dW
,
(
%
(
float_type
)
s
*
)
PyArray_DATA
(
bottom
)
+
n
*
batch_bottom_stride
);
dH
,
dW
,
(
%
(
float_type
)
s
*
)
PyArray_DATA
(
bottom
)
+
n
*
batch_bottom_stride
);
}
}
// Restore to previous blas threads
// Restore to previous blas threads
...
...
theano/tensor/nnet/corr.py
浏览文件 @
38e03e0f
...
@@ -66,20 +66,29 @@ class BaseCorrMM(gof.OpenMPOp):
...
@@ -66,20 +66,29 @@ class BaseCorrMM(gof.OpenMPOp):
raise
ValueError
(
raise
ValueError
(
'invalid border_mode {}, which must be a '
'invalid border_mode {}, which must be a '
'non-negative integer'
.
format
(
border_mode
))
'non-negative integer'
.
format
(
border_mode
))
border_mode
=
(
border_mode
,
border_mode
)
border_mode
=
(
(
border_mode
,
border_mode
),)
*
2
if
isinstance
(
border_mode
,
tuple
):
el
if
isinstance
(
border_mode
,
tuple
):
if
len
(
border_mode
)
!=
2
or
border_mode
[
0
]
<
0
or
border_mode
[
1
]
<
0
:
if
len
(
border_mode
)
!=
2
:
raise
ValueError
(
raise
ValueError
(
'invalid border_mode {}, which must be a '
'invalid border_mode {} which must be a '
'pair of non-negative integers'
.
format
(
border_mode
))
'tuple of length 2'
.
format
(
border_mode
))
pad_h
,
pad_w
=
map
(
int
,
border_mode
)
border
=
()
border_mode
=
(
pad_h
,
pad_w
)
for
mode
in
border_mode
:
if
not
((
isinstance
(
border_mode
,
tuple
)
and
min
(
border_mode
)
>=
0
)
or
if
isinstance
(
mode
,
integer_types
)
and
mode
>
0
:
border_mode
in
(
'valid'
,
'full'
,
'half'
)):
border
+=
((
mode
,
mode
),)
elif
isinstance
(
mode
,
tuple
)
and
len
(
mode
)
==
2
and
\
min
(
mode
)
>=
0
:
border
=
((
mode
[
0
],
mode
[
1
]),)
else
:
raise
ValueError
(
'invalid border mode {}. The tuple can only contain '
'integers or tuples of length 2'
.
format
(
border_mode
))
border_mode
=
border
elif
border_mode
not
in
(
'valid'
,
'full'
,
'half'
):
raise
ValueError
(
raise
ValueError
(
'invalid border_mode {}, which must be either '
'invalid border_mode {}, which must be either '
'"valid", "full", "half", an integer or a
pair of
'
'"valid", "full", "half", an integer or a
tuple
'
'
integers
'
.
format
(
border_mode
))
'
of length 2
'
.
format
(
border_mode
))
self
.
border_mode
=
border_mode
self
.
border_mode
=
border_mode
if
len
(
subsample
)
!=
2
:
if
len
(
subsample
)
!=
2
:
raise
ValueError
(
"subsample must have two elements"
)
raise
ValueError
(
"subsample must have two elements"
)
...
@@ -110,14 +119,14 @@ class BaseCorrMM(gof.OpenMPOp):
...
@@ -110,14 +119,14 @@ class BaseCorrMM(gof.OpenMPOp):
@property
@property
def
pad
(
self
):
def
pad
(
self
):
if
self
.
border_mode
==
"half"
:
if
self
.
border_mode
==
"half"
:
return
(
-
1
,
-
1
)
return
(
(
-
1
,
-
1
),)
*
2
elif
self
.
border_mode
==
"full"
:
elif
self
.
border_mode
==
"full"
:
return
(
-
2
,
-
2
)
return
(
(
-
2
,
-
2
),)
*
2
elif
isinstance
(
self
.
border_mode
,
tuple
):
elif
isinstance
(
self
.
border_mode
,
tuple
):
return
self
.
border_mode
return
self
.
border_mode
else
:
else
:
assert
self
.
border_mode
==
"valid"
assert
self
.
border_mode
==
"valid"
return
(
0
,
0
)
return
(
(
0
,
0
),)
*
2
# Direction should be converted to real enum value,
# Direction should be converted to real enum value,
# as it is compared to integer later in c_code_helper().
# as it is compared to integer later in c_code_helper().
...
@@ -129,8 +138,10 @@ class BaseCorrMM(gof.OpenMPOp):
...
@@ -129,8 +138,10 @@ class BaseCorrMM(gof.OpenMPOp):
dilH
=
property
(
lambda
self
:
self
.
filter_dilation
[
0
])
dilH
=
property
(
lambda
self
:
self
.
filter_dilation
[
0
])
dilW
=
property
(
lambda
self
:
self
.
filter_dilation
[
1
])
dilW
=
property
(
lambda
self
:
self
.
filter_dilation
[
1
])
padH
=
property
(
lambda
self
:
self
.
pad
[
0
])
padH_l
=
property
(
lambda
self
:
self
.
pad
[
0
][
0
])
padW
=
property
(
lambda
self
:
self
.
pad
[
1
])
padH_r
=
property
(
lambda
self
:
self
.
pad
[
0
][
1
])
padW_l
=
property
(
lambda
self
:
self
.
pad
[
1
][
0
])
padW_r
=
property
(
lambda
self
:
self
.
pad
[
1
][
1
])
def
__str__
(
self
):
def
__str__
(
self
):
return
'
%
s{
%
s,
%
s,
%
s,
%
s
%
s}'
%
(
return
'
%
s{
%
s,
%
s,
%
s,
%
s
%
s}'
%
(
...
@@ -271,13 +282,13 @@ class BaseCorrMM(gof.OpenMPOp):
...
@@ -271,13 +282,13 @@ class BaseCorrMM(gof.OpenMPOp):
if
height
:
if
height
:
height
=
'(*(npy_int64 *)(PyArray_DATA(
%
s)))'
%
height
height
=
'(*(npy_int64 *)(PyArray_DATA(
%
s)))'
%
height
else
:
else
:
if
((
self
.
direction
!=
0
)
and
(
self
.
dH
!=
1
))
or
((
self
.
direction
==
1
)
and
(
self
.
padH
==
-
1
)):
if
((
self
.
direction
!=
0
)
and
(
self
.
dH
!=
1
))
or
((
self
.
direction
==
1
)
and
(
self
.
padH
_l
==
-
1
)):
raise
ValueError
(
"height must be given for backprop with vertical sampling or border_mode='half'"
)
raise
ValueError
(
"height must be given for backprop with vertical sampling or border_mode='half'"
)
height
=
'-1'
height
=
'-1'
if
width
:
if
width
:
width
=
'(*(npy_int64 *)(PyArray_DATA(
%
s)))'
%
width
width
=
'(*(npy_int64 *)(PyArray_DATA(
%
s)))'
%
width
else
:
else
:
if
((
self
.
direction
!=
0
)
and
(
self
.
dW
!=
1
))
or
((
self
.
direction
==
1
)
and
(
self
.
padW
==
-
1
)):
if
((
self
.
direction
!=
0
)
and
(
self
.
dW
!=
1
))
or
((
self
.
direction
==
1
)
and
(
self
.
padW
_l
==
-
1
)):
raise
ValueError
(
"width must be given for backprop with horizontal sampling or border_mode='half'"
)
raise
ValueError
(
"width must be given for backprop with horizontal sampling or border_mode='half'"
)
width
=
'-1'
width
=
'-1'
...
@@ -290,8 +301,10 @@ class BaseCorrMM(gof.OpenMPOp):
...
@@ -290,8 +301,10 @@ class BaseCorrMM(gof.OpenMPOp):
int dW =
%(params)
s->dW;
int dW =
%(params)
s->dW;
int dilH =
%(params)
s->dilH;
int dilH =
%(params)
s->dilH;
int dilW =
%(params)
s->dilW;
int dilW =
%(params)
s->dilW;
int padH =
%(params)
s->padH;
int padH_l =
%(params)
s->padH_l;
int padW =
%(params)
s->padW;
int padH_r =
%(params)
s->padH_r;
int padW_l =
%(params)
s->padW_l;
int padW_r =
%(params)
s->padW_r;
int numgroups =
%(params)
s->num_groups;
int numgroups =
%(params)
s->num_groups;
int unshared =
%(params)
s->unshared;
int unshared =
%(params)
s->unshared;
...
@@ -340,7 +353,7 @@ class BaseCorrMM(gof.OpenMPOp):
...
@@ -340,7 +353,7 @@ class BaseCorrMM(gof.OpenMPOp):
}
}
else {
else {
// explicit padding, we can infer the kernel height
// explicit padding, we can infer the kernel height
kH = (PyArray_DIMS(bottom)[2] +
2*padH
- (PyArray_DIMS(top)[2] - 1) * dH - 1) / dilH +1;
kH = (PyArray_DIMS(bottom)[2] +
padH_l + padH_r
- (PyArray_DIMS(top)[2] - 1) * dH - 1) / dilH +1;
}
}
if (
%(width)
s != -1) {
if (
%(width)
s != -1) {
// kernel width is specified (perhaps horizontal subsampling or half padding)
// kernel width is specified (perhaps horizontal subsampling or half padding)
...
@@ -350,7 +363,7 @@ class BaseCorrMM(gof.OpenMPOp):
...
@@ -350,7 +363,7 @@ class BaseCorrMM(gof.OpenMPOp):
kW = (2 - PyArray_DIMS(bottom)[3] + (PyArray_DIMS(top)[3] - 1) * dW - 1) / dilW + 1;
kW = (2 - PyArray_DIMS(bottom)[3] + (PyArray_DIMS(top)[3] - 1) * dW - 1) / dilW + 1;
}
}
else {
else {
kW = (PyArray_DIMS(bottom)[3] +
2*padW
- (PyArray_DIMS(top)[3] - 1) * dW - 1) / dilW + 1;
kW = (PyArray_DIMS(bottom)[3] +
padW_l + padW_r
- (PyArray_DIMS(top)[3] - 1) * dW - 1) / dilW + 1;
}
}
}
}
...
@@ -386,11 +399,11 @@ class BaseCorrMM(gof.OpenMPOp):
...
@@ -386,11 +399,11 @@ class BaseCorrMM(gof.OpenMPOp):
switch(direction) {
switch(direction) {
case 0: // forward pass
case 0: // forward pass
// output is top: (batchsize, num_filters, height, width)
// output is top: (batchsize, num_filters, height, width)
// height and width: top = (bottom +
2*pad
- ((weight-1)*dil + 1)) / sample + 1
// height and width: top = (bottom +
pad_l + pad_r
- ((weight-1)*dil + 1)) / sample + 1
out_dim[0] = (npy_intp)PyArray_DIMS(bottom)[0];
out_dim[0] = (npy_intp)PyArray_DIMS(bottom)[0];
out_dim[1] = (npy_intp)PyArray_DIMS(weights)[0];
out_dim[1] = (npy_intp)PyArray_DIMS(weights)[0];
out_dim[2] = (npy_intp)((PyArray_DIMS(bottom)[2] +
2*padH
- ((PyArray_DIMS(weights)[wdim-2]-1)*dilH + 1)) / dH + 1);
out_dim[2] = (npy_intp)((PyArray_DIMS(bottom)[2] +
padH_l + padH_r
- ((PyArray_DIMS(weights)[wdim-2]-1)*dilH + 1)) / dH + 1);
out_dim[3] = (npy_intp)((PyArray_DIMS(bottom)[3] +
2*padW
- ((PyArray_DIMS(weights)[wdim-1]-1)*dilW + 1)) / dW + 1);
out_dim[3] = (npy_intp)((PyArray_DIMS(bottom)[3] +
padW_l + padW_r
- ((PyArray_DIMS(weights)[wdim-1]-1)*dilW + 1)) / dW + 1);
if (out_dim[0] < 0 || out_dim[1] < 0 || out_dim[2] <= 0 || out_dim[3] <= 0)
if (out_dim[0] < 0 || out_dim[1] < 0 || out_dim[2] <= 0 || out_dim[3] <= 0)
{
{
if (unshared) {
if (unshared) {
...
@@ -564,7 +577,8 @@ class BaseCorrMM(gof.OpenMPOp):
...
@@ -564,7 +577,8 @@ class BaseCorrMM(gof.OpenMPOp):
}
}
// Call corrMM code
// Call corrMM code
out2 = corrMM(
%(bottom)
s,
%(weights)
s,
%(top)
s, direction, dH, dW, dilH, dilW, padH, padW, numgroups, unshared);
out2 = corrMM(
%(bottom)
s,
%(weights)
s,
%(top)
s, direction, dH, dW, dilH, dilW,
padH_l, padH_r, padW_l, padW_r, numgroups, unshared);
if (out2==NULL){
if (out2==NULL){
%(fail)
s
%(fail)
s
}
}
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论