Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
a668c6c5
提交
a668c6c5
authored
7月 01, 2016
作者:
Pascal Lamblin
提交者:
GitHub
7月 01, 2016
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #4587 from niasla/dilated_convolution
Implementation of 2D dilated convolution/correlation.
上级
d78f44f6
2dcf3753
全部展开
隐藏空白字符变更
内嵌
并排
正在显示
13 个修改的文件
包含
146 行增加
和
68 行删除
+146
-68
dnn.py
theano/gpuarray/dnn.py
+3
-0
blas.py
theano/sandbox/cuda/blas.py
+0
-0
corr_gemm.cu
theano/sandbox/cuda/corr_gemm.cu
+0
-0
dnn.py
theano/sandbox/cuda/dnn.py
+3
-0
opt.py
theano/sandbox/cuda/opt.py
+21
-9
test_abstractconv.py
theano/sandbox/cuda/tests/test_abstractconv.py
+22
-15
__init__.py
theano/tensor/nnet/__init__.py
+7
-2
abstract_conv.py
theano/tensor/nnet/abstract_conv.py
+0
-0
corr.py
theano/tensor/nnet/corr.py
+0
-0
corr_gemm.c
theano/tensor/nnet/corr_gemm.c
+44
-28
opt.py
theano/tensor/nnet/opt.py
+13
-6
test_abstract_conv.py
theano/tensor/nnet/tests/test_abstract_conv.py
+0
-0
test_corr.py
theano/tensor/nnet/tests/test_corr.py
+33
-8
没有找到文件。
theano/gpuarray/dnn.py
浏览文件 @
a668c6c5
...
...
@@ -1393,6 +1393,9 @@ def local_abstractconv_cudnn(node):
inp1
=
node
.
inputs
[
0
]
inp2
=
node
.
inputs
[
1
]
if
(
node
.
op
.
filter_dilation
!=
(
1
,
1
)):
return
None
if
not
isinstance
(
inp1
.
type
,
GpuArrayType
):
return
None
...
...
theano/sandbox/cuda/blas.py
浏览文件 @
a668c6c5
差异被折叠。
点击展开。
theano/sandbox/cuda/corr_gemm.cu
浏览文件 @
a668c6c5
差异被折叠。
点击展开。
theano/sandbox/cuda/dnn.py
浏览文件 @
a668c6c5
...
...
@@ -3116,6 +3116,8 @@ def local_abstractconv_cudnn(node):
AbstractConv2d_gradWeights
,
AbstractConv2d_gradInputs
))):
return
None
if
(
node
.
op
.
filter_dilation
!=
(
1
,
1
)):
return
None
inp1
=
node
.
inputs
[
0
]
inp2
=
node
.
inputs
[
1
]
...
...
@@ -3123,6 +3125,7 @@ def local_abstractconv_cudnn(node):
if
(
not
isinstance
(
inp1
.
type
,
CudaNdarrayType
)
or
not
isinstance
(
inp2
.
type
,
CudaNdarrayType
)):
return
None
if
not
dnn_available
():
return
None
...
...
theano/sandbox/cuda/opt.py
浏览文件 @
a668c6c5
...
...
@@ -1622,7 +1622,8 @@ def local_conv_gemm(node):
# because we are not allowed to replace a CudaNdarray with
# a DimShuffle instance in a graph optimization)
rval
=
theano
.
sandbox
.
cuda
.
as_cuda_ndarray_variable
(
GpuCorrMM_gradWeights
(
border_mode
,
subsample
)(
GpuCorrMM_gradWeights
(
border_mode
,
subsample
)(
gpu_contiguous
(
img
.
dimshuffle
(
1
,
0
,
2
,
3
)),
gpu_contiguous
(
kern
.
dimshuffle
(
1
,
0
,
2
,
3
))
)
.
dimshuffle
(
1
,
0
,
2
,
3
))
...
...
@@ -2769,28 +2770,33 @@ def local_abstractconv_gemm(node):
border_mode
=
node
.
op
.
border_mode
subsample
=
node
.
op
.
subsample
if
(
border_mode
==
'full'
)
and
(
subsample
==
(
1
,
1
)):
filter_dilation
=
node
.
op
.
filter_dilation
if
((
border_mode
==
'full'
)
and
(
subsample
==
(
1
,
1
))):
if
not
node
.
op
.
filter_flip
:
kern
=
kern
[:,
:,
::
-
1
,
::
-
1
]
# need to dimshuffle the kernel for full convolution
kern
=
kern
.
dimshuffle
(
1
,
0
,
2
,
3
)
# call GpuCorrMM_gradInputs
rval
=
GpuCorrMM_gradInputs
(
'valid'
,
subsample
)(
rval
=
GpuCorrMM_gradInputs
(
'valid'
,
subsample
,
filter_dilation
)(
gpu_contiguous
(
kern
),
gpu_contiguous
(
img
))
else
:
# need to flip the kernel if necessary
if
node
.
op
.
filter_flip
:
kern
=
kern
[:,
:,
::
-
1
,
::
-
1
]
# By default use GpuCorrMM
rval
=
GpuCorrMM
(
border_mode
,
subsample
)(
gpu_contiguous
(
img
),
gpu_contiguous
(
kern
))
rval
=
GpuCorrMM
(
border_mode
,
subsample
,
filter_dilation
)(
gpu_contiguous
(
img
),
gpu_contiguous
(
kern
))
# call GpuCorrMM_gradWeights if good
# (the latter is faster if batchsize * kernelHeight * kernelWidth
# is larger than inputChannels * outputHeight * outputWidth.
# GpuConv does not always store information on the batchsize and
# channels, though, so we only use what information we have.)
if
((
subsample
==
(
1
,
1
))
and
if
((
subsample
==
(
1
,
1
))
and
(
filter_dilation
==
(
1
,
1
))
and
(
node
.
op
.
imshp
is
not
None
)
and
(
None
not
in
node
.
op
.
imshp
[
-
2
:])
and
(
node
.
op
.
kshp
is
not
None
)
and
...
...
@@ -2810,7 +2816,9 @@ def local_abstractconv_gemm(node):
# because we are not allowed to replace a CudaNdarray with
# a DimShuffle instance in a graph optimization)
rval
=
theano
.
sandbox
.
cuda
.
as_cuda_ndarray_variable
(
GpuCorrMM_gradWeights
(
border_mode
,
subsample
)(
GpuCorrMM_gradWeights
(
border_mode
,
subsample
,
filter_dilation
)(
gpu_contiguous
(
img
.
dimshuffle
(
1
,
0
,
2
,
3
)),
gpu_contiguous
(
kern
.
dimshuffle
(
1
,
0
,
2
,
3
))
)
.
dimshuffle
(
1
,
0
,
2
,
3
))
...
...
@@ -2827,7 +2835,8 @@ def local_abstractconv_gradweight_gemm(node):
return
None
rval
=
GpuCorrMM_gradWeights
(
border_mode
=
node
.
op
.
border_mode
,
subsample
=
node
.
op
.
subsample
)(
subsample
=
node
.
op
.
subsample
,
filter_dilation
=
node
.
op
.
filter_dilation
)(
gpu_contiguous
(
img
),
gpu_contiguous
(
topgrad
),
shape
)
if
node
.
op
.
filter_flip
:
rval
=
rval
[:,
:,
::
-
1
,
::
-
1
]
...
...
@@ -2849,7 +2858,8 @@ def local_abstractconv_gradinputs_gemm(node):
kern
=
kern
[:,
:,
::
-
1
,
::
-
1
]
rval
=
GpuCorrMM_gradInputs
(
border_mode
=
node
.
op
.
border_mode
,
subsample
=
node
.
op
.
subsample
)(
subsample
=
node
.
op
.
subsample
,
filter_dilation
=
node
.
op
.
filter_dilation
)(
gpu_contiguous
(
kern
),
gpu_contiguous
(
topgrad
),
shape
)
return
[
rval
]
...
...
@@ -2870,10 +2880,12 @@ conv_groupopt.register('local_abstractconv_dnn',
conv_groupopt
.
register
(
'local_abstractconv_gemm'
,
local_abstractconv_gemm
,
30
,
'conv_gemm'
,
'gpu'
,
'fast_compile'
,
'fast_run'
)
conv_groupopt
.
register
(
'local_abstractconv_gradweight_gemm'
,
local_abstractconv_gradweight_gemm
,
30
,
'conv_gemm'
,
'gpu'
,
'fast_compile'
,
'fast_run'
)
conv_groupopt
.
register
(
'local_abstractconv_gradinputs_gemm'
,
local_abstractconv_gradinputs_gemm
,
30
,
'conv_gemm'
,
...
...
theano/sandbox/cuda/tests/test_abstractconv.py
浏览文件 @
a668c6c5
...
...
@@ -29,25 +29,30 @@ class TestDnnConv2d(test_abstract_conv.BaseTestConv2d):
self
.
provide_shape
=
[
False
]
self
.
shared
=
gpu_shared
def
tcase
(
self
,
i
,
f
,
s
,
b
,
flip
,
provide_shape
):
def
tcase
(
self
,
i
,
f
,
s
,
b
,
flip
,
provide_shape
,
fd
=
(
1
,
1
)):
if
fd
!=
(
1
,
1
):
raise
SkipTest
(
"No dilation implementation for cuDNN ConvOp."
)
if
not
dnn_available
():
raise
SkipTest
(
cuda
.
dnn
.
dnn_available
.
msg
)
mode
=
mode_with_gpu
o
=
self
.
get_output_shape
(
i
,
f
,
s
,
b
)
o
=
self
.
get_output_shape
(
i
,
f
,
s
,
b
,
fd
)
self
.
run_fwd
(
inputs_shape
=
i
,
filters_shape
=
f
,
subsample
=
s
,
verify_grad
=
True
,
mode
=
mode
,
provide_shape
=
provide_shape
,
border_mode
=
b
,
filter_flip
=
flip
,
target_op
=
GpuDnnConv
)
filter_flip
=
flip
,
target_op
=
GpuDnnConv
,
filter_dilation
=
fd
)
self
.
run_gradweight
(
inputs_shape
=
i
,
filters_shape
=
f
,
output_shape
=
o
,
subsample
=
s
,
verify_grad
=
True
,
mode
=
mode
,
provide_shape
=
provide_shape
,
border_mode
=
b
,
filter_flip
=
flip
,
target_op
=
GpuDnnConvGradW
)
filter_flip
=
flip
,
target_op
=
GpuDnnConvGradW
,
filter_dilation
=
fd
)
self
.
run_gradinput
(
inputs_shape
=
i
,
filters_shape
=
f
,
output_shape
=
o
,
subsample
=
s
,
verify_grad
=
True
,
mode
=
mode
,
provide_shape
=
provide_shape
,
border_mode
=
b
,
filter_flip
=
flip
,
target_op
=
GpuDnnConvGradI
)
filter_flip
=
flip
,
target_op
=
GpuDnnConvGradI
,
filter_dilation
=
fd
)
class
TestCorrMMConv2d
(
test_abstract_conv
.
BaseTestConv2d
):
...
...
@@ -56,28 +61,30 @@ class TestCorrMMConv2d(test_abstract_conv.BaseTestConv2d):
self
.
shared
=
gpu_shared
self
.
mode
=
mode_with_gpu
.
excluding
(
'cudnn'
)
def
tcase
(
self
,
i
,
f
,
s
,
b
,
flip
,
provide_shape
):
def
tcase
(
self
,
i
,
f
,
s
,
b
,
flip
,
provide_shape
,
fd
=
(
1
,
1
)
):
mode
=
self
.
mode
o
=
self
.
get_output_shape
(
i
,
f
,
s
,
b
)
self
.
run_fwd
(
inputs_shape
=
i
,
filters_shape
=
f
,
subsample
=
s
,
verify_grad
=
True
,
mode
=
mode
,
o
=
self
.
get_output_shape
(
i
,
f
,
s
,
b
,
fd
)
self
.
run_fwd
(
inputs_shape
=
i
,
filters_shape
=
f
,
subsample
=
s
,
verify_grad
=
True
,
mode
=
mode
,
provide_shape
=
provide_shape
,
border_mode
=
b
,
filter_flip
=
flip
,
target_op
=
(
GpuCorrMM
,
GpuCorrMM_gradWeights
,
GpuCorrMM_gradInputs
)
)
filter_flip
=
flip
,
target_op
=
(
GpuCorrMM
,
GpuCorrMM_gradWeights
,
GpuCorrMM_gradInputs
)
,
filter_dilation
=
fd
)
self
.
run_gradweight
(
inputs_shape
=
i
,
filters_shape
=
f
,
output_shape
=
o
,
subsample
=
s
,
verify_grad
=
True
,
mode
=
mode
,
provide_shape
=
provide_shape
,
border_mode
=
b
,
filter_flip
=
flip
,
target_op
=
GpuCorrMM_gradWeights
)
target_op
=
GpuCorrMM_gradWeights
,
filter_dilation
=
fd
)
self
.
run_gradinput
(
inputs_shape
=
i
,
filters_shape
=
f
,
output_shape
=
o
,
subsample
=
s
,
verify_grad
=
True
,
mode
=
mode
,
provide_shape
=
provide_shape
,
border_mode
=
b
,
filter_flip
=
flip
,
target_op
=
GpuCorrMM_gradInputs
)
target_op
=
GpuCorrMM_gradInputs
,
filter_dilation
=
fd
)
class
TestDnnConvTypes
(
test_abstract_conv
.
TestConvTypes
):
...
...
theano/tensor/nnet/__init__.py
浏览文件 @
a668c6c5
...
...
@@ -35,7 +35,7 @@ from .abstract_conv import conv2d as abstract_conv2d
def
conv2d
(
input
,
filters
,
input_shape
=
None
,
filter_shape
=
None
,
border_mode
=
'valid'
,
subsample
=
(
1
,
1
),
filter_flip
=
True
,
image_shape
=
None
,
**
kwargs
):
image_shape
=
None
,
filter_dilation
=
(
1
,
1
),
**
kwargs
):
"""
This function will build the symbolic graph for convolving a mini-batch of a
stack of 2D inputs with a set of 2D filters. The implementation is modelled
...
...
@@ -95,6 +95,10 @@ def conv2d(input, filters, input_shape=None, filter_shape=None,
image_shape: None, tuple/list of len 4 of int or Constant variable
Deprecated alias for input_shape.
filter_dilation: tuple of len 2
Factor by which to subsample (stride) the input.
Also called dilation elsewhere.
kwargs: Any other keyword arguments are accepted for backwards
compatibility, but will be ignored.
...
...
@@ -140,4 +144,5 @@ def conv2d(input, filters, input_shape=None, filter_shape=None,
" be provided at the same time."
)
return
abstract_conv2d
(
input
,
filters
,
input_shape
,
filter_shape
,
border_mode
,
subsample
,
filter_flip
)
border_mode
,
subsample
,
filter_flip
,
filter_dilation
)
theano/tensor/nnet/abstract_conv.py
浏览文件 @
a668c6c5
差异被折叠。
点击展开。
theano/tensor/nnet/corr.py
浏览文件 @
a668c6c5
差异被折叠。
点击展开。
theano/tensor/nnet/corr_gemm.c
浏览文件 @
a668c6c5
...
...
@@ -6,13 +6,13 @@ Copyright (c) 2014, The Regents of the University of California (Regents)
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
modification, are permitted provided that the following conditions are met:
1. Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
...
...
@@ -31,20 +31,24 @@ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
// Loops for fast unfold + copy
void
im2col
(
const
%
(
float_type
)
s
*
data_im
,
const
int
channels
,
const
int
height
,
const
int
width
,
const
int
kernel_h
,
const
int
kernel_w
,
const
int
dilation_h
,
const
int
dilation_w
,
const
int
pad_h
,
const
int
pad_w
,
const
int
stride_h
,
const
int
stride_w
,
%
(
float_type
)
s
*
data_col
)
{
int
height_col
=
(
height
+
2
*
pad_h
-
kernel_h
)
/
stride_h
+
1
;
int
width_col
=
(
width
+
2
*
pad_w
-
kernel_w
)
/
stride_w
+
1
;
// Implicit dilated kernel size
int
dil_kernel_h
=
(
kernel_h
-
1
)
*
dilation_h
+
1
;
int
dil_kernel_w
=
(
kernel_w
-
1
)
*
dilation_w
+
1
;
int
height_col
=
(
height
+
2
*
pad_h
-
dil_kernel_h
)
/
stride_h
+
1
;
int
width_col
=
(
width
+
2
*
pad_w
-
dil_kernel_w
)
/
stride_w
+
1
;
int
channels_col
=
channels
*
kernel_h
*
kernel_w
;
for
(
int
c
=
0
;
c
<
channels_col
;
++
c
)
{
int
w_offset
=
c
%%
kernel_w
;
int
h_offset
=
(
c
/
kernel_w
)
%%
kernel_h
;
int
c_im
=
c
/
kernel_h
/
kernel_w
;
for
(
int
h
=
0
;
h
<
height_col
;
++
h
)
{
int
h_pad
=
h
*
stride_h
-
pad_h
+
h_offset
*
dilation_h
;
for
(
int
w
=
0
;
w
<
width_col
;
++
w
)
{
int
h_pad
=
h
*
stride_h
-
pad_h
+
h_offset
;
int
w_pad
=
w
*
stride_w
-
pad_w
+
w_offset
;
int
w_pad
=
w
*
stride_w
-
pad_w
+
w_offset
*
dilation_w
;
if
(
h_pad
>=
0
&&
h_pad
<
height
&&
w_pad
>=
0
&&
w_pad
<
width
)
data_col
[(
npy_intp
)(
c
*
height_col
+
h
)
*
width_col
+
w
]
=
data_im
[(
npy_intp
)(
c_im
*
height
+
h_pad
)
*
width
+
w_pad
];
...
...
@@ -60,10 +64,14 @@ void im2col(const %(float_type)s* data_im, const int channels,
// accumulated into data_im.
void
col2im
(
const
%
(
float_type
)
s
*
data_col
,
const
int
channels
,
const
int
height
,
const
int
width
,
const
int
patch_h
,
const
int
patch_w
,
const
int
dilation_h
,
const
int
dilation_w
,
const
int
pad_h
,
const
int
pad_w
,
const
int
stride_h
,
const
int
stride_w
,
%
(
float_type
)
s
*
data_im
)
{
int
height_col
=
(
height
+
2
*
pad_h
-
patch_h
)
/
stride_h
+
1
;
int
width_col
=
(
width
+
2
*
pad_w
-
patch_w
)
/
stride_w
+
1
;
// Implicit dilated patch
int
dil_patch_h
=
(
patch_h
-
1
)
*
dilation_h
+
1
;
int
dil_patch_w
=
(
patch_w
-
1
)
*
dilation_w
+
1
;
int
height_col
=
(
height
+
2
*
pad_h
-
dil_patch_h
)
/
stride_h
+
1
;
int
width_col
=
(
width
+
2
*
pad_w
-
dil_patch_w
)
/
stride_w
+
1
;
int
num_kernels
=
channels
*
height
*
width
;
int
channels_col
=
channels
*
patch_h
*
patch_w
;
for
(
int
c
=
0
;
c
<
channels_col
;
++
c
)
{
...
...
@@ -71,9 +79,9 @@ void col2im(const %(float_type)s* data_col, const int channels,
int
h_offset
=
(
c
/
patch_w
)
%%
patch_h
;
int
c_im
=
c
/
patch_h
/
patch_w
;
for
(
int
h
=
0
;
h
<
height_col
;
++
h
)
{
int
h_pad
=
h
*
stride_h
-
pad_h
+
h_offset
*
dilation_h
;
for
(
int
w
=
0
;
w
<
width_col
;
++
w
)
{
int
h_pad
=
h
*
stride_h
-
pad_h
+
h_offset
;
int
w_pad
=
w
*
stride_w
-
pad_w
+
w_offset
;
int
w_pad
=
w
*
stride_w
-
pad_w
+
w_offset
*
dilation_w
;
if
(
h_pad
>=
0
&&
h_pad
<
height
&&
w_pad
>=
0
&&
w_pad
<
width
)
data_im
[(
npy_intp
)(
c_im
*
height
+
h_pad
)
*
width
+
w_pad
]
+=
data_col
[(
npy_intp
)(
c
*
height_col
+
h
)
*
width_col
+
w
];
...
...
@@ -91,13 +99,15 @@ void col2im(const %(float_type)s* data_col, const int channels,
// CPU version author: Jesse Livezey
// CPU version adapted from GPU version
PyArrayObject
*
corrMM
(
PyArrayObject
*
bottom
,
PyArrayObject
*
weight
,
PyArrayObject
*
top
,
const
int
direction
,
const
int
dH
=
1
,
const
int
dW
=
1
,
const
int
padH
=
0
,
const
int
padW
=
0
)
PyArrayObject
*
weight
,
PyArrayObject
*
top
,
const
int
direction
,
const
int
dH
=
1
,
const
int
dW
=
1
,
const
int
dilH
=
1
,
const
int
dilW
=
1
,
const
int
padH
=
0
,
const
int
padW
=
0
)
{
if
(
PyArray_NDIM
(
bottom
)
!=
4
)
{
...
...
@@ -109,7 +119,7 @@ PyArrayObject* corrMM(PyArrayObject* bottom,
PyErr_SetString
(
PyExc_ValueError
,
"CorrMM received bottom with wrong type."
);
return
NULL
;
}
if
(
PyArray_NDIM
(
weight
)
!=
4
)
{
PyErr_SetString
(
PyExc_ValueError
,
"CorrMM requires weight of 4D"
);
...
...
@@ -151,9 +161,12 @@ PyArrayObject* corrMM(PyArrayObject* bottom,
"CorrMM images and kernel must have the same stack size
\n
"
);
return
NULL
;
}
// implicit dilated filter
const
int
dil_kH
=
(
kH
-
1
)
*
dilH
+
1
;
const
int
dil_kW
=
(
kW
-
1
)
*
dilW
+
1
;
// top: (batchSize, nFilters, topHeight, topWidth)
const
int
topHeight
=
(
bottomHeight
+
2
*
padH
-
kH
)
/
dH
+
1
;
const
int
topWidth
=
(
bottomWidth
+
2
*
padW
-
kW
)
/
dW
+
1
;
const
int
topHeight
=
(
bottomHeight
+
2
*
padH
-
dil_
kH
)
/
dH
+
1
;
const
int
topWidth
=
(
bottomWidth
+
2
*
padW
-
dil_
kW
)
/
dW
+
1
;
if
(
batchSize
!=
PyArray_DIMS
(
top
)[
0
]
||
nFilters
!=
PyArray_DIMS
(
top
)[
1
]
||
topHeight
!=
PyArray_DIMS
(
top
)[
2
]
||
...
...
@@ -176,9 +189,9 @@ PyArrayObject* corrMM(PyArrayObject* bottom,
col_dim
[
0
]
=
(
npy_intp
)(
nChannels
*
kW
*
kH
);
col_dim
[
1
]
=
(
npy_intp
)(
topHeight
*
topWidth
);
PyArrayObject
*
col
=
(
PyArrayObject
*
)
PyArray_EMPTY
(
2
,
col_dim
,
PyArray_TYPE
(
top
),
0
);
col_dim
,
PyArray_TYPE
(
top
),
0
);
if
(
NULL
==
col
)
{
PyErr_Format
(
PyExc_RuntimeError
,
...
...
@@ -206,7 +219,8 @@ PyArrayObject* corrMM(PyArrayObject* bottom,
for
(
int
n
=
0
;
n
<
batchSize
;
n
++
)
{
// First, im2col
im2col
((
%
(
float_type
)
s
*
)
PyArray_DATA
(
bottom
)
+
n
*
bottom_stride
,
nChannels
,
bottomHeight
,
bottomWidth
,
kH
,
kW
,
padH
,
padW
,
dH
,
dW
,
(
%
(
float_type
)
s
*
)
PyArray_DATA
(
col
));
bottomWidth
,
kH
,
kW
,
dilH
,
dilW
,
padH
,
padW
,
dH
,
dW
,
(
%
(
float_type
)
s
*
)
PyArray_DATA
(
col
));
// Second, gemm
%
(
gemm
)
s
(
&
NTrans
,
&
NTrans
,
&
N_
,
&
M_
,
&
K_
,
...
...
@@ -255,7 +269,8 @@ PyArrayObject* corrMM(PyArrayObject* bottom,
for
(
int
n
=
0
;
n
<
batchSize
;
n
++
)
{
// First, im2col
im2col
((
%
(
float_type
)
s
*
)
PyArray_DATA
(
bottom
)
+
n
*
bottom_stride
,
nChannels
,
bottomHeight
,
bottomWidth
,
kH
,
kW
,
padH
,
padW
,
dH
,
dW
,
(
%
(
float_type
)
s
*
)
PyArray_DATA
(
col
));
bottomWidth
,
kH
,
kW
,
dilH
,
dilW
,
padH
,
padW
,
dH
,
dW
,
(
%
(
float_type
)
s
*
)
PyArray_DATA
(
col
));
// Second, gemm
// Note that we accumulate into weight. We do so by setting beta = 0
// for the first iteration and beta = 1 for subsequent ones. (This
...
...
@@ -299,7 +314,7 @@ PyArrayObject* corrMM(PyArrayObject* bottom,
}
else
if
(
direction
==
2
)
{
// backprop wrt. inputs
output
=
bottom
;
// bottom is set to zero here rather than inside of col2im
// bottom is set to zero here rather than inside of col2im
PyArray_FILLWBYTE
(
bottom
,
0
);
// full convolution: gemm, then col2im
// Iterate over batch
...
...
@@ -314,7 +329,8 @@ PyArrayObject* corrMM(PyArrayObject* bottom,
(
%
(
float_type
)
s
*
)
PyArray_DATA
(
col
),
&
N_
);
// col2im back to the data
col2im
((
%
(
float_type
)
s
*
)
PyArray_DATA
(
col
),
nChannels
,
bottomHeight
,
bottomWidth
,
kH
,
kW
,
padH
,
padW
,
dH
,
dW
,
(
%
(
float_type
)
s
*
)
PyArray_DATA
(
bottom
)
+
n
*
bottom_stride
);
kH
,
kW
,
dilH
,
dilW
,
padH
,
padW
,
dH
,
dW
,
(
%
(
float_type
)
s
*
)
PyArray_DATA
(
bottom
)
+
n
*
bottom_stride
);
}
/*
// Original caffe code for comparison
...
...
theano/tensor/nnet/opt.py
浏览文件 @
a668c6c5
...
...
@@ -79,7 +79,8 @@ def local_abstractconv_gemm(node):
if
node
.
op
.
filter_flip
:
kern
=
kern
[:,
:,
::
-
1
,
::
-
1
]
rval
=
CorrMM
(
border_mode
=
node
.
op
.
border_mode
,
subsample
=
node
.
op
.
subsample
)(
img
,
kern
)
subsample
=
node
.
op
.
subsample
,
filter_dilation
=
node
.
op
.
filter_dilation
)(
img
,
kern
)
copy_stack_trace
(
node
.
outputs
[
0
],
rval
)
return
[
rval
]
...
...
@@ -97,7 +98,8 @@ def local_abstractconv_gradweight_gemm(node):
return
None
rval
=
CorrMM_gradWeights
(
border_mode
=
node
.
op
.
border_mode
,
subsample
=
node
.
op
.
subsample
)(
img
,
topgrad
,
shape
)
subsample
=
node
.
op
.
subsample
,
filter_dilation
=
node
.
op
.
filter_dilation
)(
img
,
topgrad
,
shape
)
copy_stack_trace
(
node
.
outputs
[
0
],
rval
)
# need to flip the kernel if necessary
...
...
@@ -124,8 +126,9 @@ def local_abstractconv_gradinputs_gemm(node):
if
node
.
op
.
filter_flip
:
kern
=
kern
[:,
:,
::
-
1
,
::
-
1
]
rval
=
CorrMM_gradInputs
(
border_mode
=
node
.
op
.
border_mode
,
subsample
=
node
.
op
.
subsample
)(
kern
,
topgrad
,
shape
)
subsample
=
node
.
op
.
subsample
,
filter_dilation
=
node
.
op
.
filter_dilation
)(
kern
,
topgrad
,
shape
)
copy_stack_trace
(
node
.
outputs
[
0
],
rval
)
return
[
rval
]
...
...
@@ -221,7 +224,9 @@ def local_conv2d_gradweight_cpu(node):
assert
len
(
op_imshp
)
==
4
and
len
(
op_kshp
)
==
4
outshp
=
get_conv_output_shape
(
op_imshp
,
op_kshp
,
node
.
op
.
border_mode
,
node
.
op
.
subsample
)[
2
:]
node
.
op
.
border_mode
,
node
.
op
.
subsample
,
node
.
op
.
filter_dilation
)[
2
:]
fulloutshp
=
get_conv_output_shape
(
op_imshp
,
op_kshp
,
node
.
op
.
border_mode
,
(
1
,
1
))[
2
:]
...
...
@@ -334,7 +339,9 @@ def local_conv2d_gradinputs_cpu(node):
filters
=
filters
[:,
:,
::
-
1
,
::
-
1
]
outshp
=
get_conv_output_shape
(
op_imshp
,
op_kshp
,
node
.
op
.
border_mode
,
node
.
op
.
subsample
)[
2
:]
node
.
op
.
border_mode
,
node
.
op
.
subsample
,
node
.
op
.
filter_dilation
)[
2
:]
fulloutshp
=
get_conv_output_shape
(
op_imshp
,
op_kshp
,
node
.
op
.
border_mode
,
(
1
,
1
))[
2
:]
...
...
theano/tensor/nnet/tests/test_abstract_conv.py
浏览文件 @
a668c6c5
差异被折叠。
点击展开。
theano/tensor/nnet/tests/test_corr.py
浏览文件 @
a668c6c5
...
...
@@ -32,8 +32,8 @@ class TestCorr2D(utt.InferShapeTester):
def
validate
(
self
,
image_shape
,
filter_shape
,
border_mode
=
'valid'
,
subsample
=
(
1
,
1
),
input
=
None
,
filters
=
None
,
verify_grad
=
True
,
non_contiguous
=
False
):
input
=
None
,
filters
=
None
,
verify_grad
=
True
,
non_contiguous
=
False
,
filter_dilation
=
(
1
,
1
)
):
"""
:param image_shape: The constant shape info passed to corrMM.
:param filter_shape: The constant shape info passed to corrMM.
...
...
@@ -55,7 +55,8 @@ class TestCorr2D(utt.InferShapeTester):
# define theano graph and function
input
.
name
=
'input'
filters
.
name
=
'filters'
rval
=
corr
.
CorrMM
(
border_mode
,
subsample
)(
input
,
filters
)
rval
=
corr
.
CorrMM
(
border_mode
,
subsample
,
filter_dilation
)(
input
,
filters
)
rval
.
name
=
'corr_output'
return
rval
...
...
@@ -86,20 +87,22 @@ class TestCorr2D(utt.InferShapeTester):
orig_image_data
=
image_data
img_shape2d
=
numpy
.
array
(
N_image_shape
[
-
2
:])
fil_shape2d
=
numpy
.
array
(
N_filter_shape
[
-
2
:])
dil_shape2d
=
numpy
.
array
(
filter_dilation
)
dil_fil_shape2d
=
(
fil_shape2d
-
1
)
*
dil_shape2d
+
1
subsample2d
=
numpy
.
array
(
subsample
)
if
border_mode
==
'full'
:
padHW
=
(
fil_shape2d
-
1
)
padHW
=
(
dil_
fil_shape2d
-
1
)
elif
border_mode
==
'valid'
:
padHW
=
numpy
.
array
([
0
,
0
])
elif
border_mode
==
'half'
:
padHW
=
numpy
.
floor
(
fil_shape2d
/
2
)
.
astype
(
'int32'
)
padHW
=
numpy
.
floor
(
dil_
fil_shape2d
/
2
)
.
astype
(
'int32'
)
elif
isinstance
(
border_mode
,
tuple
):
padHW
=
numpy
.
array
(
border_mode
)
elif
isinstance
(
border_mode
,
integer_types
):
padHW
=
numpy
.
array
([
border_mode
,
border_mode
])
else
:
raise
NotImplementedError
(
'Unsupported border_mode {}'
.
format
(
border_mode
))
out_shape2d
=
numpy
.
floor
((
img_shape2d
+
2
*
(
padHW
)
-
fil_shape2d
)
/
subsample2d
)
+
1
out_shape2d
=
numpy
.
floor
((
img_shape2d
+
2
*
(
padHW
)
-
dil_
fil_shape2d
)
/
subsample2d
)
+
1
# avoid numpy deprecation
out_shape2d
=
out_shape2d
.
astype
(
'int32'
)
out_shape
=
(
N_image_shape
[
0
],
N_filter_shape
[
0
])
+
tuple
(
out_shape2d
)
...
...
@@ -124,8 +127,8 @@ class TestCorr2D(utt.InferShapeTester):
for
col
in
range
(
ref_output
.
shape
[
3
]):
icol
=
col
*
subsample
[
1
]
# image col
ref_output
[
bb
,
nn
,
row
,
col
]
+=
(
image2d
[
irow
:
irow
+
N_filter_shape
[
2
],
icol
:
icol
+
N_filter_shape
[
3
]]
*
filter2d
[::
-
1
,
::
-
1
]
irow
:
irow
+
dil_fil_shape2d
[
0
]:
filter_dilation
[
0
],
icol
:
icol
+
dil_fil_shape2d
[
1
]:
filter_dilation
[
1
]]
*
filter2d
[::
-
1
,
::
-
1
]
)
.
sum
()
self
.
assertTrue
(
_allclose
(
theano_output
,
ref_output
))
...
...
@@ -186,6 +189,28 @@ class TestCorr2D(utt.InferShapeTester):
self
.
validate
((
1
,
1
,
6
,
6
),
(
1
,
1
,
3
,
3
),
1
,
subsample
=
(
3
,
3
))
def
test_filter_dilation
(
self
):
"""
Tests correlation where filter dilation != (1,1)
"""
self
.
validate
((
3
,
2
,
7
,
5
),
(
5
,
2
,
2
,
3
),
'valid'
,
filter_dilation
=
(
2
,
2
))
self
.
validate
((
3
,
2
,
14
,
10
),
(
5
,
2
,
2
,
3
),
'valid'
,
filter_dilation
=
(
3
,
1
))
self
.
validate
((
1
,
1
,
14
,
14
),
(
1
,
1
,
3
,
3
),
'valid'
,
filter_dilation
=
(
2
,
3
))
self
.
validate
((
3
,
2
,
7
,
5
),
(
5
,
2
,
2
,
3
),
'full'
,
filter_dilation
=
(
2
,
2
))
self
.
validate
((
3
,
2
,
7
,
5
),
(
5
,
2
,
2
,
3
),
'full'
,
filter_dilation
=
(
3
,
1
))
self
.
validate
((
1
,
1
,
6
,
6
),
(
1
,
1
,
3
,
3
),
'full'
,
filter_dilation
=
(
2
,
3
))
self
.
validate
((
3
,
2
,
7
,
5
),
(
5
,
2
,
2
,
3
),
'half'
,
filter_dilation
=
(
2
,
2
))
self
.
validate
((
3
,
2
,
7
,
5
),
(
5
,
2
,
2
,
3
),
'half'
,
filter_dilation
=
(
3
,
1
))
self
.
validate
((
1
,
1
,
6
,
6
),
(
1
,
1
,
3
,
3
),
'half'
,
filter_dilation
=
(
2
,
3
))
self
.
validate
((
3
,
2
,
7
,
5
),
(
5
,
2
,
2
,
3
),
(
1
,
1
),
filter_dilation
=
(
2
,
2
))
self
.
validate
((
3
,
2
,
7
,
5
),
(
5
,
2
,
2
,
3
),
(
2
,
1
),
filter_dilation
=
(
2
,
1
))
self
.
validate
((
1
,
1
,
6
,
6
),
(
1
,
1
,
3
,
3
),
(
1
,
2
),
filter_dilation
=
(
1
,
2
))
self
.
validate
((
1
,
1
,
6
,
6
),
(
1
,
1
,
3
,
3
),
1
,
subsample
=
(
3
,
3
),
filter_dilation
=
(
2
,
2
))
@attr
(
'slow'
)
def
test_shape_Constant_tensor
(
self
):
"""
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论