Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
369af1ad
提交
369af1ad
authored
8月 05, 2014
作者:
Frédéric Bastien
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #2002 from stencilman/conv_gemm
caffe conv kernel for theano. tests work, but needs integration and some...
上级
ae47cc39
4c55bc4b
隐藏空白字符变更
内嵌
并排
正在显示
6 个修改的文件
包含
540 行增加
和
39 行删除
+540
-39
conv.txt
doc/library/tensor/nnet/conv.txt
+13
-0
blas.py
theano/sandbox/cuda/blas.py
+171
-0
caffe_common.hpp
theano/sandbox/cuda/caffe_common.hpp
+53
-0
conv_gemm.cu
theano/sandbox/cuda/conv_gemm.cu
+196
-0
opt.py
theano/sandbox/cuda/opt.py
+16
-2
test_conv_cuda_ndarray.py
theano/sandbox/cuda/tests/test_conv_cuda_ndarray.py
+91
-37
没有找到文件。
doc/library/tensor/nnet/conv.txt
浏览文件 @
369af1ad
...
...
@@ -51,8 +51,21 @@ TODO: Give examples for how to use these things! They are pretty complicated.
implementation.
Also, there is restrictions on which shape are supported.
- :func:`GpuCorrMM <theano.sandbox.cuda.blas.GpuCorrMM>`
This is a GPU-only version of a correlation that computes correlations
as `caffe <https://github.com/BVLC/caffe/blob/master/src/caffe/layers/conv_layer.cu>`.
For each element in a batch, it first creates a
Toeplitz<http://en.wikipedia.org/wiki/Toeplitz_matrix> matrix in a cuda kernel.
Then, it performs a `gemm` call to multiply this Toeplitz matrix and to the kernel.
It need extra memory for this, which is the size of the Toeplitz matrix. Precisely,
the dimensions of this Toeplitz matrix is equal to
(no of channels * filter width * filter height, output width * output height).
You can enable it for call to conv2d 2d by setting 'THEANO_FLAGS=optimizer_including=conv_gemm'
in your environment. This is not enabled by default because it
uses some extra memory. It don't support strides for now and requires square kernels.
.. autofunction:: theano.tensor.nnet.conv.conv2d
.. autofunction:: theano.tensor.nnet.Conv3D.conv3D
.. autofunction:: theano.tensor.nnet.conv3d2d.conv3d
.. autofunction:: theano.sandbox.cuda.fftconv.conv2d_fft
.. autofunction:: theano.sandbox.cuda.blas.GpuCorrMM
theano/sandbox/cuda/blas.py
浏览文件 @
369af1ad
...
...
@@ -7,6 +7,7 @@ from theano import tensor
from
theano.compat.six
import
StringIO
from
theano.sandbox.cuda.type
import
CudaNdarrayType
from
theano.sandbox.cuda
import
GpuOp
from
theano.sandbox.cuda
import
as_cuda_ndarray_variable
class
GpuDot22
(
GpuOp
):
...
...
@@ -497,9 +498,179 @@ gpu_ger_no_inplace = GpuGer(inplace=False)
gpu_ger_inplace
=
GpuGer
(
inplace
=
True
)
class
GpuCorrMM
(
GpuOp
):
"""
Author: Arjun Jain
Implement the caffe convolution
"""
def
__init__
(
self
,
border_mode
,
subsample
=
(
1
,
1
),
pad
=
0
):
"""
:param border_mode: "valid" or "full"
:param subsample: not yet supported
:param pad: not yet supported
"""
self
.
border_mode
=
border_mode
self
.
subsample
=
subsample
self
.
pad
=
pad
if
pad
!=
0
:
raise
NotImplementedError
(
"GpuCorrMM don't implement the pad parameter"
)
if
subsample
!=
(
1
,
1
):
raise
NotImplementedError
(
"GpuCorrMM we don't implement the subsample parameter"
)
def
__eq__
(
self
,
other
):
return
type
(
self
)
==
type
(
other
)
\
and
self
.
border_mode
==
other
.
border_mode
\
and
self
.
subsample
==
other
.
subsample
\
and
self
.
pad
==
other
.
pad
def
__hash__
(
self
):
# don't use hash(self.version) as hash(-1)==-2 and
# hash(-2)==-2 in python!
return
hash
(
type
(
self
))
\
^
hash
(
self
.
border_mode
)
\
^
hash
(
self
.
subsample
)
\
^
hash
(
self
.
pad
)
def
__str__
(
self
):
return
'
%
s{
%
s,
%
s, pad=
%
d}'
%
(
self
.
__class__
.
__name__
,
self
.
border_mode
,
str
(
self
.
subsample
),
self
.
pad
)
def
make_node
(
self
,
img
,
kern
):
img
=
as_cuda_ndarray_variable
(
img
)
kern
=
as_cuda_ndarray_variable
(
kern
)
if
img
.
type
.
ndim
!=
4
:
raise
TypeError
(
'img must be 4D tensor'
)
if
kern
.
type
.
ndim
!=
4
:
raise
TypeError
(
'kern must be 4D tensor'
)
broadcastable
=
[
img
.
type
.
broadcastable
[
0
],
kern
.
type
.
broadcastable
[
0
],
False
,
False
]
return
Apply
(
self
,
[
img
,
kern
],
[
CudaNdarrayType
(
broadcastable
)()])
def
flops
(
self
,
inputs
,
outputs
):
""" Useful with the hack in profilemode to print the MFlops"""
images
,
kerns
=
inputs
out
,
=
outputs
assert
images
[
1
]
==
kerns
[
1
]
flops
=
0
if
self
.
border_mode
==
"valid"
:
# nb mul and add by output pixel
flops
=
kerns
[
2
]
*
kerns
[
3
]
*
2
# nb flops by output image
flops
*=
out
[
2
]
*
out
[
3
]
# nb patch multiplied
flops
*=
images
[
1
]
*
kerns
[
0
]
*
images
[
0
]
else
:
flops
=
(
images
[
0
]
*
kerns
[
0
]
*
images
[
1
]
*
kerns
[
2
]
*
kerns
[
3
]
*
images
[
2
]
*
images
[
3
]
*
2
)
return
flops
def
c_headers
(
self
):
return
[
'cuda_ndarray.cuh'
,
'<stdio.h>'
]
def
c_code_cache_version
(
self
):
return
# raise this whenever modifying any of the support_code_files
return
(
0
,
21
)
def
c_support_code_apply
(
self
,
node
,
nodename
):
# REMEMBER TO RAISE c_code_cache_version when changing any of
# these files
files
=
[
'conv_gemm.cu'
]
codes
=
[
open
(
os
.
path
.
join
(
os
.
path
.
split
(
__file__
)[
0
],
f
))
.
read
()
for
f
in
files
]
return
reduce
(
str
.
__add__
,
codes
)
def
c_code
(
self
,
node
,
nodename
,
inp
,
out_
,
sub
):
img
,
kern
=
inp
out
,
=
out_
dx
=
self
.
subsample
[
0
]
dy
=
self
.
subsample
[
1
]
border_mode
=
self
.
border_mode
sub
=
sub
.
copy
()
pad
=
self
.
pad
sub
.
update
(
locals
())
return
"""
//Mandatory args
const char *mode_str = "
%(border_mode)
s";
//Optional args
int dx =
%(dx)
s;
int dy =
%(dy)
s;
int pad = 0;
CudaNdarray * img =
%(img)
s;
CudaNdarray * kern =
%(kern)
s;
CudaNdarray * out2 = NULL;
int mode;
if (strcmp(mode_str, "full") == 0)
{
mode = 0;
}
else if (strcmp(mode_str, "valid") == 0)
{
mode = 1;
}
else
{
PyErr_SetString(PyExc_ValueError,
"mode must be one of 'full' or 'valid'");
%(fail)
s;
}
//TODO: Send self.pad, stride, etc
int out_dim[4];
out_dim[0] = CudaNdarray_HOST_DIMS(img)[0];
out_dim[1] = CudaNdarray_HOST_DIMS(kern)[0];
int logical_rows, logical_cols;
if (mode == 1)
{
logical_rows = CudaNdarray_HOST_DIMS(img)[2] - CudaNdarray_HOST_DIMS(kern)[2] + 1;
logical_cols = CudaNdarray_HOST_DIMS(img)[3] - CudaNdarray_HOST_DIMS(kern)[3] + 1;
}
else
{
logical_rows = CudaNdarray_HOST_DIMS(img)[2] + CudaNdarray_HOST_DIMS(kern)[2] - 1;
logical_cols = CudaNdarray_HOST_DIMS(img)[3] + CudaNdarray_HOST_DIMS(kern)[3] - 1;
pad = CudaNdarray_HOST_DIMS(kern)[2] - 1;
}
out_dim[2] = ceil_intdiv(logical_rows, dx);
out_dim[3] = ceil_intdiv(logical_cols, dy);
if ( !(
%(out)
s
&&
%(out)
s->nd==4
&& CudaNdarray_is_c_contiguous(
%(out)
s)
&& CudaNdarray_HOST_DIMS(
%(out)
s)[0]==out_dim[0]
&& CudaNdarray_HOST_DIMS(
%(out)
s)[1]==out_dim[1]
&& CudaNdarray_HOST_DIMS(
%(out)
s)[2]==out_dim[2]
&& CudaNdarray_HOST_DIMS(
%(out)
s)[3]==out_dim[3]))
{
Py_XDECREF(
%(out)
s);
%(out)
s = (CudaNdarray*)CudaNdarray_NewDims(4,out_dim);
}
out2 = corrMM(
%(img)
s,
%(kern)
s,
%(out)
s, pad);
if (out2==NULL){
%(fail)
s
}
assert (out2 ==
%(out)
s);
"""
%
sub
##
# Not really a BLAS operation, but whatever.
#
class
GpuConv
(
GpuOp
):
"""
Implement the batched and stacked 2d convolution on the gpu.
...
...
theano/sandbox/cuda/caffe_common.hpp
0 → 100644
浏览文件 @
369af1ad
/*
Copyright (c) 2014, The Regents of the University of California (Regents)
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
1. Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
#ifndef CAFFE_COMMON_HPP_
#define CAFFE_COMMON_HPP_
#include <cublas_v2.h>
#include <cuda.h>
#include <driver_types.h> // cuda driver types
// CUDA: grid stride looping
#define CUDA_KERNEL_LOOP(i, n) \
for (int i = blockIdx.x * blockDim.x + threadIdx.x; \
i < (n); \
i += blockDim.x * gridDim.x)
// CUDA: thread number configuration.
// Use 1024 threads per block, which requires cuda sm_2x or above,
// or fall back to attempt compatibility (best of luck to you).
#if __CUDA_ARCH__ >= 200
const
int
CAFFE_CUDA_NUM_THREADS
=
1024
;
#else
const
int
CAFFE_CUDA_NUM_THREADS
=
512
;
#endif
// CUDA: number of blocks for threads.
inline
int
CAFFE_GET_BLOCKS
(
const
int
N
)
{
return
(
N
+
CAFFE_CUDA_NUM_THREADS
-
1
)
/
CAFFE_CUDA_NUM_THREADS
;
}
#endif // CAFFE_COMMON_HPP_
theano/sandbox/cuda/conv_gemm.cu
0 → 100644
浏览文件 @
369af1ad
/*
Copyright (c) 2014, The Regents of the University of California (Regents)
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
1. Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
#undef _GLIBCXX_ATOMIC_BUILTINS
#include <Python.h>
#include "cuda_ndarray.cuh"
#include "caffe_common.hpp"
// Kernel for fast unfold+copy
// (borrowed from Caffe: https://github.com/BVLC/caffe/blob/master/src/caffe/layers/conv_layer.cu)
// Reference code: https://github.com/torch/cunn/blob/master/SpatialConvolutionMM.cu
__global__ void im2col_kernel(const int n, const float* data_im,
const int height, const int width, const int ksize, const int pad,
const int stride, const int height_col, const int width_col,
float* data_col) {
CUDA_KERNEL_LOOP(index, n) {
int w_out = index % width_col;
index /= width_col;
int h_out = index % height_col;
int channel_in = index / height_col;
int channel_out = channel_in * ksize * ksize;
int h_in = h_out * stride - pad;
int w_in = w_out * stride - pad;
data_col += (channel_out * height_col + h_out) * width_col + w_out;
data_im += (channel_in * height + h_in) * width + w_in;
for (int i = 0; i < ksize; ++i) {
for (int j = 0; j < ksize; ++j) {
int h = h_in + i;
int w = w_in + j;
*data_col = (h >= 0 && w >= 0 && h < height && w < width) ?
data_im[i * width + j] : 0;
data_col += height_col * width_col;
}
}
}
}
void im2col(const float* data_im, const int channels,
const int height, const int width, const int ksize, const int pad,
const int stride, float* data_col) {
// We are going to launch channels * height_col * width_col kernels, each
// kernel responsible for copying a single-channel grid.
int height_col = (height + 2 * pad - ksize) / stride + 1;
int width_col = (width + 2 * pad - ksize) / stride + 1;
int num_kernels = channels * height_col * width_col;
// Launch
im2col_kernel <<<CAFFE_GET_BLOCKS(num_kernels), CAFFE_CUDA_NUM_THREADS>>> (
num_kernels, data_im, height, width, ksize,
pad, stride,
height_col, width_col, data_col
);
}
// Author: Arjun Jain
CudaNdarray* corrMM(const CudaNdarray *input,
CudaNdarray *weight,
CudaNdarray *output,
int padding = 0)
{
cublasStatus_t status;
if (input->nd != 4)
{
PyErr_SetString(PyExc_ValueError, "required input of 4D");
}
if (weight->nd != 4)
{
PyErr_SetString(PyExc_ValueError, "required weight of 4D");
}
// TODO: stride(dW, dH) and padding as function parameter
int dH = 1;
int dW = 1;
int kH = CudaNdarray_HOST_DIMS(weight)[2];
int kW = CudaNdarray_HOST_DIMS(weight)[3];
int nInputPlane = CudaNdarray_HOST_DIMS(input)[1];
// filters: (number of filters, nInputPlane, rows, columns)
int nOutputPlane = CudaNdarray_HOST_DIMS(weight)[0];
long batchSize = CudaNdarray_HOST_DIMS(input)[0];
if (CudaNdarray_HOST_DIMS(input)[2] != CudaNdarray_HOST_DIMS(input)[3]){
PyErr_Format(PyExc_ValueError,
"GpuCorrMM support only square images. Got %dx%d images\n",
CudaNdarray_HOST_DIMS(input)[2],
CudaNdarray_HOST_DIMS(input)[3]
);
return NULL;
}
if (kW != kH){
PyErr_Format(PyExc_ValueError,
"GpuCorrMM support only square kernel. Got %dx%d kernel\n",
kW, kH
);
return NULL;
}
if (CudaNdarray_HOST_DIMS(input)[1] != CudaNdarray_HOST_DIMS(weight)[1]){
PyErr_SetString(PyExc_ValueError,
"GpuCorrMM images and kernel must have the same stack size\n"
);
return NULL;
}
long inputHeight = CudaNdarray_HOST_DIMS(input)[2];
long inputWidth = CudaNdarray_HOST_DIMS(input)[3];
long outputWidth = (inputWidth + 2*padding - kW) / dW + 1;
long outputHeight = (inputHeight + 2*padding - kH) / dH + 1;
// check output, size (batchSize, nOutputPlane,
// outputHeight, outputWidth);
if (batchSize != CudaNdarray_HOST_DIMS(output)[0] ||
nOutputPlane != CudaNdarray_HOST_DIMS(output)[1] ||
outputHeight != CudaNdarray_HOST_DIMS(output)[2] ||
outputWidth != CudaNdarray_HOST_DIMS(output)[3]){
PyErr_SetString(PyExc_ValueError,
"GpuCorrMM outputs parameter don't have the good shape\n"
);
return NULL;
}
// Create temporary columns
int col_dim[2];
col_dim[0] = nInputPlane*kW*kH;
col_dim[1]= outputHeight*outputWidth;
CudaNdarray* columns = (CudaNdarray*)CudaNdarray_NewDims(2,col_dim);
int ip_stride = CudaNdarray_HOST_STRIDES(input)[0];
int op_stride = CudaNdarray_HOST_STRIDES(output)[0];
// For each elt in batch, do:
for (int elt = 0; elt < batchSize; elt ++) {
// Matrix mulitply per output:
// 1. Extract columns:
im2col(
input->devdata + elt*ip_stride,
nInputPlane, inputWidth, inputHeight, kW, padding, dW,
columns->devdata
);
// M,N,K are dims of matrix A and B
// (see http://docs.nvidia.com/cuda/cublas/#cublas-lt-t-gt-gemm)
float alpha = 1.0f; float beta = 0.0f;
int m = CudaNdarray_HOST_DIMS(columns)[1];
int n = CudaNdarray_HOST_DIMS(weight)[0];
int k = CudaNdarray_HOST_DIMS(columns)[0];
status = cublasSgemm(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
m, n, k,
&alpha,
columns->devdata, m,
weight->devdata, k,
&beta,
output->devdata + elt * op_stride, m
);
if (status != CUBLAS_STATUS_SUCCESS) {
std::cerr << "!!!! CUBLAS error in GpuCorrMM\n";
}
}
Py_DECREF(columns);
return output;
}
theano/sandbox/cuda/opt.py
浏览文件 @
369af1ad
...
...
@@ -17,7 +17,7 @@ from theano.gof import (local_optimizer, EquilibriumDB, SequenceDB, ProxyDB,
Optimizer
,
toolbox
)
from
theano.gof.python25
import
all
,
any
from
theano.sandbox.cuda.basic_ops
import
(
device_properties
,
gpu_eye
,
device_properties
,
gpu_eye
,
gpu_contiguous
,
gpu_from_host
,
host_from_gpu
,
GpuFromHost
,
HostFromGpu
,
GpuElemwise
,
GpuDimShuffle
,
GpuReshape
,
GpuCAReduce
,
GpuFlatten
,
GpuSubtensor
,
GpuAdvancedSubtensor1
,
...
...
@@ -25,7 +25,7 @@ from theano.sandbox.cuda.basic_ops import (
GpuIncSubtensor
,
gpu_alloc
,
GpuAlloc
,
gpu_shape
)
from
theano.sandbox.cuda.type
import
CudaNdarrayType
from
theano.sandbox.cuda.blas
import
(
gpu_dot22
,
gpu_dot22scalar
,
gpu_gemm_inplace
,
gpu_gemm_no_inplace
,
GpuConv
)
gpu_gemm_inplace
,
gpu_gemm_no_inplace
,
GpuConv
,
GpuCorrMM
)
from
theano.sandbox.cuda.blas
import
gpu_gemv_inplace
from
theano.sandbox.cuda.blas
import
gpu_gemv_no_inplace
from
theano.sandbox.cuda.blas
import
gpu_ger_inplace
...
...
@@ -1259,6 +1259,7 @@ gpu_optimizer.register("conv_fft_full", local_conv_fft_full)
import
theano.tensor.signal.downsample
as
downsample
@register_opt
()
@local_optimizer
([
downsample
.
DownsampleFactorMax
])
def
local_gpu_downsample_factor_max
(
node
):
...
...
@@ -1282,6 +1283,19 @@ def local_gpu_downsample_factor_max_grad(node):
gpu_from_host
(
gz
)))]
@local_optimizer
([
GpuConv
])
def
local_conv_gemm
(
node
):
if
(
isinstance
(
node
.
op
,
GpuConv
)
and
node
.
op
.
border_mode
in
[
'full'
,
'valid'
]
and
node
.
op
.
subsample
==
(
1
,
1
)):
img
,
kern
=
node
.
inputs
img
=
gpu_contiguous
(
img
)
kern
=
kern
[:,
:,
::
-
1
,
::
-
1
]
kern
=
gpu_contiguous
(
kern
)
return
[
GpuCorrMM
(
node
.
op
.
border_mode
)(
img
,
kern
)]
gpu_optimizer
.
register
(
"conv_gemm"
,
local_conv_gemm
)
from
theano.sandbox.cuda.basic_ops
import
gpu_join
,
GpuJoin
...
...
theano/sandbox/cuda/tests/test_conv_cuda_ndarray.py
浏览文件 @
369af1ad
...
...
@@ -21,9 +21,9 @@ from theano import tensor
from
theano.gof.python25
import
any
from
theano.tests.unittest_tools
import
seed_rng
# Skip test if cuda
_ndarray
is not available.
import
theano.sandbox.cuda
as
cuda_ndarray
if
cuda
_ndarray
.
cuda_available
==
False
:
# Skip test if cuda is not available.
from
theano.sandbox
import
cuda
if
cuda
.
cuda_available
==
False
:
raise
SkipTest
(
'Optional package cuda disabled'
)
#needed as the gpu conv don't have a perform implementation.
...
...
@@ -32,11 +32,11 @@ if theano.config.mode == 'FAST_COMPILE':
else
:
theano_mode
=
theano
.
compile
.
mode
.
get_default_mode
()
.
including
(
'gpu'
)
cuda_tensor4
=
cuda
_ndarray
.
CudaNdarrayType
([
False
]
*
4
)
cuda_tensor4
=
cuda
.
CudaNdarrayType
([
False
]
*
4
)
device_id
=
theano
.
sandbox
.
cuda
.
use
.
device_number
if
device_id
is
None
:
cuda
_ndarray
.
shared_constructor
(
numpy
.
zeros
(
2
,
dtype
=
'float32'
))
cuda
.
shared_constructor
(
numpy
.
zeros
(
2
,
dtype
=
'float32'
))
device_id
=
theano
.
sandbox
.
cuda
.
use
.
device_number
if
device_id
is
None
:
cuda
.
use
(
"gpu"
,
...
...
@@ -46,6 +46,7 @@ if device_id is None:
enable_cuda
=
False
,
test_driver
=
True
)
device_id
=
theano
.
sandbox
.
cuda
.
use
.
device_number
cuda_ndarray
=
theano
.
sandbox
.
cuda
.
cuda_ndarray
.
cuda_ndarray
device_prop
=
cuda_ndarray
.
device_properties
(
device_id
)
...
...
@@ -114,8 +115,8 @@ def py_conv_scipy(img, kern, mode, subsample):
for
k
in
xrange
(
out
.
shape
[
1
]):
for
s
in
xrange
(
img
.
shape
[
1
]):
out
[
b
,
k
,
:,
:]
+=
convolve2d
(
img
[
b
,
s
,
:,
:],
kern
[
k
,
s
,
:,
:],
mode
)
kern
[
k
,
s
,
:,
:],
mode
)
return
out
[:,
:,
::
subsample
[
0
],
::
subsample
[
1
]]
...
...
@@ -126,7 +127,8 @@ def _params_allgood_header():
def
_params_allgood
(
ishape
,
kshape
,
mode
,
subsample
=
(
1
,
1
),
img_stride
=
(
1
,
1
),
kern_stride
=
(
1
,
1
),
version
=-
1
,
verbose
=
0
,
random
=
True
,
print_
=
None
,
id
=
None
,
rtol
=
1e-5
,
atol
=
1e-8
,
nb_iter
=
0
,
ones
=
False
,
compile_kshp
=
None
):
nb_iter
=
0
,
ones
=
False
,
compile_kshp
=
None
,
theano_mode
=
None
,
cls
=
None
):
#
# This function is the core of several of the big unit-test drivers,
# but it can also be used very directly on its own to test a specific
...
...
@@ -181,6 +183,9 @@ def _params_allgood(ishape, kshape, mode, subsample=(1, 1), img_stride=(1, 1),
verbose
=
verbose
,
kshp
=
compile_kshp
)(
i
,
k
)
f
=
theano
.
function
([
i
,
k
],
op
,
mode
=
theano_mode
)
if
cls
is
not
None
:
assert
any
([
isinstance
(
node
.
op
,
cls
)
for
node
in
f
.
maker
.
fgraph
.
toposort
()]),
f
.
maker
.
fgraph
.
toposort
()
gpuval
=
f
(
img
,
kern
)
t2
=
time
.
time
()
for
i
in
range
(
nb_iter
):
...
...
@@ -195,7 +200,7 @@ def _params_allgood(ishape, kshape, mode, subsample=(1, 1), img_stride=(1, 1),
rval
=
False
if
rval
:
rval
=
numpy
.
allclose
(
cpuval
,
gpuval
,
rtol
=
rtol
)
assert
numpy
.
all
(
numpy
.
isfinite
(
gpuval
))
assert
numpy
.
all
(
numpy
.
isfinite
(
gpuval
))
,
gpuval
except
NotImplementedError
,
e
:
print
>>
sys
.
stdout
,
'_params_allgood Failed allclose'
,
e
rval
=
False
...
...
@@ -247,7 +252,8 @@ def _params_allgood(ishape, kshape, mode, subsample=(1, 1), img_stride=(1, 1),
def
exec_conv
(
version
,
shapes
,
verbose
,
random
,
mode
,
print_
=
None
,
rtol
=
1e-5
,
ones
=
False
):
print_
=
None
,
rtol
=
1e-5
,
ones
=
False
,
theano_mode
=
theano_mode
,
cls
=
None
):
if
verbose
>
0
:
_params_allgood_header
()
nb_failed
=
0
...
...
@@ -273,7 +279,9 @@ def exec_conv(version, shapes, verbose, random, mode,
id
=
id
,
print_
=
print_
,
rtol
=
rtol
,
ones
=
ones
)
ones
=
ones
,
theano_mode
=
theano_mode
,
cls
=
cls
)
except
Exception
,
e
:
print
ver
,
id
,
(
ishape
,
kshape
,
subshape
,
istride
,
kstride
)
print
e
...
...
@@ -624,8 +632,23 @@ def test_valid():
if
ones
:
random
=
False
# exec_conv(version, shapes, verbose, random, 'valid',
# print_=print_, ones=ones, rtol=1.1e-5)
mode
=
theano_mode
.
including
(
"conv_gemm"
)
version
=
[
-
1
]
# Remove case not supported
# Add tests with strided inputs by still square images and filters.
shapes
+=
get_shapes2
(
scales_img
=
(
2
,
2
),
img_stride
=
(
2
,
2
))
shapes
+=
get_shapes2
(
scales_kern
=
(
2
,
2
),
kern_stride
=
(
2
,
2
))
# Keep only tests with square images and filters even with inputs strides
shapes
=
[
shp
for
shp
in
shapes
if
(
shp
[
0
][
2
]
/
shp
[
3
][
0
]
==
shp
[
0
][
3
]
/
shp
[
3
][
1
]
and
shp
[
1
][
2
]
/
shp
[
4
][
0
]
==
shp
[
1
][
3
]
/
shp
[
4
][
1
])]
exec_conv
(
version
,
shapes
,
verbose
,
random
,
'valid'
,
print_
=
print_
,
ones
=
ones
,
rtol
=
1.1e-5
)
print_
=
print_
,
ones
=
ones
,
rtol
=
1.1e-5
,
theano_mode
=
mode
,
cls
=
cuda
.
blas
.
GpuCorrMM
)
def
test_full
():
...
...
@@ -688,7 +711,16 @@ def test_full():
# version=[4]
random
=
True
exec_conv
(
version
,
shapes
,
verbose
,
random
,
'full'
)
# exec_conv(version, shapes, verbose, random, 'full')
# Test the GpuCorrMM version
mode
=
theano_mode
.
including
(
"conv_gemm"
)
shapes
=
[
shp
for
shp
in
shapes
if
shp
[
1
][
2
]
==
shp
[
1
][
3
]]
shapes
=
[
shp
for
shp
in
shapes
if
shp
[
0
][
2
]
==
shp
[
0
][
3
]]
shapes
=
shapes
[
0
:
10
]
exec_conv
(
version
,
shapes
,
verbose
,
random
,
'full'
,
theano_mode
=
mode
,
cls
=
cuda
.
blas
.
GpuCorrMM
)
def
test_subsample
():
...
...
@@ -792,31 +824,53 @@ class TestConv2DGPU(unittest.TestCase):
theano_mode
=
theano_mode_orig
def
_test_dummy
():
ishape
=
(
1
,
1
,
5
,
5
)
kshape
=
(
1
,
1
,
3
,
3
)
mode
=
'valid'
subsample
=
(
1
,
1
)
npy_img
=
theano
.
_asarray
(
numpy
.
random
.
rand
(
*
ishape
),
dtype
=
'float32'
)
npy_kern
=
theano
.
_asarray
(
numpy
.
random
.
rand
(
*
kshape
),
dtype
=
'float32'
)
img
=
cuda_ndarray
.
CudaNdarray
(
npy_img
)
kern
=
cuda_ndarray
.
CudaNdarray
(
npy_kern
)
#print >> sys.stdout, '_params_allgood trying ', ishape, kshape, mode
t2
=
None
rval
=
True
t0
=
time
.
time
()
cpuval
=
py_conv
(
npy_img
,
npy_kern
,
mode
,
subsample
)
t1
=
time
.
time
()
gpuval
=
cuda_ndarray
.
conv
(
img
,
kern
,
mode
,
subsample
)
t2
=
time
.
time
()
gpuval
=
numpy
.
asarray
(
gpuval
)
print
gpuval
print
cpuval
def
test_gemm
():
"""
input: (batch size, channels, rows, columns)
filters: (number of filters, channels, rows, columns)
"""
for
mode
in
[
'valid'
,
'full'
]:
print
'Testing mode: '
+
mode
for
bs
in
range
(
1
,
5
):
for
ch
in
range
(
1
,
4
):
for
nf
in
range
(
1
,
4
):
for
rImg
in
range
(
5
,
9
):
for
rFlt
in
range
(
2
,
4
):
ishape
=
(
bs
,
ch
,
rImg
,
rImg
)
kshape
=
(
nf
,
ch
,
rFlt
,
rFlt
)
print
"ishape: "
,
ishape
print
"kshape: "
,
kshape
subsample
=
(
1
,
1
)
npy_img
=
theano
.
_asarray
(
numpy
.
random
.
rand
(
*
ishape
),
dtype
=
'float32'
)
npy_kern
=
theano
.
_asarray
(
numpy
.
random
.
rand
(
*
kshape
),
dtype
=
'float32'
)
i
=
cuda_tensor4
()
k
=
cuda_tensor4
()
t2
=
None
t0
=
time
.
time
()
cpuval
=
py_conv
(
npy_img
,
npy_kern
,
mode
,
subsample
)
t1
=
time
.
time
()
op
=
theano
.
sandbox
.
cuda
.
blas
.
GpuCorrMM
(
border_mode
=
mode
)(
i
,
k
)
f
=
theano
.
function
([
i
,
k
],
op
,
mode
=
theano_mode
)
for
k
in
range
(
npy_kern
.
shape
[
0
]):
for
s
in
range
(
npy_kern
.
shape
[
1
]):
npy_kern
[
k
,
s
,:,:]
=
numpy
.
rot90
(
npy_kern
[
k
,
s
,:,:],
2
)
gpuval
=
f
(
npy_img
,
npy_kern
)
t2
=
time
.
time
()
gpuval
=
numpy
.
asarray
(
gpuval
)
rval
=
numpy
.
allclose
(
cpuval
,
gpuval
,
rtol
=
1e-4
)
assert
(
rval
==
True
)
print
'Test Passed'
def
benchmark
():
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论