Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
cfc493d1
提交
cfc493d1
authored
9月 01, 2014
作者:
Frédéric Bastien
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #2033 from f0k/corrmm-faster-fullconv
Faster algorithms and gradients for GpuCorrMM
上级
a81b5cdc
372bab54
隐藏空白字符变更
内嵌
并排
正在显示
6 个修改的文件
包含
1062 行增加
和
360 行删除
+1062
-360
conv.txt
doc/library/tensor/nnet/conv.txt
+68
-45
blas.py
theano/sandbox/cuda/blas.py
+395
-91
caffe_common.hpp
theano/sandbox/cuda/caffe_common.hpp
+0
-47
conv_gemm.cu
theano/sandbox/cuda/conv_gemm.cu
+414
-129
opt.py
theano/sandbox/cuda/opt.py
+72
-14
test_conv_cuda_ndarray.py
theano/sandbox/cuda/tests/test_conv_cuda_ndarray.py
+113
-34
没有找到文件。
doc/library/tensor/nnet/conv.txt
浏览文件 @
cfc493d1
...
...
@@ -22,23 +22,28 @@
.. moduleauthor:: LISA
TODO: Give examples
for
how to use these things! They are pretty complicated.
TODO: Give examples
on
how to use these things! They are pretty complicated.
- Conv
implemented
- :func:`signal.conv2d <theano.tensor.signal.conv.conv2d>`.
- Conv
olution operators implemented:
- :func:`signal.conv2d <theano.tensor.signal.conv.conv2d>`.
See note above.
- :func:`nnet.conv2d <theano.tensor.nnet.conv.conv2d>`.
This is the standard operator for convolutional neural networks working
with batches of multi-channel 2D images, available for CPU and GPU.
Most of the more efficient GPU implementations listed below can be used
as an automatic replacement for nnet.conv2d by enabling specific graph
optimizations.
- :func:`conv2d_fft <theano.sandbox.cuda.fftconv.conv2d_fft>`
This is a GPU-only version of nnet.conv2d that uses an FFT transform
to perform the work. conv2d_fft should not be
us
ed directly as it
does not
implement a grad function. Instead, you should use
nnet.conv2d and enable the fft optimizat
ion by setting
'THEANO_FLAGS=optimizer_including=conv_fft_valid:conv_fft_full'
to perform the work. conv2d_fft should not be
call
ed directly as it
does not
provide a gradient. Instead, use nnet.conv2d and allow
Theano's graph optimizer to replace it by the FFT vers
ion by setting
``THEANO_FLAGS=optimizer_including=conv_fft_valid:conv_fft_full``
in your environement. This is not enabled by default because it
has some restrictions on input and uses more memory. Also note
has some restrictions on input and uses
a lot
more memory. Also note
that it requires CUDA >= 5.0, scikits.cuda >= 0.5.0 and PyCUDA to run.
To de
sactivate the fft
optimization on a specific nnet.conv2d
while the optimization flags are active, you can set its
parameters
version to 'no_fft'. To enable
for just one Theano function:
To de
activate the FFT
optimization on a specific nnet.conv2d
while the optimization flags are active, you can set its
``version``
parameter to ``'no_fft'``. To enable it
for just one Theano function:
.. code-block:: python
...
...
@@ -47,17 +52,58 @@ TODO: Give examples for how to use these things! They are pretty complicated.
f = theano.function(..., mode=mode)
- `cuda-convnet wrapper for 2d correlation <http://deeplearning.net/software/pylearn2/library/alex.html>`_
Wrapper for an open-source GPU-only implementation of conv2d by Alex
Krizhevsky, very fast, but with several restrictions on input and kernel
shapes, and with a different memory layout for the input.
This is in Pylearn2, where it is normally called from the `linear transform
<http://deeplearning.net/software/pylearn2/library/linear.html>`_
implementation, but it can also be used `directly from within Theano
<http://benanne.github.io/2014/04/03/faster-convolutions-in-theano.html>`_
as a manual replacement for nnet.conv2d.
- :func:`GpuCorrMM <theano.sandbox.cuda.blas.GpuCorrMM>`
This is a GPU-only 2d correlation implementation taken from
`caffe <https://github.com/BVLC/caffe/blob/master/src/caffe/layers/conv_layer.cu>`_
and also used by Torch.
For each element in a batch, it first creates a
`Toeplitz <http://en.wikipedia.org/wiki/Toeplitz_matrix>`_ matrix in a CUDA kernel.
Then, it performs a ``gemm`` call to multiply this Toeplitz matrix and the filters
(hence the name: MM is for matrix multiplication).
It needs extra memory for the Toeplitz matrix, which is a 2D matrix of shape
``(no of channels * filter width * filter height, output width * output height)``.
As it provides a gradient, you can use it as a replacement for nnet.conv2d.
Alternatively, you can use nnet.conv2d and allow Theano's graph optimizer
to replace it by the GEMM version by setting
``THEANO_FLAGS=optimizer_including=conv_gemm`` in your environment.
This is not enabled by default because it uses some extra memory, but the
overhead is small compared to conv2d_fft, there are no restrictions on
input or kernel shapes and it is sometimes still faster than cuda-convnet.
If using it, please see the warning about a bug in CUDA 5.0 to 6.0 below.
To enable it for just one Theano function:
.. code-block:: python
mode = theano.compile.get_default_mode()
mode = mode.including('conv_gemm')
f = theano.function(..., mode=mode)
- :func:`conv3D <theano.tensor.nnet.Conv3D.conv3D>`
3D Convolution. Doesn't work on the GPU.
3D Convolution applying multi-channel 3D filters to batches of
multi-channel 3D images.
- :func:`conv3d_fft <theano.sandbox.cuda.fftconv.conv3d_fft>`
GPU-only version of conv3D using FFT transform. conv3d_fft should
not be call
directly as it does not implement a grad function
.
You can enable it by setting THEANO_FLAGS to
'optimizer_including=conv3d_fft:convgrad3d_fft:convtransp3d_fft'
It does not support strides.
This is not enabled by default because it uses more memory.
Also note that it requires CUDA >= 5.0,
scikits.cuda >= 0.5.0 and PyCUDA to run.
not be call
ed directly as it does not provide a gradient
.
Instead, use conv3D and allow Theano's graph optimizer to replace it by
the FFT version by setting
``THEANO_FLAGS=optimizer_including=conv3d_fft:convgrad3d_fft:convtransp3d_fft``
in your environment. This is not enabled by default because it does not
support strides and uses more memory. Also note that it requires
CUDA >= 5.0,
scikits.cuda >= 0.5.0 and PyCUDA to run.
To enable for just one Theano function:
.. code-block:: python
...
...
@@ -70,33 +116,10 @@ TODO: Give examples for how to use these things! They are pretty complicated.
- :func:`conv3d2d <theano.tensor.nnet.conv3d2d.conv3d>`
Another conv3d implementation that uses the conv2d with data reshaping.
It is faster in some cases than conv3d, specifically on the GPU.
- `Faster conv2d <http://deeplearning.net/software/pylearn2/library/alex.html>`_
This is in Pylearn2, not very documented and uses a different
memory layout for the input. It is important to have the input
in the native memory layout, and not use dimshuffle on the
inputs, otherwise you lose most of the speed up. So this is not
a drop in replacement of conv2d.
Normally those are called from the `linear transform
<http://deeplearning.net/software/pylearn2/library/linear.html>`_
implementation.
Also, there is restrictions on which shape are supported.
- :func:`GpuCorrMM <theano.sandbox.cuda.blas.GpuCorrMM>`
This is a GPU-only version of a correlation that computes correlations
as `caffe <https://github.com/BVLC/caffe/blob/master/src/caffe/layers/conv_layer.cu>`_.
For each element in a batch, it first creates a
`Toeplitz <http://en.wikipedia.org/wiki/Toeplitz_matrix>`_ matrix in a cuda kernel.
Then, it performs a ``gemm`` call to multiply this Toeplitz matrix and the kernel.
It need extra memory equal to the size of the Toeplitz matrix. Precisely,
the dimensions of this 2D Toeplitz matrix is equal to
``(no of channels * filter width * filter height, output width * output height)``.
You can enable it for call to conv2d 2d by setting ``THEANO_FLAGS=optimizer_including=conv_gemm``
in your environment. This is not enabled by default because it
uses some extra memory. MM mean matrix multiply.
.. autofunction:: theano.tensor.nnet.conv.conv2d
.. autofunction:: theano.sandbox.cuda.fftconv.conv2d_fft
.. autofunction:: theano.sandbox.cuda.blas.GpuCorrMM
.. autofunction:: theano.tensor.nnet.Conv3D.conv3D
.. autofunction:: theano.sandbox.cuda.fftconv.conv3d_fft
.. autofunction:: theano.tensor.nnet.conv3d2d.conv3d
.. autofunction:: theano.sandbox.cuda.fftconv.conv2d_fft
theano/sandbox/cuda/blas.py
浏览文件 @
cfc493d1
...
...
@@ -8,6 +8,7 @@ from theano.compat.six import StringIO
from
theano.sandbox.cuda.type
import
CudaNdarrayType
from
theano.sandbox.cuda
import
GpuOp
from
theano.sandbox.cuda
import
as_cuda_ndarray_variable
from
theano.sandbox.cuda.basic_ops
import
gpu_contiguous
class
GpuDot22
(
GpuOp
):
...
...
@@ -500,30 +501,22 @@ gpu_ger_no_inplace = GpuGer(inplace=False)
gpu_ger_inplace
=
GpuGer
(
inplace
=
True
)
class
GpuCorrMM
(
GpuOp
):
"""GPU correlation implementation using Matrix Multiply.
class
BaseGpuCorrMM
(
GpuOp
):
"""Base class for `GpuCorrMM`, `GpuCorrMM_gradWeights` and
`GpuCorrMM_gradInputs`. Cannot be used directly."""
:note: It don't implement the grad. So you should use it by
enabling the Theano flag ``optimizer_including=conv_gemm`` and
use :func:`conv2d <theano.tensor.nnet.conv.conv2d>`.
"""
def
__init__
(
self
,
border_mode
,
def
__init__
(
self
,
border_mode
=
"valid"
,
subsample
=
(
1
,
1
),
pad
=
0
):
"""
:param border_mode: "valid" or "full"
:param subsample: the subsample operation applied on each output image.
Should be a tuple with 2 elements.
(sv, sh) is equivalent to GpuCorrMM(...)(...)[:,:,::sv, ::sh]
:param pad: not yet supported
"""
pad
=
(
0
,
0
)):
if
border_mode
!=
"valid"
:
raise
ValueError
(
"border_mode must be 'valid'"
)
self
.
border_mode
=
border_mode
if
len
(
subsample
)
!=
2
:
raise
ValueError
(
"subsample must have two elements"
)
self
.
subsample
=
subsample
if
(
pad
not
in
(
"half"
,
"full"
))
and
(
len
(
pad
)
!=
2
):
raise
ValueError
(
"pad must be 'half', 'full', or have two elements"
)
self
.
pad
=
pad
if
pad
!=
0
:
raise
NotImplementedError
(
"GpuCorrMM don't implement the pad parameter"
)
def
__eq__
(
self
,
other
):
return
type
(
self
)
==
type
(
other
)
\
...
...
@@ -540,40 +533,25 @@ class GpuCorrMM(GpuOp):
^
hash
(
self
.
pad
)
def
__str__
(
self
):
return
'
%
s{
%
s,
%
s, pad=
%
d
}'
%
(
return
'
%
s{
%
s,
%
s, pad=
%
r
}'
%
(
self
.
__class__
.
__name__
,
self
.
border_mode
,
str
(
self
.
subsample
),
self
.
pad
)
def
make_node
(
self
,
img
,
kern
):
img
=
as_cuda_ndarray_variable
(
img
)
kern
=
as_cuda_ndarray_variable
(
kern
)
if
img
.
type
.
ndim
!=
4
:
raise
TypeError
(
'img must be 4D tensor'
)
if
kern
.
type
.
ndim
!=
4
:
raise
TypeError
(
'kern must be 4D tensor'
)
broadcastable
=
[
img
.
type
.
broadcastable
[
0
],
kern
.
type
.
broadcastable
[
0
],
False
,
False
]
return
Apply
(
self
,
[
img
,
kern
],
[
CudaNdarrayType
(
broadcastable
)()])
def
flops
(
self
,
inputs
,
outputs
):
images
,
kerns
=
inputs
out
,
=
outputs
assert
images
[
1
]
==
kerns
[
1
]
flops
=
0
if
self
.
border_mode
==
"valid"
:
# nb mul and add by output pixel
flops
=
kerns
[
2
]
*
kerns
[
3
]
*
2
# nb flops by output image
flops
*=
out
[
2
]
*
out
[
3
]
# nb patch multiplied
flops
*=
images
[
1
]
*
kerns
[
0
]
*
images
[
0
]
else
:
flops
=
(
images
[
0
]
*
kerns
[
0
]
*
images
[
1
]
*
kerns
[
2
]
*
kerns
[
3
]
*
images
[
2
]
*
images
[
3
]
*
2
)
def
flops
(
self
,
inp
,
outp
):
""" Useful with the hack in profilemode to print the MFlops"""
# if the output shape is correct, then this gives the correct
# flops for any direction, sampling, padding, and border mode
inputs
,
filters
=
inp
outputs
,
=
outp
assert
inputs
[
1
]
==
filters
[
1
]
# nb mul and add by output pixel
flops
=
filters
[
2
]
*
filters
[
3
]
*
2
# nb flops by output image
flops
*=
outputs
[
2
]
*
outputs
[
3
]
# nb patch multiplied
flops
*=
inputs
[
1
]
*
filters
[
0
]
*
inputs
[
0
]
return
flops
def
c_headers
(
self
):
...
...
@@ -581,7 +559,7 @@ class GpuCorrMM(GpuOp):
def
c_code_cache_version
(
self
):
# raise this whenever modifying any of the support_code_files
return
(
0
,
2
2
)
return
(
0
,
2
3
)
def
c_support_code_apply
(
self
,
node
,
nodename
):
# REMEMBER TO RAISE c_code_cache_version when changing any of
...
...
@@ -591,55 +569,182 @@ class GpuCorrMM(GpuOp):
for
f
in
files
]
return
reduce
(
str
.
__add__
,
codes
)
def
c_code
(
self
,
node
,
nodename
,
inp
,
out_
,
sub
):
img
,
kern
=
inp
out
,
=
out_
dx
=
self
.
subsample
[
0
]
dy
=
self
.
subsample
[
1
]
sub
=
sub
.
copy
()
pad
=
self
.
pad
if
self
.
border_mode
==
"valid"
:
bmode
=
1
def
c_code_helper
(
self
,
bottom
,
weights
,
top
,
direction
,
sub
,
height
=
None
,
width
=
None
):
"""
This generates the C code for GpuCorrMM (direction="forward"),
GpuCorrMM_gradWeights (direction="backprop weights"), and
GpuCorrMM_gradInputs (direction="backprop inputs").
Depending on the direction, one of bottom, weights, top will
receive the output, while the other two serve as inputs.
:param bottom: Variable name of the input images in the forward pass,
or the gradient of the input images in backprop wrt. inputs
:param weights: Variable name of the filters in the forward pass,
or the gradient of the filters in backprop wrt. weights
:param top: Variable name of the output images / feature maps in the
forward pass, or the gradient of the outputs in the backprop passes
:param direction: "forward" to correlate bottom with weights and store
results in top,
"backprop weights" to do a valid convolution of bottom with top
(swapping the first two dimensions) and store results in weights,
and "backprop inputs" to do a full convolution of top with weights
(swapping the first two dimensions) and store results in bottom.
:param sub: Dictionary of substitutions useable to help generating the
C code.
:param height: If self.subsample[0] != 1, a variable giving the height
of the filters for direction="backprop weights" or the height of the
input images for direction="backprop inputs".
If self.pad == 'half', a variable giving the height of the filters
for direction="backprop weights".
Ignored otherwise.
:param width: If self.subsample[1] != 1, a variable giving the width
of the filters for direction="backprop weights" or the width of the
input images for direction="backprop inputs".
If self.pad == 'half', a variable giving the width of the filters
for direction="backprop weights".
Ignored otherwise.
"""
if
self
.
border_mode
!=
"valid"
:
raise
ValueError
(
"mode must be 'valid'"
)
dH
,
dW
=
self
.
subsample
if
self
.
pad
==
"half"
:
padH
=
padW
=
-
1
elif
self
.
pad
==
"full"
:
padH
=
padW
=
-
2
else
:
assert
self
.
border_mode
==
"full"
bmode
=
0
padH
,
padW
=
self
.
pad
if
direction
==
"forward"
:
direction
=
0
out
=
top
elif
direction
==
"backprop weights"
:
direction
=
1
out
=
weights
elif
direction
==
"backprop inputs"
:
direction
=
2
out
=
bottom
else
:
raise
ValueError
(
"direction must be one of 'forward', "
"'backprop weights', 'backprop inputs'"
)
# When subsampling, we cannot unambiguously infer the height and width
# of bottom and weights from top, so we require them to be given.
# Similarly, when pad="half", we cannot infer the weight size.
if
((
direction
!=
0
)
and
(
dH
!=
1
))
or
((
direction
==
1
)
and
(
padH
==
-
1
)):
if
not
height
:
raise
ValueError
(
"height must be given for backprop with vertical sampling or pad='half'"
)
height
=
'(*(npy_int*)(PyArray_DATA(
%
s)))'
%
height
else
:
height
=
'NULL'
if
((
direction
!=
0
)
and
(
dW
!=
1
))
or
((
direction
==
1
)
and
(
padW
==
-
1
)):
if
not
width
:
raise
ValueError
(
"width must be given for backprop with horizontal sampling or pad='half'"
)
width
=
'(*(npy_int*)(PyArray_DATA(
%
s)))'
%
width
else
:
width
=
'NULL'
sub
=
sub
.
copy
()
sub
.
update
(
locals
())
return
"""
//Mandatory args
int
mode =
%(bmode)
s;
//Optional args
int d
x =
%(dx
)
s;
int d
y =
%(dy
)
s;
int padH =
0
;
int padW =
0
;
//
Mandatory args
int
direction =
%(direction)
s; // forward, bprop weights, bprop inputs
//
Optional args
int d
H =
%(dH
)
s;
int d
W =
%(dW
)
s;
int padH =
%(padH)
s
;
int padW =
%(padW)
s
;
CudaNdarray * img =
%(img)
s;
CudaNdarray * kern =
%(kern)
s;
CudaNdarray * bottom =
%(bottom)
s;
CudaNdarray * weights =
%(weights)
s;
CudaNdarray * top =
%(top)
s;
CudaNdarray * out2 = NULL;
//TODO: Send self.pad, stride, etc
int out_dim[4];
out_dim[0] = CudaNdarray_HOST_DIMS(img)[0];
out_dim[1] = CudaNdarray_HOST_DIMS(kern)[0];
int logical_rows, logical_cols;
if (mode == 1)
{
logical_rows = CudaNdarray_HOST_DIMS(img)[2] - CudaNdarray_HOST_DIMS(kern)[2] + 1;
logical_cols = CudaNdarray_HOST_DIMS(img)[3] - CudaNdarray_HOST_DIMS(kern)[3] + 1;
// Obtain or infer kernel width and height
// (we need to know it early to be able to handle auto-padding)
int kH, kW;
if (direction != 1) {
// weight is an input variable, we can just read its shape
kH = CudaNdarray_HOST_DIMS(weights)[2];
kW = CudaNdarray_HOST_DIMS(weights)[3];
}
else
{
logical_rows = CudaNdarray_HOST_DIMS(img)[2] + CudaNdarray_HOST_DIMS(kern)[2] - 1;
logical_cols = CudaNdarray_HOST_DIMS(img)[3] + CudaNdarray_HOST_DIMS(kern)[3] - 1;
padH = CudaNdarray_HOST_DIMS(kern)[2] - 1;
padW = CudaNdarray_HOST_DIMS(kern)[3] - 1;
else {
if ((dH != 1) || (padH == -1)) {
// vertical subsampling or half padding, kernel height is specified
kH =
%(height)
s;
}
else if (padH == -2) {
// vertical full padding, we can infer the kernel height
kH = 2 - CudaNdarray_HOST_DIMS(bottom)[2] + (CudaNdarray_HOST_DIMS(top)[2] - 1) * dH;
}
else {
// explicit padding, we can infer the kernel height
kH = CudaNdarray_HOST_DIMS(bottom)[2] + 2*padH - (CudaNdarray_HOST_DIMS(top)[2] - 1) * dH;
}
if ((dW != 1) || (padW == -1)) {
kW =
%(width)
s;
}
else if (padW == -2) {
kW = 2 - CudaNdarray_HOST_DIMS(bottom)[3] + (CudaNdarray_HOST_DIMS(top)[3] - 1) * dW;
}
else {
kW = CudaNdarray_HOST_DIMS(bottom)[3] + 2*padW - (CudaNdarray_HOST_DIMS(top)[3] - 1) * dW;
}
}
out_dim[2] = ceil_intdiv(logical_rows, dx);
out_dim[3] = ceil_intdiv(logical_cols, dy);
// Auto-padding if requested
if (padH == -1) { // vertical half padding
padH = kH / 2;
}
else if (padH == -2) { // vertical full padding
padH = kH - 1;
}
else if (padH < 0) {
PyErr_SetString(PyExc_ValueError, "BaseGpuCorrMM: padH must be >= -2");
%(fail)
s
}
if (padW == -1) { // horizontal half padding
padW = kW / 2;
}
else if (padW == -2) { // horizontal full padding
padW = kW - 1;
}
else if (padW < 0) {
PyErr_SetString(PyExc_ValueError, "BaseGpuCorrMM: padW must be >= -2");
%(fail)
s
}
// Infer output shape
int out_dim[4];
switch(direction) {
case 0: // forward pass
// output is top: (batchsize, num_filters, height, width)
// height and width: top = (bottom + 2*pad - weight) / sample + 1
out_dim[0] = CudaNdarray_HOST_DIMS(bottom)[0];
out_dim[1] = CudaNdarray_HOST_DIMS(weights)[0];
out_dim[2] = (CudaNdarray_HOST_DIMS(bottom)[2] + 2*padH - CudaNdarray_HOST_DIMS(weights)[2]) / dH + 1;
out_dim[3] = (CudaNdarray_HOST_DIMS(bottom)[3] + 2*padW - CudaNdarray_HOST_DIMS(weights)[3]) / dW + 1;
break;
case 1: // backprop wrt. weights
// output is weights: (num_filters, num_channels, height, width)
// height and width: weights = bottom + 2*pad - (top - 1) * sample
out_dim[0] = CudaNdarray_HOST_DIMS(top)[1];
out_dim[1] = CudaNdarray_HOST_DIMS(bottom)[1];
out_dim[2] = kH; // already inferred further above
out_dim[3] = kW; // how convenient
break;
case 2: // backprop wrt. inputs
// output is bottom: (batchsize, num_channels, height, width)
// height and width: bottom = (top - 1) * sample + weights - 2*pad
out_dim[0] = CudaNdarray_HOST_DIMS(top)[0];
out_dim[1] = CudaNdarray_HOST_DIMS(weights)[1];
out_dim[2] = (dH != 1) ?
%(height)
s : (CudaNdarray_HOST_DIMS(top)[2] - 1) * dH + CudaNdarray_HOST_DIMS(weights)[2] - 2*padH;
out_dim[3] = (dW != 1) ?
%(width)
s : (CudaNdarray_HOST_DIMS(top)[3] - 1) * dW + CudaNdarray_HOST_DIMS(weights)[3] - 2*padW;
break;
default:
PyErr_SetString(PyExc_ValueError, "BaseGpuCorrMM: direction must be 0, 1, or 2
\\
n");
%(fail)
s
}
// Prepare output array
if ( !(
%(out)
s
&&
%(out)
s->nd==4
&& CudaNdarray_is_c_contiguous(
%(out)
s)
...
...
@@ -650,10 +755,17 @@ class GpuCorrMM(GpuOp):
{
Py_XDECREF(
%(out)
s);
%(out)
s = (CudaNdarray*)CudaNdarray_NewDims(4,out_dim);
if (NULL ==
%(out)
s)
{
PyErr_Format(PyExc_RuntimeError,
"BaseGpuCorrMM: Failed to allocate output of
%%
d x
%%
d x
%%
d x
%%
d",
out_dim[0], out_dim[1], out_dim[2], out_dim[3]);
%(fail)
s
}
}
out2 = corrMM(
%(img)
s,
%(kern)
s,
%(out)
s, dx, dy, padH, padW);
// Call CUDA code
out2 = corrMM(
%(bottom)
s,
%(weights)
s,
%(top)
s, direction, dH, dW, padH, padW);
if (out2==NULL){
%(fail)
s
}
...
...
@@ -662,6 +774,186 @@ class GpuCorrMM(GpuOp):
"""
%
sub
class
GpuCorrMM
(
BaseGpuCorrMM
):
"""GPU correlation implementation using Matrix Multiplication.
:note: You can either enable the Theano flag `optimizer_including=conv_gemm`
to automatically replace all convolution operations with `GpuCorrMM`
or one of its gradients, or you can use it as a replacement for
:func:`conv2d <theano.tensor.nnet.conv.conv2d>`, called as
`GpuCorrMM(subsample=...)(image, filters)`. The latter is currently
faster, but note that it computes a correlation -- if you need to
compute a convolution, flip the filters as `filters[:,:,::-1,::-1]`.
:warning: For 700 series Nvidia GPUs of compute capability 3.5 and CUDA 5.0
to 6.0, there is a bug in CUBLAS' matrix multiplication function that
can make GpuCorrMM or its gradients crash for some input and filter
shapes. So if you have a Tesla K20, Tesla K40, Quadro K6000, GeForce GT
640 (DDR5), GeForce GTX 780 (or Ti), GeForce GTX TITAN (or Black or Z)
and experience a crash, switching to CUDA 6.5 or CUDA 4.2 should fix it.
If this is not possible, changing the input or filter shapes (e.g., the
batchsize or number of filters) may also work around the CUBLAS bug.
"""
def
__init__
(
self
,
border_mode
=
"valid"
,
subsample
=
(
1
,
1
),
pad
=
(
0
,
0
)):
"""
:param border_mode: currently supports "valid" only; "full" can be
simulated by setting `pad="full"` (at the cost of performance), or
by using `GpuCorrMM_gradInputs`
:param subsample: the subsample operation applied to each output image.
Should be a tuple with 2 elements.
`(sv, sh)` is equivalent to `GpuCorrMM(...)(...)[:,:,::sv, ::sh]`,
but faster.
Set to `(1, 1)` to disable subsampling.
:param pad: the width of a border of implicit zeros to pad the input
image with. Should be a tuple with 2 elements giving the numbers of
rows and columns to pad on each side, or "half" to set the padding
to `(kernel_rows // 2, kernel_columns // 2)`, or "full" to set the
padding to `(kernel_rows - 1, kernel_columns - 1)` at runtime.
Set to `(0, 0)` to disable padding.
:note: Currently, the Op requires the inputs, filters and outputs to be
C-contiguous. Use :func:`gpu_contiguous
<theano.sandbox.cuda.basic_ops.gpu_contiguous>` on these arguments
if needed.
"""
super
(
GpuCorrMM
,
self
)
.
__init__
(
border_mode
,
subsample
,
pad
)
def
make_node
(
self
,
img
,
kern
):
img
=
as_cuda_ndarray_variable
(
img
)
kern
=
as_cuda_ndarray_variable
(
kern
)
if
img
.
type
.
ndim
!=
4
:
raise
TypeError
(
'img must be 4D tensor'
)
if
kern
.
type
.
ndim
!=
4
:
raise
TypeError
(
'kern must be 4D tensor'
)
broadcastable
=
[
img
.
type
.
broadcastable
[
0
],
kern
.
type
.
broadcastable
[
0
],
False
,
False
]
return
Apply
(
self
,
[
img
,
kern
],
[
CudaNdarrayType
(
broadcastable
)()])
def
c_code
(
self
,
node
,
nodename
,
inp
,
out_
,
sub
):
bottom
,
weights
=
inp
top
,
=
out_
direction
=
"forward"
return
super
(
GpuCorrMM
,
self
)
.
c_code_helper
(
bottom
,
weights
,
top
,
direction
,
sub
)
def
grad
(
self
,
inp
,
grads
):
bottom
,
weights
=
inp
top
,
=
grads
top
=
gpu_contiguous
(
top
)
d_bottom
=
GpuCorrMM_gradInputs
(
self
.
border_mode
,
self
.
subsample
,
self
.
pad
)(
weights
,
top
,
bottom
.
shape
[
-
2
:])
d_weights
=
GpuCorrMM_gradWeights
(
self
.
border_mode
,
self
.
subsample
,
self
.
pad
)(
bottom
,
top
,
weights
.
shape
[
-
2
:])
return
d_bottom
,
d_weights
class
GpuCorrMM_gradWeights
(
BaseGpuCorrMM
):
"""Gradient wrt. filters for `GpuCorrMM`.
:note: You will not want to use this directly, but rely on Theano's
automatic differentiation or graph optimization to use it as needed."""
def
__init__
(
self
,
border_mode
=
"valid"
,
subsample
=
(
1
,
1
),
pad
=
(
0
,
0
)):
super
(
GpuCorrMM_gradWeights
,
self
)
.
__init__
(
border_mode
,
subsample
,
pad
)
def
make_node
(
self
,
img
,
topgrad
,
shape
=
None
):
img
=
as_cuda_ndarray_variable
(
img
)
topgrad
=
as_cuda_ndarray_variable
(
topgrad
)
if
img
.
type
.
ndim
!=
4
:
raise
TypeError
(
'img must be 4D tensor'
)
if
topgrad
.
type
.
ndim
!=
4
:
raise
TypeError
(
'topgrad must be 4D tensor'
)
if
self
.
subsample
!=
(
1
,
1
)
or
self
.
pad
==
"half"
:
if
shape
is
None
:
raise
ValueError
(
'shape must be given if subsample != (1, 1) or pad == "half"'
)
height_width
=
[
shape
[
0
],
shape
[
1
]]
else
:
height_width
=
[]
broadcastable
=
[
topgrad
.
type
.
broadcastable
[
1
],
img
.
type
.
broadcastable
[
1
],
False
,
False
]
return
Apply
(
self
,
[
img
,
topgrad
]
+
height_width
,
[
CudaNdarrayType
(
broadcastable
)()])
def
c_code
(
self
,
node
,
nodename
,
inp
,
out_
,
sub
):
bottom
,
top
=
inp
[:
2
]
height
,
width
=
inp
[
2
:]
or
(
None
,
None
)
weights
,
=
out_
direction
=
"backprop weights"
return
super
(
GpuCorrMM_gradWeights
,
self
)
.
c_code_helper
(
bottom
,
weights
,
top
,
direction
,
sub
,
height
,
width
)
def
grad
(
self
,
inp
,
grads
):
bottom
,
top
=
inp
[:
2
]
weights
,
=
grads
weights
=
gpu_contiguous
(
weights
)
d_bottom
=
GpuCorrMM_gradInputs
(
self
.
border_mode
,
self
.
subsample
,
self
.
pad
)(
weights
,
top
,
bottom
.
shape
[
-
2
:])
d_top
=
GpuCorrMM
(
self
.
border_mode
,
self
.
subsample
,
self
.
pad
)(
bottom
,
weights
)
d_height_width
=
(
theano
.
gradient
.
DisconnectedType
()(),)
*
2
if
len
(
inp
)
==
4
else
()
return
(
d_bottom
,
d_top
)
+
d_height_width
def
connection_pattern
(
self
,
node
):
if
node
.
nin
==
2
:
return
[[
1
],
[
1
]]
else
:
return
[[
1
],
[
1
],
[
0
],
[
0
]]
# no connection to height, width
class
GpuCorrMM_gradInputs
(
BaseGpuCorrMM
):
"""Gradient wrt. inputs for `GpuCorrMM`.
:note: You will not want to use this directly, but rely on Theano's
automatic differentiation or graph optimization to use it as needed."""
def
__init__
(
self
,
border_mode
=
"valid"
,
subsample
=
(
1
,
1
),
pad
=
(
0
,
0
)):
super
(
GpuCorrMM_gradInputs
,
self
)
.
__init__
(
border_mode
,
subsample
,
pad
)
def
make_node
(
self
,
kern
,
topgrad
,
shape
=
None
):
kern
=
as_cuda_ndarray_variable
(
kern
)
topgrad
=
as_cuda_ndarray_variable
(
topgrad
)
if
kern
.
type
.
ndim
!=
4
:
raise
TypeError
(
'kern must be 4D tensor'
)
if
topgrad
.
type
.
ndim
!=
4
:
raise
TypeError
(
'topgrad must be 4D tensor'
)
if
self
.
subsample
!=
(
1
,
1
)
and
shape
is
None
:
raise
ValueError
(
'shape must be given if subsample != (1, 1)'
)
height_width
=
[
shape
[
0
],
shape
[
1
]]
if
self
.
subsample
!=
(
1
,
1
)
else
[]
broadcastable
=
[
topgrad
.
type
.
broadcastable
[
0
],
kern
.
type
.
broadcastable
[
1
],
False
,
False
]
return
Apply
(
self
,
[
kern
,
topgrad
]
+
height_width
,
[
CudaNdarrayType
(
broadcastable
)()])
def
c_code
(
self
,
node
,
nodename
,
inp
,
out_
,
sub
):
weights
,
top
=
inp
[:
2
]
height
,
width
=
inp
[
2
:]
or
(
None
,
None
)
bottom
,
=
out_
direction
=
"backprop inputs"
return
super
(
GpuCorrMM_gradInputs
,
self
)
.
c_code_helper
(
bottom
,
weights
,
top
,
direction
,
sub
,
height
,
width
)
def
grad
(
self
,
inp
,
grads
):
weights
,
top
=
inp
[:
2
]
bottom
,
=
grads
bottom
=
gpu_contiguous
(
bottom
)
d_weights
=
GpuCorrMM_gradWeights
(
self
.
border_mode
,
self
.
subsample
,
self
.
pad
)(
bottom
,
top
,
weights
.
shape
[
-
2
:])
d_top
=
GpuCorrMM
(
self
.
border_mode
,
self
.
subsample
,
self
.
pad
)(
bottom
,
weights
)
d_height_width
=
(
theano
.
gradient
.
DisconnectedType
()(),)
*
2
if
len
(
inp
)
==
4
else
()
return
(
d_weights
,
d_top
)
+
d_height_width
def
connection_pattern
(
self
,
node
):
if
node
.
nin
==
2
:
return
[[
1
],
[
1
]]
else
:
return
[[
1
],
[
1
],
[
0
],
[
0
]]
# no connection to height, width
##
# Not really a BLAS operation, but whatever.
#
...
...
@@ -688,6 +980,8 @@ class GpuConv(GpuOp):
kshp
=
None
,
imshp
=
None
,
max_threads_dim0
=
None
,
nkern
=
None
,
bsize
=
None
,
fft_opt
=
True
):
"""
:param version: each version of c_code implements many kernel for the
...
...
@@ -707,7 +1001,15 @@ class GpuConv(GpuOp):
:param max_threads_dim0: The maximum number of threads for the
block size dimensions 0 (blockDim.x) used by the
GPU function.
:param fft_opt: desactivate fft_opt optimization at the op level when
:param nkern: The number of kernels. Not used for this op, but can be
used by graph optimizers to select a more optimal
convolution implementation. If the GpuConv op is inserted
automatically, we take its value from the Conv op.
:param bsize: The batch size. Not used for this op, but can be
used by graph optimizers to select a more optimal
convolution implementation. If the GpuConv op is inserted
automatically, we take its value from the Conv op.
:param fft_opt: deactivate fft_opt optimization at the op level when
set to False. Note that by default fft optimization
aren't enabled. See
:ref:`convolution documentation <libdoc_tensor_nnet_conv>`
...
...
@@ -735,6 +1037,8 @@ class GpuConv(GpuOp):
self
.
kshp
=
kshp
self
.
imshp
=
imshp
self
.
max_threads_dim0
=
max_threads_dim0
self
.
nkern
=
nkern
self
.
bsize
=
bsize
self
.
fft_opt
=
fft_opt
def
__eq__
(
self
,
other
):
...
...
theano/sandbox/cuda/caffe_common.hpp
deleted
100644 → 0
浏览文件 @
a81b5cdc
/*
Copyright (c) 2014, The Regents of the University of California (Regents)
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
1. Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
#ifndef CAFFE_COMMON_HPP_
#define CAFFE_COMMON_HPP_
#include <cublas_v2.h>
#include <cuda.h>
#include <driver_types.h> // cuda driver types
// CUDA: thread number configuration.
// Use 1024 threads per block, which requires cuda sm_2x or above,
// or fall back to attempt compatibility (best of luck to you).
#if __CUDA_ARCH__ >= 200
const
int
CAFFE_CUDA_NUM_THREADS
=
1024
;
#else
const
int
CAFFE_CUDA_NUM_THREADS
=
512
;
#endif
// CUDA: number of blocks for threads.
inline
int
CAFFE_GET_BLOCKS
(
const
int
N
)
{
return
(
N
+
CAFFE_CUDA_NUM_THREADS
-
1
)
/
CAFFE_CUDA_NUM_THREADS
;
}
#endif // CAFFE_COMMON_HPP_
theano/sandbox/cuda/conv_gemm.cu
浏览文件 @
cfc493d1
// This uses a lot of code from Caffe (http://caffe.berkeleyvision.org/);
// sources are clearly marked. Below we reproduce the original license of
// the Caffe software.
/*
Copyright (c) 2014, The Regents of the University of California (Regents)
All rights reserved.
...
...
@@ -22,176 +25,458 @@ ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
// Reference code: https://github.com/torch/cunn/blob/master/SpatialConvolutionMM.cu
#undef _GLIBCXX_ATOMIC_BUILTINS
#include <Python.h>
#include "cuda_ndarray.cuh"
#include "caffe_common.hpp"
// (borrowed from Caffe: https://github.com/BVLC/caffe/blob/master/src/caffe/caffe_common.hpp)
// CUDA: grid stride looping
#define CUDA_KERNEL_LOOP(i, n) \
for (int i = blockIdx.x * blockDim.x + threadIdx.x; \
i < (n); \
i += blockDim.x * gridDim.x)
// Use 1024 threads per block, which requires cuda sm_2x or above
const int CUDA_NUM_THREADS = 1024;
// CUDA: thread number configuration.
// Use 1024 threads per block, which requires cuda sm_2x or above,
// or fall back to attempt compatibility (best of luck to you).
#if __CUDA_ARCH__ >= 200
const int CUDA_NUM_THREADS = 1024;
#else
const int CUDA_NUM_THREADS = 512;
#endif
// CUDA: number of blocks for threads.
inline int GET_BLOCKS(const int N) {
return (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS;
}
// Kernel for fast unfold+copy
// (borrowed from Caffe: https://github.com/BVLC/caffe/blob/master/src/caffe/layers/conv_layer.cu)
// (borrowed from Caffe: https://github.com/BVLC/caffe/blob/master/src/caffe/util/im2col.cu)
// Kernels for fast unfold + copy
__global__ void im2col_kernel(const int n, const float* data_im,
const int height, const int width, const int ksize_h, const int ksize_w, const int pad_h,
const int pad_w, const int stride_h, const int stride_w, const int height_col, const int width_col,
float* data_col) {
const int height, const int width, const int kernel_h, const int kernel_w,
const int pad_h, const int pad_w,
const int stride_h, const int stride_w,
const int height_col, const int width_col,
float* data_col) {
CUDA_KERNEL_LOOP(index, n) {
int w_out = index % width_col;
in
dex /=
width_col;
int h_out = index % height_col;
int channel_in = index / height_col;
int channel_out = channel_in * k
size_h * ksize
_w;
in
t h_index = index /
width_col;
int h_out =
h_
index % height_col;
int channel_in =
h_
index / height_col;
int channel_out = channel_in * k
ernel_h * kernel
_w;
int h_in = h_out * stride_h - pad_h;
int w_in = w_out * stride_w - pad_w;
data_col += (channel_out * height_col + h_out) * width_col + w_out;
data_im += (channel_in * height + h_in) * width + w_in;
for (int i = 0; i < ksize_h; ++i) {
for (int j = 0; j < ksize_w; ++j) {
float* data_col_ptr = data_col;
data_col_ptr += (channel_out * height_col + h_out) * width_col + w_out;
const float* data_im_ptr = data_im;
data_im_ptr += (channel_in * height + h_in) * width + w_in;
for (int i = 0; i < kernel_h; ++i) {
for (int j = 0; j < kernel_w; ++j) {
int h = h_in + i;
int w = w_in + j;
*data_col = (h >= 0 && w >= 0 && h < height && w < width) ?
data_im
[i * width + j] : 0;
data_col += height_col * width_col;
*data_col
_ptr
= (h >= 0 && w >= 0 && h < height && w < width) ?
data_im_ptr
[i * width + j] : 0;
data_col
_ptr
+= height_col * width_col;
}
}
}
}
void im2col(const float* data_im, const int channels,
const int height, const int width, const int ksize_h, const int ksize_w, const int pad_h,
const int pad_w, const int stride_h, const int stride_w, float* data_col) {
const int height, const int width, const int kernel_h, const int kernel_w,
const int pad_h, const int pad_w,
const int stride_h, const int stride_w,
float* data_col) {
// We are going to launch channels * height_col * width_col kernels, each
// kernel responsible for copying a single-channel grid.
int height_col = (height + 2 * pad_h - k
size
_h) / stride_h + 1;
int width_col = (width + 2 * pad_w - k
size
_w) / stride_w + 1;
int height_col = (height + 2 * pad_h - k
ernel
_h) / stride_h + 1;
int width_col = (width + 2 * pad_w - k
ernel
_w) / stride_w + 1;
int num_kernels = channels * height_col * width_col;
// Launch
im2col_kernel <<<GET_BLOCKS(num_kernels), CUDA_NUM_THREADS>>> (
num_kernels, data_im, height, width, ksize_h, ksize_w,
pad_h, pad_w, stride_h, stride_w,
height_col, width_col, data_col
);
im2col_kernel<<<GET_BLOCKS(num_kernels),
CUDA_NUM_THREADS>>>(
num_kernels, data_im, height, width, kernel_h, kernel_w, pad_h,
pad_w, stride_h, stride_w, height_col,
width_col, data_col);
}
__global__ void col2im_kernel(const int n, const float* data_col,
const int height, const int width, const int channels,
const int patch_h, const int patch_w,
const int pad_h, const int pad_w,
const int stride_h, const int stride_w,
const int height_col, const int width_col,
float* data_im) {
CUDA_KERNEL_LOOP(index, n) {
float val = 0;
int w = index % width + pad_w;
int h = (index / width) % height + pad_h;
int c = index / (width * height);
// compute the start and end of the output
int w_col_start = (w < patch_w) ? 0 : (w - patch_w) / stride_w + 1;
int w_col_end = min(w / stride_w + 1, width_col);
int h_col_start = (h < patch_h) ? 0 : (h - patch_h) / stride_h + 1;
int h_col_end = min(h / stride_h + 1, height_col);
/*
for (int h_col = h_col_start; h_col < h_col_end; ++h_col) {
for (int w_col = w_col_start; w_col < w_col_end; ++w_col) {
// the col location: [c * width * height + h_out, w_out]
int c_col = c * patch_h * patch_w + (h - h_col * stride_h) * ksize
+ (w - w_col * stride_w);
val += data_col[(c_col * height_col + h_col) * width_col + w_col];
}
}
*/
// equivalent implementation
int offset =
(c * patch_h * patch_w + h * patch_w + w) * height_col * width_col;
int coeff_h_col = (1 - stride_h * patch_w * height_col) * width_col;
int coeff_w_col = (1 - stride_w * height_col * width_col);
for (int h_col = h_col_start; h_col < h_col_end; ++h_col) {
for (int w_col = w_col_start; w_col < w_col_end; ++w_col) {
val += data_col[offset + h_col * coeff_h_col + w_col * coeff_w_col];
}
}
data_im[index] = val;
}
}
void col2im(const float* data_col, const int channels,
const int height, const int width, const int patch_h, const int patch_w,
const int pad_h, const int pad_w, const int stride_h,
const int stride_w, float* data_im) {
int height_col = (height + 2 * pad_h - patch_h) / stride_h + 1;
int width_col = (width + 2 * pad_w - patch_w) / stride_w + 1;
int num_kernels = channels * height * width;
// To avoid involving atomic operations, we will launch one kernel per
// bottom dimension, and then in the kernel add up the top dimensions.
col2im_kernel<<<GET_BLOCKS(num_kernels),
CUDA_NUM_THREADS>>>(
num_kernels, data_col, height, width, channels, patch_h, patch_w,
pad_h, pad_w, stride_h, stride_w,
height_col, width_col, data_im);
}
// Author: Arjun Jain
CudaNdarray* corrMM(const CudaNdarray *input,
CudaNdarray *weight,
CudaNdarray *output,
int dH = 1,
int dW = 1,
int padH = 0,
int padW = 0)
{
cublasStatus_t status;
if (input->nd != 4)
// Theano op code
// Authors: Arjun Jain, Frédéric Bastien, Jan Schlüter
// Reference code: https://github.com/BVLC/caffe/blob/master/src/caffe/layers/conv_layer.cu
// and https://github.com/torch/cunn/blob/master/SpatialConvolutionMM.cu
CudaNdarray* corrMM(CudaNdarray *const bottom,
CudaNdarray *const weight,
CudaNdarray *const top,
const int direction,
const int dH = 1,
const int dW = 1,
const int padH = 0,
const int padW = 0)
{
if (bottom->nd != 4)
{
PyErr_SetString(PyExc_ValueError, "required input of 4D");
PyErr_SetString(PyExc_ValueError, "GpuCorrMM requires bottom of 4D");
return NULL;
}
if (!CudaNdarray_is_c_contiguous(bottom))
{
PyErr_Format(PyExc_ValueError,
"GpuCorrMM requires bottom to be C-contiguous, "
"but strides are: %d %d %d %d\n",
CudaNdarray_HOST_STRIDES(bottom)[0],
CudaNdarray_HOST_STRIDES(bottom)[1],
CudaNdarray_HOST_STRIDES(bottom)[2],
CudaNdarray_HOST_STRIDES(bottom)[3]);
return NULL;
}
if (weight->nd != 4)
{
PyErr_SetString(PyExc_ValueError, "required weight of 4D");
PyErr_SetString(PyExc_ValueError, "GpuCorrMM requires weight of 4D");
return NULL;
}
if (!CudaNdarray_is_c_contiguous(weight))
{
PyErr_Format(PyExc_ValueError,
"GpuCorrMM requires weight to be C-contiguous, "
"but strides are: %d %d %d %d\n",
CudaNdarray_HOST_STRIDES(weight)[0],
CudaNdarray_HOST_STRIDES(weight)[1],
CudaNdarray_HOST_STRIDES(weight)[2],
CudaNdarray_HOST_STRIDES(weight)[3]);
return NULL;
}
int kH = CudaNdarray_HOST_DIMS(weight)[2];
int kW = CudaNdarray_HOST_DIMS(weight)[3];
int nInputPlane = CudaNdarray_HOST_DIMS(input)[1];
// filters: (number of filters, nInputPlane, rows, columns)
int nOutputPlane = CudaNdarray_HOST_DIMS(weight)[0];
long batchSize = CudaNdarray_HOST_DIMS(input)[0];
if (CudaNdarray_HOST_DIMS(input)[1] != CudaNdarray_HOST_DIMS(weight)[1]){
PyErr_SetString(PyExc_ValueError,
"GpuCorrMM images and kernel must have the same stack size\n"
);
return NULL;
}
long inputHeight = CudaNdarray_HOST_DIMS(input)[2];
long inputWidth = CudaNdarray_HOST_DIMS(input)[3];
long outputWidth = (inputWidth + 2*padW - kW) / dW + 1;
long outputHeight = (inputHeight + 2*padH - kH) / dH + 1;
// check output, size (batchSize, nOutputPlane,
// outputHeight, outputWidth);
if (batchSize != CudaNdarray_HOST_DIMS(output)[0] ||
nOutputPlane != CudaNdarray_HOST_DIMS(output)[1] ||
outputHeight != CudaNdarray_HOST_DIMS(output)[2] ||
outputWidth != CudaNdarray_HOST_DIMS(output)[3]){
PyErr_Format(
PyExc_ValueError,
"GpuCorrMM outputs parameter don't have the good shape %d %d %d %d, %d %d %d %d\n",
batchSize, nOutputPlane, outputHeight, outputWidth,
CudaNdarray_HOST_DIMS(output)[0], CudaNdarray_HOST_DIMS(output)[1],
CudaNdarray_HOST_DIMS(output)[2], CudaNdarray_HOST_DIMS(output)[3]);
return NULL;
}
// Create temporary columns
int col_dim[2];
col_dim[0] = nInputPlane*kW*kH;
col_dim[1]= outputHeight*outputWidth;
CudaNdarray* columns = (CudaNdarray*)CudaNdarray_NewDims(2,col_dim);
int ip_stride = CudaNdarray_HOST_STRIDES(input)[0];
int op_stride = CudaNdarray_HOST_STRIDES(output)[0];
// For each elt in batch, do:
for (int elt = 0; elt < batchSize; elt ++) {
// Matrix mulitply per output:
// 1. Extract columns:
im2col(
input->devdata + elt*ip_stride,
nInputPlane, inputHeight, inputWidth, kH, kW, padH, padW, dH, dW,
columns->devdata
);
// M,N,K are dims of matrix A and B
// (see http://docs.nvidia.com/cuda/cublas/#cublas-lt-t-gt-gemm)
float alpha = 1.0f; float beta = 0.0f;
int m = CudaNdarray_HOST_DIMS(columns)[1];
int n = CudaNdarray_HOST_DIMS(weight)[0];
int k = CudaNdarray_HOST_DIMS(columns)[0];
status = cublasSgemm(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
m, n, k,
&alpha,
columns->devdata, m,
weight->devdata, k,
&beta,
output->devdata + elt * op_stride, m
);
if (status != CUBLAS_STATUS_SUCCESS) {
std::cerr << "!!!! CUBLAS error: ";
std::cerr << cublasGetErrorString(status) << "\n";
}
}
if (top->nd != 4)
{
PyErr_SetString(PyExc_ValueError, "GpuCorrMM requires top of 4D");
return NULL;
}
if (!CudaNdarray_is_c_contiguous(top))
{
PyErr_Format(PyExc_ValueError,
"GpuCorrMM requires top to be C-contiguous, "
"but strides are: %d %d %d %d\n",
CudaNdarray_HOST_STRIDES(top)[0],
CudaNdarray_HOST_STRIDES(top)[1],
CudaNdarray_HOST_STRIDES(top)[2],
CudaNdarray_HOST_STRIDES(top)[3]);
return NULL;
}
Py_DECREF(columns);
return output;
// Extract some shape information for later and check shape consistency
// bottom: (batchSize, nChannels, bottomHeight, bottomWidth)
const int batchSize = CudaNdarray_HOST_DIMS(bottom)[0];
const int nChannels = CudaNdarray_HOST_DIMS(bottom)[1];
const int bottomHeight = CudaNdarray_HOST_DIMS(bottom)[2];
const int bottomWidth = CudaNdarray_HOST_DIMS(bottom)[3];
// weights: (nFilters, nChannels, rows, columns)
const int nFilters = CudaNdarray_HOST_DIMS(weight)[0];
const int kH = CudaNdarray_HOST_DIMS(weight)[2];
const int kW = CudaNdarray_HOST_DIMS(weight)[3];
if (nChannels != CudaNdarray_HOST_DIMS(weight)[1]) {
PyErr_SetString(PyExc_ValueError,
"GpuCorrMM images and kernel must have the same stack size\n");
return NULL;
}
// top: (batchSize, nFilters, topHeight, topWidth)
const int topHeight = (bottomHeight + 2*padH - kH) / dH + 1;
const int topWidth = (bottomWidth + 2*padW - kW) / dW + 1;
if (batchSize != CudaNdarray_HOST_DIMS(top)[0] ||
nFilters != CudaNdarray_HOST_DIMS(top)[1] ||
topHeight != CudaNdarray_HOST_DIMS(top)[2] ||
topWidth != CudaNdarray_HOST_DIMS(top)[3]) {
PyErr_Format(PyExc_ValueError,
"GpuCorrMM shape inconsistency:\n"
" bottom shape: %d %d %d %d\n"
" weight shape: %d %d %d %d\n"
" top shape: %d %d %d %d (expected %d %d %d %d)\n",
batchSize, nChannels, bottomHeight, bottomWidth,
nFilters, nChannels, kH, kW,
CudaNdarray_HOST_DIMS(top)[0], CudaNdarray_HOST_DIMS(top)[1],
CudaNdarray_HOST_DIMS(top)[2], CudaNdarray_HOST_DIMS(top)[3],
batchSize, nFilters, topHeight, topWidth);
return NULL;
}
// Create temporary columns
int col_dim[2];
col_dim[0] = nChannels * kW * kH;
col_dim[1] = topHeight * topWidth;
CudaNdarray* col = (CudaNdarray*)CudaNdarray_NewDims(2, col_dim);
if (NULL == col)
{
PyErr_Format(PyExc_RuntimeError,
"GpuCorrMM failed to allocate working memory of %d x %d\n",
col_dim[0], col_dim[1]);
return NULL;
}
// Define some useful variables
const int bottom_stride = CudaNdarray_HOST_STRIDES(bottom)[0];
const int top_stride = CudaNdarray_HOST_STRIDES(top)[0];
const int K_ = col_dim[0];
const int N_ = col_dim[1];
const int M_ = nFilters;
const float one = 1.0f;
const float zero = 0.0f;
CudaNdarray *output;
if (direction == 0) { // forward pass
output = top;
// valid correlation: im2col, then gemm
// Iterate over batch
for (int n = 0; n < batchSize; n++) {
// First, im2col
im2col(bottom->devdata + n * bottom_stride, nChannels, bottomHeight,
bottomWidth, kH, kW, padH, padW, dH, dW, col->devdata);
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess) {
PyErr_Format(PyExc_RuntimeError,
"GpuCorrMM encountered a CUDA error in im2col: %s\n"
"This could be a known bug in CUDA, please see the "
"GpuCorrMM() documentation.\n",
cudaGetErrorString(err));
return NULL;
}
// Second, gemm
cublasStatus_t status = cublasSgemm(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
N_, M_, K_,
&one,
col->devdata, N_,
weight->devdata, K_,
&zero,
top->devdata + n * top_stride, N_);
if (status != CUBLAS_STATUS_SUCCESS) {
PyErr_Format(PyExc_RuntimeError,
"GpuCorrMM encountered a CUBLAS error: %s\n"
"This could be a known bug in CUDA, please see the "
"GpuCorrMM() documentation.\n",
cublasGetErrorString(status));
return NULL;
}
}
/*
// Original caffe code for comparison
// https://github.com/BVLC/caffe/blob/master/src/caffe/layers/conv_layer.cu
// Note that this is for grouped convolution; we can ignore groups here,
// but the group-related offsets help explain what M_, N_ and K_ are
int weight_offset = M_ * K_;
int col_offset = K_ * N_;
int top_offset = M_ * N_;
for (int n = 0; n < num_; ++n) {
// First, im2col
im2col_gpu(bottom_data + bottom[i]->offset(n), channels_, height_,
width_, kernel_h_, kernel_w_, pad_h_, pad_w_, stride_h_, stride_w_,
col_data);
// Second, innerproduct with groups
for (int g = 0; g < group_; ++g) {
caffe_gpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, M_, N_, K_,
(Dtype)1., weight + weight_offset * g, col_data + col_offset * g,
(Dtype)0., top_data + (*top)[i]->offset(n) + top_offset * g);
== (see https://github.com/BVLC/caffe/blob/master/src/caffe/util/math_functions.cu#L16)
cublasSgemm(CUBLAS_OP_N, CUBLAS_OP_N,
N_, M_, K_,
1.,
col_data + col_offset * g, N_,
weight + weight_offset * g, K_,
0.,
top_data + (*top)[i]->offset(n) + top_offset * g, N_);
}
}
*/
}
else if (direction == 1) { // backprop wrt. weights
output = weight;
// valid convolution: im2col, then gemm
// Iterate over batch
for (int n = 0; n < batchSize; n++) {
// First, im2col
im2col(bottom->devdata + n * bottom_stride, nChannels, bottomHeight,
bottomWidth, kH, kW, padH, padW, dH, dW, col->devdata);
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess) {
PyErr_Format(PyExc_RuntimeError,
"GpuCorrMM encountered a CUDA error in im2col: %s\n"
"This could be a known bug in CUDA, please see the "
"GpuCorrMM() documentation.\n",
cudaGetErrorString(err));
return NULL;
}
// Second, gemm
// Note that we accumulate into weight. We do so by setting beta = 0
// for the first iteration and beta = 1 for subsequent ones. (This
// is faster than setting weight to all zeros before the loop.)
cublasStatus_t status = cublasSgemm(handle,
CUBLAS_OP_T, CUBLAS_OP_N,
K_, M_, N_,
&one,
col->devdata, N_,
top->devdata + n * top_stride, N_,
(n == 0) ? &zero : &one,
weight->devdata, K_);
if (status != CUBLAS_STATUS_SUCCESS) {
PyErr_Format(PyExc_RuntimeError,
"GpuCorrMM encountered a CUBLAS error: %s\n"
"This could be a known bug in CUDA, please see the "
"GpuCorrMM() documentation.\n",
cublasGetErrorString(status));
return NULL;
}
}
/*
// Original caffe code for comparison
// https://github.com/BVLC/caffe/blob/master/src/caffe/layers/conv_layer.cu
// Note that this is for grouped convolution; we can ignore groups
for (int n = 0; n < num_; ++n) {
// Since we saved memory in the forward pass by not storing all col
// data, we will need to recompute them.
im2col_gpu(bottom_data + (*bottom)[i]->offset(n), channels_, height_,
width_, kernel_h_, kernel_w_, pad_h_, pad_w_,
stride_h_, stride_w_, col_data);
// gradient w.r.t. weight. Note that we will accumulate diffs.
for (int g = 0; g < group_; ++g) {
caffe_gpu_gemm<Dtype>(CblasNoTrans, CblasTrans, M_, K_, N_,
(Dtype)1., top_diff + top[i]->offset(n) + top_offset * g,
col_data + col_offset * g, (Dtype)1.,
weight_diff + weight_offset * g);
== (see https://github.com/BVLC/caffe/blob/master/src/caffe/util/math_functions.cu#L16)
cublasSgemm(CUBLAS_OP_T, CUBLAS_OP_N, K_, M_, N_,
1.0,
col_data + col_offset * g, N_,
top_diff + top[i]->offset(n) + top_offset * g, N_,
1.0,
weight_diff + weight_offset * g, K_);
}
}
*/
}
else if (direction == 2) { // backprop wrt. inputs
output = bottom;
// full convolution: gemm, then col2im
// Iterate over batch
for (int n = 0; n < batchSize; n++) {
// gemm into columns
cublasStatus_t status = cublasSgemm(handle,
CUBLAS_OP_N, CUBLAS_OP_T,
N_, K_, M_,
&one,
top->devdata + n * top_stride, N_,
weight->devdata, K_,
&zero,
col->devdata, N_);
if (status != CUBLAS_STATUS_SUCCESS) {
PyErr_Format(PyExc_RuntimeError,
"GpuCorrMM encountered a CUBLAS error: %s\n"
"This could be a known bug in CUDA, please see the "
"GpuCorrMM() documentation.\n",
cublasGetErrorString(status));
return NULL;
}
// col2im back to the data
col2im(col->devdata, nChannels, bottomHeight, bottomWidth,
kH, kW, padH, padW, dH, dW, bottom->devdata + n * bottom_stride);
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess) {
PyErr_Format(PyExc_RuntimeError,
"GpuCorrMM encountered a CUDA error in col2im: %s\n"
"This could be a known bug in CUDA, please see the "
"GpuCorrMM() documentation.\n",
cudaGetErrorString(err));
return NULL;
}
}
/*
// Original caffe code for comparison
// https://github.com/BVLC/caffe/blob/master/src/caffe/layers/conv_layer.cu
for (int n = 0; n < num_; ++n) {
// gradient w.r.t. bottom data, if necessary
if (propagate_down[i]) {
for (int g = 0; g < group_; ++g) {
caffe_gpu_gemm<Dtype>(CblasTrans, CblasNoTrans, K_, N_, M_,
(Dtype)1., weight + weight_offset * g,
top_diff + top[i]->offset(n) + top_offset * g,
(Dtype)0., col_diff + col_offset * g);
== (see https://github.com/BVLC/caffe/blob/master/src/caffe/util/math_functions.cu#L16)
cublasSgemm(CUBLAS_OP_N, CUBLAS_OP_T, N_, K_, M_,
1.,
top_diff + top[i]->offset(n) + top_offset * g, N_,
weight + weight_offset * g, K_,
0.,
col_diff + col_offset * g, N_);
}
// col2im back to the data
col2im_gpu(col_diff, channels_, height_, width_,
kernel_h_, kernel_w_, pad_h_, pad_w_, stride_h_, stride_w_,
bottom_diff + (*bottom)[i]->offset(n));
}
}
*/
}
// Free temporary columns
Py_DECREF(col);
// Note that we don't change the refcount of the output matrix here. Output
// (re)allocation and refcounting is done in BaseGpuCorrMM.c_code_helper();
// in here output is just aliased to one of bottom, weights, or top.
return output;
}
theano/sandbox/cuda/opt.py
浏览文件 @
cfc493d1
...
...
@@ -25,7 +25,8 @@ from theano.sandbox.cuda.basic_ops import (
GpuIncSubtensor
,
gpu_alloc
,
GpuAlloc
,
gpu_shape
)
from
theano.sandbox.cuda.type
import
CudaNdarrayType
from
theano.sandbox.cuda.blas
import
(
gpu_dot22
,
gpu_dot22scalar
,
gpu_gemm_inplace
,
gpu_gemm_no_inplace
,
GpuConv
,
GpuCorrMM
)
gpu_gemm_inplace
,
gpu_gemm_no_inplace
,
GpuConv
,
GpuCorrMM
,
GpuCorrMM_gradInputs
,
GpuCorrMM_gradWeights
)
from
theano.sandbox.cuda.blas
import
gpu_gemv_inplace
from
theano.sandbox.cuda.blas
import
gpu_gemv_no_inplace
from
theano.sandbox.cuda.blas
import
gpu_ger_inplace
...
...
@@ -1121,6 +1122,8 @@ def local_gpu_conv(node):
version
=
op
.
version
,
verbose
=
op
.
verbose
,
imshp
=
op
.
imshp
,
nkern
=
op
.
nkern
,
bsize
=
op
.
bsize
,
fft_opt
=
op
.
fft_opt
)
if
op
.
imshp_logical
is
not
None
:
...
...
@@ -1206,15 +1209,25 @@ def _gpu_conv_to_fftconv(node):
node
.
op
.
imshp
[
-
1
]
is
not
None
and
node
.
op
.
imshp
[
-
1
]
%
2
==
1
):
kwargs
[
'pad_last_dim'
]
=
True
# TODO: If the user supplied the full nonsymbolic image_shape and
# filter_shape in conv2d(), we could pass it on to conv2d_fft(). However,
# information on batch size and channel counts is currently discarded
# when a ConvOp is replaced by a GpuConv, so this would need more changes.
#if (node.op.imshp is not None) and (None not in node.op.imshp):
# kwargs['image_shape'] = (bsize, inchannels) + node.op.imshp
#if (node.op.kshp is not None) and (None not in node.op.kshp):
# kwargs['filter_shape'] = (outchannels, inchannels) + node.op.kshp
return
conv2d_fft
(
node
.
inputs
[
0
],
node
.
inputs
[
1
],
**
kwargs
)
# If the user supplied the full nonsymbolic image_shape and
# filter_shape in conv2d(), we can pass it on to conv2d_fft().
if
((
node
.
op
.
imshp
is
not
None
)
and
(
len
(
node
.
op
.
imshp
)
==
3
)
and
(
None
not
in
node
.
op
.
imshp
)
and
(
node
.
op
.
bsize
is
not
None
)):
kwargs
[
'image_shape'
]
=
(
node
.
op
.
bsize
,)
+
node
.
op
.
imshp
if
((
node
.
op
.
kshp
is
not
None
)
and
(
None
not
in
node
.
op
.
kshp
)
and
(
node
.
op
.
nkern
is
not
None
)
and
(
len
(
node
.
op
.
imshp
)
==
3
)
and
(
node
.
op
.
imshp
[
0
]
is
not
None
)):
kwargs
[
'filter_shape'
]
=
(
node
.
op
.
nkern
,
node
.
op
.
imshp
[
0
])
+
node
.
op
.
kshp
rval
=
conv2d_fft
(
node
.
inputs
[
0
],
node
.
inputs
[
1
],
**
kwargs
)
if
(
'image_shape'
in
kwargs
)
or
(
'filter_shape'
in
kwargs
):
# With given shape information, conv2d_fft may return a different
# broadcast pattern than GpuConv. This is forbidden, so we fix it.
rval
=
tensor
.
patternbroadcast
(
rval
,
node
.
outputs
[
0
]
.
type
.
broadcastable
)
return
rval
@local_optimizer
([
GpuConv
])
...
...
@@ -1351,10 +1364,55 @@ def local_conv_gemm(node):
if
(
isinstance
(
node
.
op
,
GpuConv
)
and
node
.
op
.
border_mode
in
[
'full'
,
'valid'
]):
img
,
kern
=
node
.
inputs
img
=
gpu_contiguous
(
img
)
kern
=
kern
[:,
:,
::
-
1
,
::
-
1
]
kern
=
gpu_contiguous
(
kern
)
return
[
GpuCorrMM
(
node
.
op
.
border_mode
,
node
.
op
.
subsample
)(
img
,
kern
)]
border_mode
=
node
.
op
.
border_mode
subsample
=
node
.
op
.
subsample
pad
=
(
0
,
0
)
if
(
border_mode
==
'full'
)
and
(
subsample
!=
(
1
,
1
)):
# need to simulate this via a padded valid convolution
pad
=
'full'
border_mode
=
'valid'
if
(
border_mode
==
'valid'
):
# need to flip the kernel for valid convolution
kern
=
kern
[:,
:,
::
-
1
,
::
-
1
]
# call GpuCorrMM or GpuCorrMM_gradWeights
# (the latter is faster if batchsize * kernelHeight * kernelWidth
# is larger than inputChannels * outputHeight * outputWidth.
# GpuConv does not always store information on the batchsize and
# channels, though, so we only use what information we have.)
if
((
subsample
==
(
1
,
1
))
and
(
node
.
op
.
imshp
is
not
None
)
and
(
None
not
in
node
.
op
.
imshp
[
-
2
:])
and
(
node
.
op
.
kshp
is
not
None
)
and
(
None
not
in
node
.
op
.
kshp
)):
# we know the kernel and output size
prod1
=
node
.
op
.
kshp
[
0
]
*
node
.
op
.
kshp
[
1
]
prod2
=
((
node
.
op
.
imshp
[
-
2
]
-
node
.
op
.
kshp
[
0
]
+
1
)
*
(
node
.
op
.
imshp
[
-
1
]
-
node
.
op
.
kshp
[
1
]
+
1
))
if
((
node
.
op
.
bsize
is
not
None
)
and
(
len
(
node
.
op
.
imshp
)
==
3
)
and
(
node
.
op
.
imshp
[
0
]
is
not
None
)):
# we also know batchsize and input channels
prod1
*=
node
.
op
.
bsize
prod2
*=
node
.
op
.
imshp
[
0
]
# compare to decide
if
prod1
>
prod2
:
# (we need to wrap the result in as_cuda_ndarray_variable,
# because we are not allowed to replace a CudaNdarray with
# a DimShuffle instance in a graph optimization)
return
[
theano
.
sandbox
.
cuda
.
as_cuda_ndarray_variable
(
GpuCorrMM_gradWeights
(
'valid'
,
subsample
,
pad
)(
gpu_contiguous
(
img
.
dimshuffle
(
1
,
0
,
2
,
3
)),
gpu_contiguous
(
kern
.
dimshuffle
(
1
,
0
,
2
,
3
))
)
.
dimshuffle
(
1
,
0
,
2
,
3
))]
# use GpuCorrMM if we did not choose GpuCorrMM_gradWeights above
return
[
GpuCorrMM
(
'valid'
,
subsample
,
pad
)(
gpu_contiguous
(
img
),
gpu_contiguous
(
kern
))]
elif
(
border_mode
==
'full'
):
# need to dimshuffle the kernel for full convolution
kern
=
kern
.
dimshuffle
(
1
,
0
,
2
,
3
)
# call GpuCorrMM_gradInputs
return
[
GpuCorrMM_gradInputs
(
'valid'
,
subsample
,
pad
)(
gpu_contiguous
(
kern
),
gpu_contiguous
(
img
))]
gpu_optimizer
.
register
(
"conv_gemm"
,
local_conv_gemm
)
...
...
theano/sandbox/cuda/tests/test_conv_cuda_ndarray.py
浏览文件 @
cfc493d1
...
...
@@ -186,7 +186,7 @@ def _params_allgood(ishape, kshape, mode, subsample=(1, 1), img_stride=(1, 1),
f
=
theano
.
function
([
i
,
k
],
op
,
mode
=
theano_mode
)
if
cls
is
not
None
:
assert
any
([
isinstance
(
node
.
op
,
cls
)
for
node
in
f
.
maker
.
fgraph
.
toposort
()]),
f
.
maker
.
fgraph
.
toposort
(
)
for
node
in
f
.
maker
.
fgraph
.
toposort
()]),
"Cannot find class
%
r in
%
r"
%
(
cls
,
f
.
maker
.
fgraph
.
toposort
()
)
gpuval
=
f
(
img
,
kern
)
t2
=
time
.
time
()
for
i
in
range
(
nb_iter
):
...
...
@@ -284,7 +284,7 @@ def exec_conv(version, shapes, verbose, random, mode,
cls
=
cls
)
except
Exception
,
e
:
print
ver
,
id
,
(
ishape
,
kshape
,
subshape
,
istride
,
kstride
)
print
e
print
"Exception"
,
type
(
e
),
e
pass
if
not
ret
:
failed_version
.
add
(
ver
)
...
...
@@ -634,7 +634,7 @@ def test_valid(conv_gemm=False):
if
conv_gemm
:
# Test the GpuCorrMM version
mode
=
theano_mode
.
including
(
"conv_gemm"
)
cls
=
cuda
.
blas
.
GpuCorrMM
cls
=
cuda
.
blas
.
Base
GpuCorrMM
# dummy version; not used by GpuCorrMM so one version is enough
version
=
[
-
1
]
# Add tests with strided inputs by still square images and filters.
...
...
@@ -713,7 +713,7 @@ def test_full(conv_gemm=False):
if
conv_gemm
:
# Test the GpuCorrMM version
mode
=
theano_mode
.
including
(
"conv_gemm"
)
cls
=
cuda
.
blas
.
GpuCorrMM
cls
=
cuda
.
blas
.
Base
GpuCorrMM
# dummy version; not used by GpuCorrMM so one version is enough
version
=
[
-
1
]
else
:
...
...
@@ -753,7 +753,7 @@ def test_subsample(conv_gemm=False):
if
conv_gemm
:
# Test the GpuCorrMM version
mode
=
theano_mode
.
including
(
"conv_gemm"
)
cls
=
cuda
.
blas
.
GpuCorrMM
cls
=
cuda
.
blas
.
Base
GpuCorrMM
# dummy version; not used by GpuCorrMM so one version is enough
version_valid
=
version_full
=
[
-
1
]
else
:
...
...
@@ -842,14 +842,9 @@ class TestConv2DGPU(unittest.TestCase):
theano_mode
=
theano_mode_orig
def
test_gemm_directly
():
"""
input: (batch size, channels, rows, columns)
filters: (number of filters, channels, rows, columns)
"""
for
mode
in
[
'full'
,
'valid'
]:
print
'Testing mode: '
+
mode
for
direction
in
[
'fprop'
,
'bprop img'
,
'bprop kern'
]:
print
'Testing direction: '
+
direction
for
bs
in
range
(
1
,
5
):
for
ch
in
range
(
1
,
4
):
for
nf
in
range
(
1
,
4
):
...
...
@@ -857,35 +852,119 @@ def test_gemm_directly():
for
rImg2
in
range
(
5
,
9
):
for
rFlt1
in
range
(
2
,
4
):
for
rFlt2
in
range
(
2
,
4
):
for
subsx
in
range
(
1
,
3
):
for
subsy
in
range
(
1
,
3
):
for
subsx
in
range
(
1
,
3
)
if
direction
==
'fprop'
else
[
1
]
:
for
subsy
in
range
(
1
,
3
)
if
direction
==
'fprop'
else
[
1
]
:
ishape
=
(
bs
,
ch
,
rImg1
,
rImg2
)
kshape
=
(
nf
,
ch
,
rFlt1
,
rFlt2
)
print
"ishape: "
,
ishape
print
"kshape: "
,
kshape
subsample
=
(
subsx
,
subsy
)
print
"subsample: "
,
subsample
npy_img
=
theano
.
_asarray
(
numpy
.
random
.
rand
(
*
ishape
),
dtype
=
'float32'
)
npy_kern
=
theano
.
_asarray
(
numpy
.
random
.
rand
(
*
kshape
),
dtype
=
'float32'
)
i
=
cuda_tensor4
()
k
=
cuda_tensor4
()
cpuval
=
py_conv
(
npy_img
,
npy_kern
,
mode
,
subsample
)
op
=
theano
.
sandbox
.
cuda
.
blas
.
GpuCorrMM
(
border_mode
=
mode
,
\
subsample
=
subsample
)(
i
,
k
)
f
=
theano
.
function
([
i
,
k
],
op
,
mode
=
theano_mode
)
npy_kern
=
npy_kern
[:,:,::
-
1
,::
-
1
]
gpuval
=
f
(
npy_img
,
npy_kern
)
gpuval
=
numpy
.
asarray
(
gpuval
)
rval
=
numpy
.
allclose
(
cpuval
,
gpuval
,
rtol
=
1e-4
)
assert
(
rval
==
True
)
print
'Test Passed'
if
direction
==
'fprop'
:
cpuval
=
py_conv
(
npy_img
,
npy_kern
,
'valid'
,
subsample
)
op
=
theano
.
sandbox
.
cuda
.
blas
.
GpuCorrMM
(
border_mode
=
'valid'
,
subsample
=
subsample
)(
i
,
k
)
f
=
theano
.
function
([
i
,
k
],
op
,
mode
=
theano_mode
)
gpuval
=
f
(
npy_img
,
npy_kern
[:,:,::
-
1
,::
-
1
])
elif
direction
==
'bprop img'
:
cpuval
=
py_conv
(
npy_img
,
npy_kern
,
'full'
,
subsample
)
op
=
theano
.
sandbox
.
cuda
.
blas
.
GpuCorrMM_gradInputs
(
border_mode
=
'valid'
,
subsample
=
subsample
)(
i
,
k
)
f
=
theano
.
function
([
i
,
k
],
op
,
mode
=
theano_mode
)
gpuval
=
f
(
npy_kern
.
transpose
(
1
,
0
,
2
,
3
),
npy_img
)
elif
direction
==
'bprop kern'
:
cpuval
=
py_conv
(
npy_img
,
npy_kern
,
'valid'
,
subsample
)
op
=
theano
.
sandbox
.
cuda
.
blas
.
GpuCorrMM_gradWeights
(
border_mode
=
'valid'
,
subsample
=
subsample
)(
i
,
k
)
f
=
theano
.
function
([
i
,
k
],
op
,
mode
=
theano_mode
)
gpuval
=
numpy
.
array
(
f
(
npy_img
.
transpose
(
1
,
0
,
2
,
3
),
npy_kern
.
transpose
(
1
,
0
,
2
,
3
)[:,:,::
-
1
,::
-
1
]))
.
transpose
(
1
,
0
,
2
,
3
)
if
not
numpy
.
allclose
(
cpuval
,
gpuval
,
rtol
=
1e-4
):
print
"Test failed for"
print
"direction: "
,
direction
print
"ishape: "
,
ishape
print
"kshape: "
,
kshape
print
"subsample: "
,
subsample
assert
False
def
test_gemm_grads
():
for
mode
in
'valid'
,
'full'
:
for
bs
in
[
1
,
5
]:
for
ch
in
[
4
]:
for
nf
in
[
3
]:
for
rImg1
in
[
2
,
5
]:
for
rImg2
in
[
2
,
8
]:
for
rFlt1
in
[
1
,
2
]:
for
rFlt2
in
[
1
,
2
]:
for
subsx
in
[
1
,
2
]:
for
subsy
in
[
1
,
2
]
if
subsx
==
1
else
[
2
]:
ishape
=
(
bs
,
ch
,
rImg1
,
rImg2
)
kshape
=
(
nf
,
ch
,
rFlt1
,
rFlt2
)
subsample
=
(
subsx
,
subsy
)
npy_img
=
theano
.
_asarray
(
numpy
.
random
.
rand
(
*
ishape
),
dtype
=
'float32'
)
npy_kern
=
theano
.
_asarray
(
numpy
.
random
.
rand
(
*
kshape
),
dtype
=
'float32'
)
i
=
cuda_tensor4
()
k
=
cuda_tensor4
()
pad
=
'full'
if
mode
==
'full'
else
(
0
,
0
)
# TODO: also test custom pad values
corr_op
=
theano
.
sandbox
.
cuda
.
blas
.
GpuCorrMM
(
'valid'
,
subsample
,
pad
)(
i
,
k
)
# try to compile reference implementation without shape,
# so we don't have to compile hundreds of versions
conv_op
=
tensor
.
nnet
.
conv2d
(
i
,
k
[:,:,::
-
1
,::
-
1
],
border_mode
=
mode
,
subsample
=
subsample
)
try
:
conv_op_di
=
theano
.
grad
(
conv_op
.
sum
(),
i
)
conv_op_dk
=
theano
.
grad
(
conv_op
.
sum
(),
k
)
except
Exception
:
# compile with shape information only when needed
conv_op
=
tensor
.
nnet
.
conv2d
(
i
,
k
[:,:,::
-
1
,::
-
1
],
ishape
,
kshape
,
mode
,
subsample
)
conv_op_di
=
theano
.
grad
(
conv_op
.
sum
(),
i
)
conv_op_dk
=
theano
.
grad
(
conv_op
.
sum
(),
k
)
corr_op_di
=
theano
.
grad
(
corr_op
.
sum
(),
i
)
corr_op_dk
=
theano
.
grad
(
corr_op
.
sum
(),
k
)
outputs
=
[
corr_op
,
conv_op
,
corr_op_di
,
conv_op_di
,
corr_op_dk
,
conv_op_dk
]
try
:
conv_op_dik
=
theano
.
grad
(
conv_op_di
.
sum
(),
k
)
conv_op_dki
=
theano
.
grad
(
conv_op_dk
.
sum
(),
i
)
except
Exception
:
# skip if the reference implementation can't do it
print
"."
,
else
:
corr_op_dik
=
theano
.
grad
(
corr_op_di
.
sum
(),
k
)
corr_op_dki
=
theano
.
grad
(
corr_op_dk
.
sum
(),
i
)
outputs
.
extend
([
corr_op_dik
,
conv_op_dik
,
corr_op_dki
,
conv_op_dki
])
print
":"
,
f
=
theano
.
function
([
i
,
k
],
outputs
,
mode
=
theano_mode
)
allvals
=
f
(
npy_img
,
npy_kern
)
for
a
,
b
,
p
in
zip
(
allvals
[::
2
],
allvals
[
1
::
2
],
(
'top'
,
'dtop/dbottom'
,
'dtop/dweight'
,
'dtop/dbottom/dweight'
,
'dtop/dweight/dbottom'
)):
if
(
a
.
shape
!=
b
.
shape
)
or
not
numpy
.
allclose
(
a
,
b
,
rtol
=
1e-4
):
print
"Test failed for"
,
p
print
"mode: "
,
mode
print
"ishape: "
,
ishape
print
"kshape: "
,
kshape
print
"subsample: "
,
subsample
assert
False
sys
.
stdout
.
flush
()
def
benchmark
():
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论