Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
4c55bc4b
提交
4c55bc4b
authored
8月 04, 2014
作者:
Arjun Jain
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
- added some documentaiton
- changed conv to corr as suggested by Fred
上级
1e3de2ce
显示空白字符变更
内嵌
并排
正在显示
5 个修改的文件
包含
28 行增加
和
15 行删除
+28
-15
conv.txt
doc/library/tensor/nnet/conv.txt
+13
-0
blas.py
theano/sandbox/cuda/blas.py
+4
-4
conv_gemm.cu
theano/sandbox/cuda/conv_gemm.cu
+5
-5
opt.py
theano/sandbox/cuda/opt.py
+2
-2
test_conv_cuda_ndarray.py
theano/sandbox/cuda/tests/test_conv_cuda_ndarray.py
+4
-4
没有找到文件。
doc/library/tensor/nnet/conv.txt
浏览文件 @
4c55bc4b
...
@@ -51,8 +51,21 @@ TODO: Give examples for how to use these things! They are pretty complicated.
...
@@ -51,8 +51,21 @@ TODO: Give examples for how to use these things! They are pretty complicated.
implementation.
implementation.
Also, there is restrictions on which shape are supported.
Also, there is restrictions on which shape are supported.
- :func:`GpuCorrMM <theano.sandbox.cuda.blas.GpuCorrMM>`
This is a GPU-only version of a correlation that computes correlations
as `caffe <https://github.com/BVLC/caffe/blob/master/src/caffe/layers/conv_layer.cu>`.
For each element in a batch, it first creates a
Toeplitz<http://en.wikipedia.org/wiki/Toeplitz_matrix> matrix in a cuda kernel.
Then, it performs a `gemm` call to multiply this Toeplitz matrix and to the kernel.
It need extra memory for this, which is the size of the Toeplitz matrix. Precisely,
the dimensions of this Toeplitz matrix is equal to
(no of channels * filter width * filter height, output width * output height).
You can enable it for call to conv2d 2d by setting 'THEANO_FLAGS=optimizer_including=conv_gemm'
in your environment. This is not enabled by default because it
uses some extra memory. It don't support strides for now and requires square kernels.
.. autofunction:: theano.tensor.nnet.conv.conv2d
.. autofunction:: theano.tensor.nnet.conv.conv2d
.. autofunction:: theano.tensor.nnet.Conv3D.conv3D
.. autofunction:: theano.tensor.nnet.Conv3D.conv3D
.. autofunction:: theano.tensor.nnet.conv3d2d.conv3d
.. autofunction:: theano.tensor.nnet.conv3d2d.conv3d
.. autofunction:: theano.sandbox.cuda.fftconv.conv2d_fft
.. autofunction:: theano.sandbox.cuda.fftconv.conv2d_fft
.. autofunction:: theano.sandbox.cuda.blas.GpuCorrMM
theano/sandbox/cuda/blas.py
浏览文件 @
4c55bc4b
...
@@ -498,7 +498,7 @@ gpu_ger_no_inplace = GpuGer(inplace=False)
...
@@ -498,7 +498,7 @@ gpu_ger_no_inplace = GpuGer(inplace=False)
gpu_ger_inplace
=
GpuGer
(
inplace
=
True
)
gpu_ger_inplace
=
GpuGer
(
inplace
=
True
)
class
GpuCo
nv
MM
(
GpuOp
):
class
GpuCo
rr
MM
(
GpuOp
):
"""
"""
Author: Arjun Jain
Author: Arjun Jain
Implement the caffe convolution
Implement the caffe convolution
...
@@ -516,10 +516,10 @@ class GpuConvMM(GpuOp):
...
@@ -516,10 +516,10 @@ class GpuConvMM(GpuOp):
self
.
pad
=
pad
self
.
pad
=
pad
if
pad
!=
0
:
if
pad
!=
0
:
raise
NotImplementedError
(
raise
NotImplementedError
(
"GpuCo
nv
MM don't implement the pad parameter"
)
"GpuCo
rr
MM don't implement the pad parameter"
)
if
subsample
!=
(
1
,
1
):
if
subsample
!=
(
1
,
1
):
raise
NotImplementedError
(
raise
NotImplementedError
(
"GpuCo
nv
MM we don't implement the subsample parameter"
)
"GpuCo
rr
MM we don't implement the subsample parameter"
)
def
__eq__
(
self
,
other
):
def
__eq__
(
self
,
other
):
return
type
(
self
)
==
type
(
other
)
\
return
type
(
self
)
==
type
(
other
)
\
...
@@ -658,7 +658,7 @@ class GpuConvMM(GpuOp):
...
@@ -658,7 +658,7 @@ class GpuConvMM(GpuOp):
}
}
out2 =
valid
MM(
%(img)
s,
%(kern)
s,
%(out)
s, pad);
out2 =
corr
MM(
%(img)
s,
%(kern)
s,
%(out)
s, pad);
if (out2==NULL){
if (out2==NULL){
%(fail)
s
%(fail)
s
}
}
...
...
theano/sandbox/cuda/conv_gemm.cu
浏览文件 @
4c55bc4b
...
@@ -105,7 +105,7 @@ CudaNdarray* corrMM(const CudaNdarray *input,
...
@@ -105,7 +105,7 @@ CudaNdarray* corrMM(const CudaNdarray *input,
long batchSize = CudaNdarray_HOST_DIMS(input)[0];
long batchSize = CudaNdarray_HOST_DIMS(input)[0];
if (CudaNdarray_HOST_DIMS(input)[2] != CudaNdarray_HOST_DIMS(input)[3]){
if (CudaNdarray_HOST_DIMS(input)[2] != CudaNdarray_HOST_DIMS(input)[3]){
PyErr_Format(PyExc_ValueError,
PyErr_Format(PyExc_ValueError,
"GpuCo
nv
MM support only square images. Got %dx%d images\n",
"GpuCo
rr
MM support only square images. Got %dx%d images\n",
CudaNdarray_HOST_DIMS(input)[2],
CudaNdarray_HOST_DIMS(input)[2],
CudaNdarray_HOST_DIMS(input)[3]
CudaNdarray_HOST_DIMS(input)[3]
);
);
...
@@ -113,14 +113,14 @@ CudaNdarray* corrMM(const CudaNdarray *input,
...
@@ -113,14 +113,14 @@ CudaNdarray* corrMM(const CudaNdarray *input,
}
}
if (kW != kH){
if (kW != kH){
PyErr_Format(PyExc_ValueError,
PyErr_Format(PyExc_ValueError,
"GpuCo
nv
MM support only square kernel. Got %dx%d kernel\n",
"GpuCo
rr
MM support only square kernel. Got %dx%d kernel\n",
kW, kH
kW, kH
);
);
return NULL;
return NULL;
}
}
if (CudaNdarray_HOST_DIMS(input)[1] != CudaNdarray_HOST_DIMS(weight)[1]){
if (CudaNdarray_HOST_DIMS(input)[1] != CudaNdarray_HOST_DIMS(weight)[1]){
PyErr_SetString(PyExc_ValueError,
PyErr_SetString(PyExc_ValueError,
"GpuCo
nv
MM images and kernel must have the same stack size\n"
"GpuCo
rr
MM images and kernel must have the same stack size\n"
);
);
return NULL;
return NULL;
}
}
...
@@ -136,7 +136,7 @@ CudaNdarray* corrMM(const CudaNdarray *input,
...
@@ -136,7 +136,7 @@ CudaNdarray* corrMM(const CudaNdarray *input,
outputHeight != CudaNdarray_HOST_DIMS(output)[2] ||
outputHeight != CudaNdarray_HOST_DIMS(output)[2] ||
outputWidth != CudaNdarray_HOST_DIMS(output)[3]){
outputWidth != CudaNdarray_HOST_DIMS(output)[3]){
PyErr_SetString(PyExc_ValueError,
PyErr_SetString(PyExc_ValueError,
"GpuCo
nv
MM outputs parameter don't have the good shape\n"
"GpuCo
rr
MM outputs parameter don't have the good shape\n"
);
);
return NULL;
return NULL;
}
}
...
@@ -182,7 +182,7 @@ CudaNdarray* corrMM(const CudaNdarray *input,
...
@@ -182,7 +182,7 @@ CudaNdarray* corrMM(const CudaNdarray *input,
);
);
if (status != CUBLAS_STATUS_SUCCESS) {
if (status != CUBLAS_STATUS_SUCCESS) {
std::cerr << "!!!! CUBLAS error in GpuCo
nv
MM\n";
std::cerr << "!!!! CUBLAS error in GpuCo
rr
MM\n";
}
}
}
}
...
...
theano/sandbox/cuda/opt.py
浏览文件 @
4c55bc4b
...
@@ -25,7 +25,7 @@ from theano.sandbox.cuda.basic_ops import (
...
@@ -25,7 +25,7 @@ from theano.sandbox.cuda.basic_ops import (
GpuIncSubtensor
,
gpu_alloc
,
GpuAlloc
,
gpu_shape
)
GpuIncSubtensor
,
gpu_alloc
,
GpuAlloc
,
gpu_shape
)
from
theano.sandbox.cuda.type
import
CudaNdarrayType
from
theano.sandbox.cuda.type
import
CudaNdarrayType
from
theano.sandbox.cuda.blas
import
(
gpu_dot22
,
gpu_dot22scalar
,
from
theano.sandbox.cuda.blas
import
(
gpu_dot22
,
gpu_dot22scalar
,
gpu_gemm_inplace
,
gpu_gemm_no_inplace
,
GpuConv
,
GpuCo
nv
MM
)
gpu_gemm_inplace
,
gpu_gemm_no_inplace
,
GpuConv
,
GpuCo
rr
MM
)
from
theano.sandbox.cuda.blas
import
gpu_gemv_inplace
from
theano.sandbox.cuda.blas
import
gpu_gemv_inplace
from
theano.sandbox.cuda.blas
import
gpu_gemv_no_inplace
from
theano.sandbox.cuda.blas
import
gpu_gemv_no_inplace
from
theano.sandbox.cuda.blas
import
gpu_ger_inplace
from
theano.sandbox.cuda.blas
import
gpu_ger_inplace
...
@@ -1292,7 +1292,7 @@ def local_conv_gemm(node):
...
@@ -1292,7 +1292,7 @@ def local_conv_gemm(node):
img
=
gpu_contiguous
(
img
)
img
=
gpu_contiguous
(
img
)
kern
=
kern
[:,
:,
::
-
1
,
::
-
1
]
kern
=
kern
[:,
:,
::
-
1
,
::
-
1
]
kern
=
gpu_contiguous
(
kern
)
kern
=
gpu_contiguous
(
kern
)
return
[
GpuCo
nv
MM
(
node
.
op
.
border_mode
)(
img
,
kern
)]
return
[
GpuCo
rr
MM
(
node
.
op
.
border_mode
)(
img
,
kern
)]
gpu_optimizer
.
register
(
"conv_gemm"
,
local_conv_gemm
)
gpu_optimizer
.
register
(
"conv_gemm"
,
local_conv_gemm
)
...
...
theano/sandbox/cuda/tests/test_conv_cuda_ndarray.py
浏览文件 @
4c55bc4b
...
@@ -648,7 +648,7 @@ def test_valid():
...
@@ -648,7 +648,7 @@ def test_valid():
shp
[
1
][
2
]
/
shp
[
4
][
0
]
==
shp
[
1
][
3
]
/
shp
[
4
][
1
])]
shp
[
1
][
2
]
/
shp
[
4
][
0
]
==
shp
[
1
][
3
]
/
shp
[
4
][
1
])]
exec_conv
(
version
,
shapes
,
verbose
,
random
,
'valid'
,
exec_conv
(
version
,
shapes
,
verbose
,
random
,
'valid'
,
print_
=
print_
,
ones
=
ones
,
rtol
=
1.1e-5
,
print_
=
print_
,
ones
=
ones
,
rtol
=
1.1e-5
,
theano_mode
=
mode
,
cls
=
cuda
.
blas
.
GpuCo
nv
MM
)
theano_mode
=
mode
,
cls
=
cuda
.
blas
.
GpuCo
rr
MM
)
def
test_full
():
def
test_full
():
...
@@ -713,14 +713,14 @@ def test_full():
...
@@ -713,14 +713,14 @@ def test_full():
# exec_conv(version, shapes, verbose, random, 'full')
# exec_conv(version, shapes, verbose, random, 'full')
# Test the GpuCo
nv
MM version
# Test the GpuCo
rr
MM version
mode
=
theano_mode
.
including
(
"conv_gemm"
)
mode
=
theano_mode
.
including
(
"conv_gemm"
)
shapes
=
[
shp
for
shp
in
shapes
if
shp
[
1
][
2
]
==
shp
[
1
][
3
]]
shapes
=
[
shp
for
shp
in
shapes
if
shp
[
1
][
2
]
==
shp
[
1
][
3
]]
shapes
=
[
shp
for
shp
in
shapes
if
shp
[
0
][
2
]
==
shp
[
0
][
3
]]
shapes
=
[
shp
for
shp
in
shapes
if
shp
[
0
][
2
]
==
shp
[
0
][
3
]]
shapes
=
shapes
[
0
:
10
]
shapes
=
shapes
[
0
:
10
]
exec_conv
(
version
,
shapes
,
verbose
,
random
,
'full'
,
exec_conv
(
version
,
shapes
,
verbose
,
random
,
'full'
,
theano_mode
=
mode
,
cls
=
cuda
.
blas
.
GpuCo
nv
MM
)
theano_mode
=
mode
,
cls
=
cuda
.
blas
.
GpuCo
rr
MM
)
def
test_subsample
():
def
test_subsample
():
...
@@ -856,7 +856,7 @@ def test_gemm():
...
@@ -856,7 +856,7 @@ def test_gemm():
t1
=
time
.
time
()
t1
=
time
.
time
()
op
=
theano
.
sandbox
.
cuda
.
blas
.
GpuCo
nv
MM
(
border_mode
=
mode
)(
i
,
k
)
op
=
theano
.
sandbox
.
cuda
.
blas
.
GpuCo
rr
MM
(
border_mode
=
mode
)(
i
,
k
)
f
=
theano
.
function
([
i
,
k
],
op
,
mode
=
theano_mode
)
f
=
theano
.
function
([
i
,
k
],
op
,
mode
=
theano_mode
)
for
k
in
range
(
npy_kern
.
shape
[
0
]):
for
k
in
range
(
npy_kern
.
shape
[
0
]):
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论