Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
e76a29d9
提交
e76a29d9
authored
8月 12, 2014
作者:
f0k
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Adds caffe's implementation of the full convolution to CorrMM; cleaning up and…
Adds caffe's implementation of the full convolution to CorrMM; cleaning up and documenting the code on the way
上级
f6bf2943
全部展开
隐藏空白字符变更
内嵌
并排
正在显示
5 个修改的文件
包含
84 行增加
和
84 行删除
+84
-84
blas.py
theano/sandbox/cuda/blas.py
+68
-33
caffe_common.hpp
theano/sandbox/cuda/caffe_common.hpp
+0
-47
conv_gemm.cu
theano/sandbox/cuda/conv_gemm.cu
+0
-0
opt.py
theano/sandbox/cuda/opt.py
+15
-3
test_conv_cuda_ndarray.py
theano/sandbox/cuda/tests/test_conv_cuda_ndarray.py
+1
-1
没有找到文件。
theano/sandbox/cuda/blas.py
浏览文件 @
e76a29d9
...
@@ -501,29 +501,59 @@ gpu_ger_inplace = GpuGer(inplace=True)
...
@@ -501,29 +501,59 @@ gpu_ger_inplace = GpuGer(inplace=True)
class
GpuCorrMM
(
GpuOp
):
class
GpuCorrMM
(
GpuOp
):
"""GPU correlation
implementation using Matrix Multiply
.
"""GPU correlation
/convolution implementation using Matrix Multiplication
.
:note: It don't implement the grad. So you should use it by
:note: It doesn't implement the grad. So you shouldn't use it directly, but
enabling the Theano flag ``optimizer_including=conv_gemm`` and
use :func:`conv2d <theano.tensor.nnet.conv.conv2d>` and then enable the
use :func:`conv2d <theano.tensor.nnet.conv.conv2d>`.
Theano flag ``optimizer_including=conv_gemm`` to automatically replace
all convolution operations with `GpuCorrMM`.
"""
"""
def
__init__
(
self
,
border_mode
,
def
__init__
(
self
,
border_mode
,
subsample
=
(
1
,
1
),
subsample
=
(
1
,
1
),
pad
=
0
):
pad
=
(
0
,
0
)
):
"""
"""
:param border_mode: "valid" or "full"
:param border_mode: "valid" or "full"
:param subsample: the subsample operation applied
on
each output image.
:param subsample: the subsample operation applied
to
each output image.
Should be a tuple with 2 elements.
Should be a tuple with 2 elements.
(sv, sh) is equivalent to GpuCorrMM(...)(...)[:,:,::sv, ::sh]
(sv, sh) is equivalent to GpuCorrMM(...)(...)[:,:,::sv, ::sh]
:param pad: not yet supported
If border_mode="full", this is instead treated as an upsampling
operation applied to each input image.
Set to (1, 1) to disable downsampling/upsampling.
:param pad: the width of a border of implicit zeros to pad the input
image with. Should be a tuple with 2 elements giving the numbers of
rows and columns to pad on each side, or "auto" to set the padding
to (kernel_rows - 1, kernel_columns - 1) at runtime.
If border_mode="full", this is instead treated as the width of a
border to crop from the output image.
Set to (0, 0) to disable padding/cropping.
:note: The border_mode changes the meaning of several parameters.
If border_mode="valid", the Op does a valid correlation of a padded
input image and subsamples it. (To perform a convolution instead,
you will need to flip the kernels.)
If border_mode="full", the Op does a full convolution of an
upsampled input image and crops it. (This can be used as a backward
pass of the valid correlation done with border_mode="valid".)
Combined with pad="auto", you can use border_mode="valid" to
simulate a full correlation with subsampling, or border_mode="full"
to simulate a valid convolution with upsampling.
:note: Currently, the Op requires a very specific memory layout.
For border_mode="valid", inputs, filters and outputs must be
C-contiguous. For border_mode="full", the same applies, except that
the strides of the first two dimensions of the filters (output and
input channels) must be swapped compared to C-contiguity.
"""
"""
self
.
border_mode
=
border_mode
self
.
border_mode
=
border_mode
self
.
subsample
=
subsample
self
.
subsample
=
subsample
#if (border_mode == "full") and (subsample != (1,1)):
# raise NotImplementedError(
# "GpuCorrMM doesn't support subsampling for border_mode='full'")
self
.
pad
=
pad
self
.
pad
=
pad
if
pad
!=
0
:
#if (border_mode == "full") and (pad != (0,0))
:
raise
NotImplementedError
(
#
raise NotImplementedError(
"GpuCorrMM don't implement the pad parameter
"
)
# "GpuCorrMM doesn't support padding for border_mode='full'
")
def
__eq__
(
self
,
other
):
def
__eq__
(
self
,
other
):
return
type
(
self
)
==
type
(
other
)
\
return
type
(
self
)
==
type
(
other
)
\
...
@@ -540,7 +570,7 @@ class GpuCorrMM(GpuOp):
...
@@ -540,7 +570,7 @@ class GpuCorrMM(GpuOp):
^
hash
(
self
.
pad
)
^
hash
(
self
.
pad
)
def
__str__
(
self
):
def
__str__
(
self
):
return
'
%
s{
%
s,
%
s, pad=
%
d
}'
%
(
return
'
%
s{
%
s,
%
s, pad=
%
r
}'
%
(
self
.
__class__
.
__name__
,
self
.
__class__
.
__name__
,
self
.
border_mode
,
self
.
border_mode
,
str
(
self
.
subsample
),
str
(
self
.
subsample
),
...
@@ -581,7 +611,7 @@ class GpuCorrMM(GpuOp):
...
@@ -581,7 +611,7 @@ class GpuCorrMM(GpuOp):
def
c_code_cache_version
(
self
):
def
c_code_cache_version
(
self
):
# raise this whenever modifying any of the support_code_files
# raise this whenever modifying any of the support_code_files
return
(
0
,
2
2
)
return
(
0
,
2
3
)
def
c_support_code_apply
(
self
,
node
,
nodename
):
def
c_support_code_apply
(
self
,
node
,
nodename
):
# REMEMBER TO RAISE c_code_cache_version when changing any of
# REMEMBER TO RAISE c_code_cache_version when changing any of
...
@@ -596,13 +626,18 @@ class GpuCorrMM(GpuOp):
...
@@ -596,13 +626,18 @@ class GpuCorrMM(GpuOp):
out
,
=
out_
out
,
=
out_
dx
=
self
.
subsample
[
0
]
dx
=
self
.
subsample
[
0
]
dy
=
self
.
subsample
[
1
]
dy
=
self
.
subsample
[
1
]
sub
=
sub
.
copy
()
if
self
.
pad
==
"auto"
:
pad
=
self
.
pad
padH
=
padW
=
-
1
else
:
padH
=
self
.
pad
[
0
]
padW
=
self
.
pad
[
1
]
if
self
.
border_mode
==
"valid"
:
if
self
.
border_mode
==
"valid"
:
bmode
=
1
bmode
=
1
else
:
elif
self
.
border_mode
==
"full"
:
assert
self
.
border_mode
==
"full"
bmode
=
0
bmode
=
0
else
:
raise
ValueError
(
"mode must be one of 'full' or 'valid'"
)
sub
=
sub
.
copy
()
sub
.
update
(
locals
())
sub
.
update
(
locals
())
return
"""
return
"""
...
@@ -612,33 +647,34 @@ class GpuCorrMM(GpuOp):
...
@@ -612,33 +647,34 @@ class GpuCorrMM(GpuOp):
//Optional args
//Optional args
int dx =
%(dx)
s;
int dx =
%(dx)
s;
int dy =
%(dy)
s;
int dy =
%(dy)
s;
int padH =
0
;
int padH =
%(padH)
s
;
int padW =
0
;
int padW =
%(padW)
s
;
CudaNdarray * img =
%(img)
s;
CudaNdarray * img =
%(img)
s;
CudaNdarray * kern =
%(kern)
s;
CudaNdarray * kern =
%(kern)
s;
CudaNdarray * out2 = NULL;
CudaNdarray * out2 = NULL;
//TODO: Send self.pad, stride, etc
//Auto-padding if requested
if (padH < 0) {
padH = CudaNdarray_HOST_DIMS(kern)[2] - 1;
}
if (padW < 0) {
padW = CudaNdarray_HOST_DIMS(kern)[3] - 1;
}
int out_dim[4];
int out_dim[4];
out_dim[0] = CudaNdarray_HOST_DIMS(img)[0];
out_dim[0] = CudaNdarray_HOST_DIMS(img)[0];
out_dim[1] = CudaNdarray_HOST_DIMS(kern)[0];
out_dim[1] = CudaNdarray_HOST_DIMS(kern)[0];
int logical_rows, logical_cols;
if (mode == 1) // valid correlation with padding and subsampling
if (mode == 1)
{
{
logical_rows = CudaNdarray_HOST_DIMS(img)[2] - CudaNdarray_HOST_DIMS(kern)[2] + 1
;
out_dim[2] = ceil_intdiv(CudaNdarray_HOST_DIMS(img)[2] + 2*padH - CudaNdarray_HOST_DIMS(kern)[2] + 1, dx)
;
logical_cols = CudaNdarray_HOST_DIMS(img)[3] - CudaNdarray_HOST_DIMS(kern)[3] + 1
;
out_dim[3] = ceil_intdiv(CudaNdarray_HOST_DIMS(img)[3] + 2*padW - CudaNdarray_HOST_DIMS(kern)[3] + 1, dy)
;
}
}
else
else
// full convolution with upsampling and cropping
{
{
logical_rows = CudaNdarray_HOST_DIMS(img)[2] + CudaNdarray_HOST_DIMS(kern)[2] - 1;
out_dim[2] = (CudaNdarray_HOST_DIMS(img)[2] - 1) * dx + CudaNdarray_HOST_DIMS(kern)[2] - 2*padH;
logical_cols = CudaNdarray_HOST_DIMS(img)[3] + CudaNdarray_HOST_DIMS(kern)[3] - 1;
out_dim[3] = (CudaNdarray_HOST_DIMS(img)[3] - 1) * dy + CudaNdarray_HOST_DIMS(kern)[3] - 2*padW;
padH = CudaNdarray_HOST_DIMS(kern)[2] - 1;
padW = CudaNdarray_HOST_DIMS(kern)[3] - 1;
}
}
out_dim[2] = ceil_intdiv(logical_rows, dx);
out_dim[3] = ceil_intdiv(logical_cols, dy);
if ( !(
%(out)
s
if ( !(
%(out)
s
&&
%(out)
s->nd==4
&&
%(out)
s->nd==4
...
@@ -650,10 +686,9 @@ class GpuCorrMM(GpuOp):
...
@@ -650,10 +686,9 @@ class GpuCorrMM(GpuOp):
{
{
Py_XDECREF(
%(out)
s);
Py_XDECREF(
%(out)
s);
%(out)
s = (CudaNdarray*)CudaNdarray_NewDims(4,out_dim);
%(out)
s = (CudaNdarray*)CudaNdarray_NewDims(4,out_dim);
}
}
out2 = corrMM(
%(img)
s,
%(kern)
s,
%(out)
s, dx, dy, padH, padW);
out2 = corrMM(
%(img)
s,
%(kern)
s,
%(out)
s,
mode,
dx, dy, padH, padW);
if (out2==NULL){
if (out2==NULL){
%(fail)
s
%(fail)
s
}
}
...
...
theano/sandbox/cuda/caffe_common.hpp
deleted
100644 → 0
浏览文件 @
f6bf2943
/*
Copyright (c) 2014, The Regents of the University of California (Regents)
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
1. Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
#ifndef CAFFE_COMMON_HPP_
#define CAFFE_COMMON_HPP_
#include <cublas_v2.h>
#include <cuda.h>
#include <driver_types.h> // cuda driver types
// CUDA: thread number configuration.
// Use 1024 threads per block, which requires cuda sm_2x or above,
// or fall back to attempt compatibility (best of luck to you).
#if __CUDA_ARCH__ >= 200
const
int
CAFFE_CUDA_NUM_THREADS
=
1024
;
#else
const
int
CAFFE_CUDA_NUM_THREADS
=
512
;
#endif
// CUDA: number of blocks for threads.
inline
int
CAFFE_GET_BLOCKS
(
const
int
N
)
{
return
(
N
+
CAFFE_CUDA_NUM_THREADS
-
1
)
/
CAFFE_CUDA_NUM_THREADS
;
}
#endif // CAFFE_COMMON_HPP_
theano/sandbox/cuda/conv_gemm.cu
浏览文件 @
e76a29d9
差异被折叠。
点击展开。
theano/sandbox/cuda/opt.py
浏览文件 @
e76a29d9
...
@@ -1351,10 +1351,22 @@ def local_conv_gemm(node):
...
@@ -1351,10 +1351,22 @@ def local_conv_gemm(node):
if
(
isinstance
(
node
.
op
,
GpuConv
)
and
if
(
isinstance
(
node
.
op
,
GpuConv
)
and
node
.
op
.
border_mode
in
[
'full'
,
'valid'
]):
node
.
op
.
border_mode
in
[
'full'
,
'valid'
]):
img
,
kern
=
node
.
inputs
img
,
kern
=
node
.
inputs
border_mode
=
node
.
op
.
border_mode
subsample
=
node
.
op
.
subsample
pad
=
(
0
,
0
)
if
(
border_mode
==
'full'
)
and
((
subsample
!=
(
1
,
1
))
or
(
pad
!=
(
0
,
0
))):
# need to simulate this via a padded valid convolution
pad
=
'auto'
border_mode
=
'valid'
if
(
border_mode
==
'valid'
):
# need to flip the kernel for valid convolution
kern
=
gpu_contiguous
(
kern
[:,
:,
::
-
1
,
::
-
1
])
elif
(
border_mode
==
'full'
):
# need to bring kernel into correct memory layout for full convolution
kern
=
gpu_contiguous
(
kern
.
dimshuffle
(
1
,
0
,
2
,
3
))
.
dimshuffle
(
1
,
0
,
2
,
3
)
# need C-contiguous inputs
img
=
gpu_contiguous
(
img
)
img
=
gpu_contiguous
(
img
)
kern
=
kern
[:,
:,
::
-
1
,
::
-
1
]
return
[
GpuCorrMM
(
border_mode
,
subsample
,
pad
)(
img
,
kern
)]
kern
=
gpu_contiguous
(
kern
)
return
[
GpuCorrMM
(
node
.
op
.
border_mode
,
node
.
op
.
subsample
)(
img
,
kern
)]
gpu_optimizer
.
register
(
"conv_gemm"
,
local_conv_gemm
)
gpu_optimizer
.
register
(
"conv_gemm"
,
local_conv_gemm
)
...
...
theano/sandbox/cuda/tests/test_conv_cuda_ndarray.py
浏览文件 @
e76a29d9
...
@@ -848,7 +848,7 @@ def test_gemm_directly():
...
@@ -848,7 +848,7 @@ def test_gemm_directly():
input: (batch size, channels, rows, columns)
input: (batch size, channels, rows, columns)
filters: (number of filters, channels, rows, columns)
filters: (number of filters, channels, rows, columns)
"""
"""
for
mode
in
[
'
full'
,
'valid'
]:
for
mode
in
[
'
valid'
]:
# 'full' currently disabled; doesn't allow subsampling
print
'Testing mode: '
+
mode
print
'Testing mode: '
+
mode
for
bs
in
range
(
1
,
5
):
for
bs
in
range
(
1
,
5
):
for
ch
in
range
(
1
,
4
):
for
ch
in
range
(
1
,
4
):
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论