Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
c072d669
提交
c072d669
authored
12月 19, 2016
作者:
Frédéric Bastien
提交者:
GitHub
12月 19, 2016
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #5267 from gvtulder/f-abstractconv-differences
Minor inconsistency in AbstractConv_gradInput implementations
上级
1a42bf9b
7f1c3677
全部展开
隐藏空白字符变更
内嵌
并排
正在显示
24 个修改的文件
包含
921 行增加
和
210 行删除
+921
-210
blas.py
theano/gpuarray/blas.py
+0
-0
corr3d_gemm.c
theano/gpuarray/corr3d_gemm.c
+53
-3
corr_gemm.c
theano/gpuarray/corr_gemm.c
+42
-2
dnn.py
theano/gpuarray/dnn.py
+30
-23
dnn_fwd.c
theano/gpuarray/dnn_fwd.c
+14
-5
dnn_gi.c
theano/gpuarray/dnn_gi.c
+56
-5
dnn_gw.c
theano/gpuarray/dnn_gw.c
+56
-5
test_abstractconv.py
theano/gpuarray/tests/test_abstractconv.py
+95
-0
test_dnn.py
theano/gpuarray/tests/test_dnn.py
+19
-24
blas.py
theano/sandbox/cuda/blas.py
+0
-0
corr3d_gemm.cu
theano/sandbox/cuda/corr3d_gemm.cu
+50
-3
corr_gemm.cu
theano/sandbox/cuda/corr_gemm.cu
+48
-2
dnn.py
theano/sandbox/cuda/dnn.py
+23
-14
dnn_fwd.c
theano/sandbox/cuda/dnn_fwd.c
+18
-7
dnn_gi.c
theano/sandbox/cuda/dnn_gi.c
+55
-5
dnn_gw.c
theano/sandbox/cuda/dnn_gw.c
+55
-5
test_abstractconv.py
theano/sandbox/cuda/tests/test_abstractconv.py
+95
-0
test_dnn.py
theano/sandbox/cuda/tests/test_dnn.py
+86
-78
abstract_conv.py
theano/tensor/nnet/abstract_conv.py
+0
-0
corr.py
theano/tensor/nnet/corr.py
+72
-22
corr3d.py
theano/tensor/nnet/corr3d.py
+0
-0
corr3d_gemm.c
theano/tensor/nnet/corr3d_gemm.c
+28
-4
corr_gemm.c
theano/tensor/nnet/corr_gemm.c
+26
-3
test_abstract_conv.py
theano/tensor/nnet/tests/test_abstract_conv.py
+0
-0
没有找到文件。
theano/gpuarray/blas.py
浏览文件 @
c072d669
差异被折叠。
点击展开。
theano/gpuarray/corr3d_gemm.c
浏览文件 @
c072d669
...
@@ -425,9 +425,17 @@ PyGpuArrayObject* corr3dMM(PyGpuArrayObject *const bottom,
...
@@ -425,9 +425,17 @@ PyGpuArrayObject* corr3dMM(PyGpuArrayObject *const bottom,
const
size_t
dil_kW
=
(
kW
-
1
)
*
dilW
+
1
;
const
size_t
dil_kW
=
(
kW
-
1
)
*
dilW
+
1
;
const
size_t
dil_kD
=
(
kD
-
1
)
*
dilD
+
1
;
const
size_t
dil_kD
=
(
kD
-
1
)
*
dilD
+
1
;
// top: (batchSize, nFilters, topHeight, topWidth, topDepth)
// top: (batchSize, nFilters, topHeight, topWidth, topDepth)
const
size_t
topHeight
=
(
bottomHeight
+
2
*
padH
-
dil_kH
)
/
dH
+
1
;
const
size_t
topHeightNoDH
=
(
bottomHeight
+
2
*
padH
-
dil_kH
);
const
size_t
topWidth
=
(
bottomWidth
+
2
*
padW
-
dil_kW
)
/
dW
+
1
;
const
size_t
topWidthNoDW
=
(
bottomWidth
+
2
*
padW
-
dil_kW
);
const
size_t
topDepth
=
(
bottomDepth
+
2
*
padD
-
dil_kD
)
/
dD
+
1
;
const
size_t
topDepthNoDD
=
(
bottomDepth
+
2
*
padD
-
dil_kD
);
// the above values might be negative so we need to use Python-like
// flooring integer division to be compatible with get_conv_output.
// note: this macro implements Python's // for negative x only
#define _CONV_FLOORDIV_X(x,y) ((x < 0) ? (- ((-x) / y) - (((-x) % y) == 0 ? 0 : 1)) : (x / y))
const
size_t
topHeight
=
_CONV_FLOORDIV_X
(
topHeightNoDH
,
dH
)
+
1
;
const
size_t
topWidth
=
_CONV_FLOORDIV_X
(
topWidthNoDW
,
dW
)
+
1
;
const
size_t
topDepth
=
_CONV_FLOORDIV_X
(
topDepthNoDD
,
dD
)
+
1
;
#undef _CONV_FLOORDIV
if
(
batchSize
!=
PyGpuArray_DIMS
(
top
)[
0
]
||
if
(
batchSize
!=
PyGpuArray_DIMS
(
top
)[
0
]
||
nFilters
!=
PyGpuArray_DIMS
(
top
)[
1
]
||
nFilters
!=
PyGpuArray_DIMS
(
top
)[
1
]
||
topHeight
!=
PyGpuArray_DIMS
(
top
)[
2
]
||
topHeight
!=
PyGpuArray_DIMS
(
top
)[
2
]
||
...
@@ -479,6 +487,17 @@ PyGpuArrayObject* corr3dMM(PyGpuArrayObject *const bottom,
...
@@ -479,6 +487,17 @@ PyGpuArrayObject* corr3dMM(PyGpuArrayObject *const bottom,
PyGpuArrayObject
*
output
;
PyGpuArrayObject
*
output
;
if
(
direction
==
0
)
{
// forward pass
if
(
direction
==
0
)
{
// forward pass
output
=
top
;
output
=
top
;
if
(
batchSize
==
0
||
nChannels
==
0
||
nFilters
==
0
)
{
err
=
GpuArray_memset
(
&
output
->
ga
,
0
);
if
(
err
!=
GA_NO_ERROR
)
{
PyErr_Format
(
PyExc_RuntimeError
,
"GpuCorr3dMM could not fill the output with zeros: %d"
,
err
);
Py_DECREF
(
col
);
return
NULL
;
}
Py_DECREF
(
col
);
return
output
;
}
// valid correlation: im3d2col, then gemm
// valid correlation: im3d2col, then gemm
// Iterate over batch
// Iterate over batch
for
(
size_t
n
=
0
;
n
<
batchSize
;
n
++
)
{
for
(
size_t
n
=
0
;
n
<
batchSize
;
n
++
)
{
...
@@ -530,6 +549,17 @@ PyGpuArrayObject* corr3dMM(PyGpuArrayObject *const bottom,
...
@@ -530,6 +549,17 @@ PyGpuArrayObject* corr3dMM(PyGpuArrayObject *const bottom,
}
}
else
if
(
direction
==
1
)
{
// backprop wrt. weights
else
if
(
direction
==
1
)
{
// backprop wrt. weights
output
=
weight
;
output
=
weight
;
if
(
batchSize
==
0
||
nChannels
==
0
||
nFilters
==
0
)
{
err
=
GpuArray_memset
(
&
output
->
ga
,
0
);
if
(
err
!=
GA_NO_ERROR
)
{
PyErr_Format
(
PyExc_RuntimeError
,
"GpuCorr3dMM grad wrt. weights could not fill the output with zeros: %d"
,
err
);
Py_DECREF
(
col
);
return
NULL
;
}
Py_DECREF
(
col
);
return
output
;
}
// valid convolution: im3col, then gemm
// valid convolution: im3col, then gemm
// Iterate over batch
// Iterate over batch
for
(
size_t
n
=
0
;
n
<
batchSize
;
n
++
)
{
for
(
size_t
n
=
0
;
n
<
batchSize
;
n
++
)
{
...
@@ -581,9 +611,29 @@ PyGpuArrayObject* corr3dMM(PyGpuArrayObject *const bottom,
...
@@ -581,9 +611,29 @@ PyGpuArrayObject* corr3dMM(PyGpuArrayObject *const bottom,
return
NULL
;
return
NULL
;
}
}
}
}
if
(
batchSize
==
0
)
{
err
=
GpuArray_memset
(
&
weight
->
ga
,
0
);
if
(
err
!=
GA_NO_ERROR
)
{
PyErr_Format
(
PyExc_RuntimeError
,
"GpuCorr3dMM grad weights could not fill the output with zeros: %d"
,
err
);
Py_DECREF
(
col
);
return
NULL
;
}
}
}
}
else
if
(
direction
==
2
)
{
// backprop wrt. inputs
else
if
(
direction
==
2
)
{
// backprop wrt. inputs
output
=
bottom
;
output
=
bottom
;
if
(
batchSize
==
0
||
nChannels
==
0
||
nFilters
==
0
)
{
err
=
GpuArray_memset
(
&
output
->
ga
,
0
);
if
(
err
!=
GA_NO_ERROR
)
{
PyErr_Format
(
PyExc_RuntimeError
,
"GpuCorr3dMM grad wrt. inputs could not fill the output with zeros: %d"
,
err
);
Py_DECREF
(
col
);
return
NULL
;
}
Py_DECREF
(
col
);
return
output
;
}
// full convolution: gemm, then col2im3d
// full convolution: gemm, then col2im3d
// Iterate over batch
// Iterate over batch
for
(
size_t
n
=
0
;
n
<
batchSize
;
n
++
)
{
for
(
size_t
n
=
0
;
n
<
batchSize
;
n
++
)
{
...
...
theano/gpuarray/corr_gemm.c
浏览文件 @
c072d669
...
@@ -360,8 +360,15 @@ PyGpuArrayObject* corrMM(PyGpuArrayObject *const bottom,
...
@@ -360,8 +360,15 @@ PyGpuArrayObject* corrMM(PyGpuArrayObject *const bottom,
const
size_t
dil_kH
=
(
kH
-
1
)
*
dilH
+
1
;
const
size_t
dil_kH
=
(
kH
-
1
)
*
dilH
+
1
;
const
size_t
dil_kW
=
(
kW
-
1
)
*
dilW
+
1
;
const
size_t
dil_kW
=
(
kW
-
1
)
*
dilW
+
1
;
// top: (batchSize, nFilters, topHeight, topWidth)
// top: (batchSize, nFilters, topHeight, topWidth)
const
size_t
topHeight
=
(
bottomHeight
+
2
*
padH
-
dil_kH
)
/
dH
+
1
;
const
size_t
topHeightNoDH
=
(
bottomHeight
+
2
*
padH
-
dil_kH
);
const
size_t
topWidth
=
(
bottomWidth
+
2
*
padW
-
dil_kW
)
/
dW
+
1
;
const
size_t
topWidthNoDW
=
(
bottomWidth
+
2
*
padW
-
dil_kW
);
// the above values might be negative so we need to use Python-like
// flooring integer division to be compatible with get_conv_output.
// note: this macro implements Python's // for negative x only
#define _CONV_FLOORDIV_X(x,y) ((x < 0) ? (- ((-x) / y) - (((-x) % y) == 0 ? 0 : 1)) : (x / y))
const
size_t
topHeight
=
_CONV_FLOORDIV_X
(
topHeightNoDH
,
dH
)
+
1
;
const
size_t
topWidth
=
_CONV_FLOORDIV_X
(
topWidthNoDW
,
dW
)
+
1
;
#undef _CONV_FLOORDIV
if
(
batchSize
!=
PyGpuArray_DIMS
(
top
)[
0
]
||
if
(
batchSize
!=
PyGpuArray_DIMS
(
top
)[
0
]
||
nFilters
!=
PyGpuArray_DIMS
(
top
)[
1
]
||
nFilters
!=
PyGpuArray_DIMS
(
top
)[
1
]
||
topHeight
!=
PyGpuArray_DIMS
(
top
)[
2
]
||
topHeight
!=
PyGpuArray_DIMS
(
top
)[
2
]
||
...
@@ -411,6 +418,17 @@ PyGpuArrayObject* corrMM(PyGpuArrayObject *const bottom,
...
@@ -411,6 +418,17 @@ PyGpuArrayObject* corrMM(PyGpuArrayObject *const bottom,
PyGpuArrayObject
*
output
;
PyGpuArrayObject
*
output
;
if
(
direction
==
0
)
{
// forward pass
if
(
direction
==
0
)
{
// forward pass
output
=
top
;
output
=
top
;
if
(
batchSize
==
0
||
nChannels
==
0
||
nFilters
==
0
)
{
err
=
GpuArray_memset
(
&
output
->
ga
,
0
);
if
(
err
!=
GA_NO_ERROR
)
{
PyErr_Format
(
PyExc_RuntimeError
,
"GpuCorrMM could not fill the output with zeros: %d"
,
err
);
Py_DECREF
(
col
);
return
NULL
;
}
Py_DECREF
(
col
);
return
output
;
}
// valid correlation: im2col, then gemm
// valid correlation: im2col, then gemm
// Iterate over batch
// Iterate over batch
for
(
size_t
n
=
0
;
n
<
batchSize
;
n
++
)
{
for
(
size_t
n
=
0
;
n
<
batchSize
;
n
++
)
{
...
@@ -462,6 +480,17 @@ PyGpuArrayObject* corrMM(PyGpuArrayObject *const bottom,
...
@@ -462,6 +480,17 @@ PyGpuArrayObject* corrMM(PyGpuArrayObject *const bottom,
}
}
else
if
(
direction
==
1
)
{
// backprop wrt. weights
else
if
(
direction
==
1
)
{
// backprop wrt. weights
output
=
weight
;
output
=
weight
;
if
(
batchSize
==
0
||
nChannels
==
0
||
nFilters
==
0
)
{
err
=
GpuArray_memset
(
&
output
->
ga
,
0
);
if
(
err
!=
GA_NO_ERROR
)
{
PyErr_Format
(
PyExc_RuntimeError
,
"GpuCorrMM grad wrt. weights could not fill the output with zeros: %d"
,
err
);
Py_DECREF
(
col
);
return
NULL
;
}
Py_DECREF
(
col
);
return
output
;
}
// valid convolution: im2col, then gemm
// valid convolution: im2col, then gemm
// Iterate over batch
// Iterate over batch
for
(
size_t
n
=
0
;
n
<
batchSize
;
n
++
)
{
for
(
size_t
n
=
0
;
n
<
batchSize
;
n
++
)
{
...
@@ -516,6 +545,17 @@ PyGpuArrayObject* corrMM(PyGpuArrayObject *const bottom,
...
@@ -516,6 +545,17 @@ PyGpuArrayObject* corrMM(PyGpuArrayObject *const bottom,
}
}
else
if
(
direction
==
2
)
{
// backprop wrt. inputs
else
if
(
direction
==
2
)
{
// backprop wrt. inputs
output
=
bottom
;
output
=
bottom
;
if
(
batchSize
==
0
||
nChannels
==
0
||
nFilters
==
0
)
{
err
=
GpuArray_memset
(
&
output
->
ga
,
0
);
if
(
err
!=
GA_NO_ERROR
)
{
PyErr_Format
(
PyExc_RuntimeError
,
"GpuCorrMM grad wrt. inputs could not fill the output with zeros: %d"
,
err
);
Py_DECREF
(
col
);
return
NULL
;
}
Py_DECREF
(
col
);
return
output
;
}
// full convolution: gemm, then col2im
// full convolution: gemm, then col2im
// Iterate over batch
// Iterate over batch
for
(
size_t
n
=
0
;
n
<
batchSize
;
n
++
)
{
for
(
size_t
n
=
0
;
n
<
batchSize
;
n
++
)
{
...
...
theano/gpuarray/dnn.py
浏览文件 @
c072d669
...
@@ -24,7 +24,8 @@ from theano.tensor.nnet.abstract_conv import (AbstractConv2d,
...
@@ -24,7 +24,8 @@ from theano.tensor.nnet.abstract_conv import (AbstractConv2d,
AbstractConv3d
,
AbstractConv3d
,
AbstractConv3d_gradWeights
,
AbstractConv3d_gradWeights
,
AbstractConv3d_gradInputs
,
AbstractConv3d_gradInputs
,
get_conv_output_shape
)
get_conv_output_shape
,
assert_conv_shape
)
from
theano.tensor.signal.pool
import
(
from
theano.tensor.signal.pool
import
(
Pool
,
MaxPoolGrad
,
AveragePoolGrad
)
Pool
,
MaxPoolGrad
,
AveragePoolGrad
)
from
.
import
pygpu
from
.
import
pygpu
...
@@ -979,11 +980,12 @@ def dnn_conv(img, kerns, border_mode='valid', subsample=(1, 1),
...
@@ -979,11 +980,12 @@ def dnn_conv(img, kerns, border_mode='valid', subsample=(1, 1),
# that would be flipped by conv_mode='conv' in GpuDnnConvGradW.
# that would be flipped by conv_mode='conv' in GpuDnnConvGradW.
kerns
=
kerns
[:,
:,
::
-
1
,
::
-
1
]
kerns
=
kerns
[:,
:,
::
-
1
,
::
-
1
]
kerns
=
gpu_contiguous
(
kerns
.
dimshuffle
(
1
,
0
,
2
,
3
))
kerns
=
gpu_contiguous
(
kerns
.
dimshuffle
(
1
,
0
,
2
,
3
))
shape2
=
shape_i
(
img
,
2
,
fgraph
)
-
shape_i
(
kerns
,
2
,
fgraph
)
+
1
out_shp
=
(
shape_i
(
kerns
,
1
,
fgraph
),
shape3
=
shape_i
(
img
,
3
,
fgraph
)
-
shape_i
(
kerns
,
3
,
fgraph
)
+
1
shape_i
(
img
,
1
,
fgraph
),
out
=
gpu_alloc_empty
(
ctx_name
,
dtype
=
img
.
dtype
)(
shape_i
(
img
,
2
,
fgraph
)
-
shape_i
(
kerns
,
2
,
fgraph
)
+
1
,
shape_i
(
kerns
,
1
,
fgraph
),
shape_i
(
img
,
3
,
fgraph
)
-
shape_i
(
kerns
,
3
,
fgraph
)
+
1
)
shape_i
(
img
,
1
,
fgraph
),
shape2
,
shape3
)
out_shp
=
assert_conv_shape
(
out_shp
)
out
=
gpu_alloc_empty
(
ctx_name
,
dtype
=
img
.
dtype
)(
*
out_shp
)
desc
=
GpuDnnConvDesc
(
border_mode
=
'valid'
,
subsample
=
(
1
,
1
),
desc
=
GpuDnnConvDesc
(
border_mode
=
'valid'
,
subsample
=
(
1
,
1
),
conv_mode
=
'cross'
,
precision
=
precision
)(
out
.
shape
)
conv_mode
=
'cross'
,
precision
=
precision
)(
out
.
shape
)
conv
=
gpu_dnn_conv_gradW
()(
img
,
kerns
,
out
,
desc
)
conv
=
gpu_dnn_conv_gradW
()(
img
,
kerns
,
out
,
desc
)
...
@@ -997,11 +999,12 @@ def dnn_conv(img, kerns, border_mode='valid', subsample=(1, 1),
...
@@ -997,11 +999,12 @@ def dnn_conv(img, kerns, border_mode='valid', subsample=(1, 1),
img
=
gpu_contiguous
(
img
)
# cudnn v2 rc3 need contiguous data
img
=
gpu_contiguous
(
img
)
# cudnn v2 rc3 need contiguous data
kerns
=
gpu_contiguous
(
kerns
.
dimshuffle
(
1
,
0
,
2
,
3
))
kerns
=
gpu_contiguous
(
kerns
.
dimshuffle
(
1
,
0
,
2
,
3
))
conv_mode
=
'cross'
if
conv_mode
==
'conv'
else
'conv'
conv_mode
=
'cross'
if
conv_mode
==
'conv'
else
'conv'
shape2
=
shape_i
(
img
,
2
,
fgraph
)
+
shape_i
(
kerns
,
2
,
fgraph
)
-
1
out_shp
=
(
shape_i
(
img
,
0
,
fgraph
),
shape3
=
shape_i
(
img
,
3
,
fgraph
)
+
shape_i
(
kerns
,
3
,
fgraph
)
-
1
shape_i
(
kerns
,
1
,
fgraph
),
out
=
gpu_alloc_empty
(
ctx_name
,
dtype
=
img
.
dtype
)(
shape_i
(
img
,
0
,
fgraph
),
shape_i
(
img
,
2
,
fgraph
)
+
shape_i
(
kerns
,
2
,
fgraph
)
-
1
,
shape_i
(
kerns
,
1
,
fgraph
),
shape_i
(
img
,
3
,
fgraph
)
+
shape_i
(
kerns
,
3
,
fgraph
)
-
1
)
shape2
,
shape3
)
out_shp
=
assert_conv_shape
(
out_shp
)
out
=
gpu_alloc_empty
(
ctx_name
,
dtype
=
img
.
dtype
)(
*
out_shp
)
desc
=
GpuDnnConvDesc
(
border_mode
=
'valid'
,
subsample
=
(
1
,
1
),
desc
=
GpuDnnConvDesc
(
border_mode
=
'valid'
,
subsample
=
(
1
,
1
),
conv_mode
=
conv_mode
,
precision
=
precision
)(
kerns
.
shape
)
conv_mode
=
conv_mode
,
precision
=
precision
)(
kerns
.
shape
)
return
gpu_dnn_conv_gradI
()(
kerns
,
img
,
out
,
desc
)
return
gpu_dnn_conv_gradI
()(
kerns
,
img
,
out
,
desc
)
...
@@ -1021,6 +1024,7 @@ def dnn_conv(img, kerns, border_mode='valid', subsample=(1, 1),
...
@@ -1021,6 +1024,7 @@ def dnn_conv(img, kerns, border_mode='valid', subsample=(1, 1),
out_shp
=
get_conv_output_shape
(
ishape
,
kshape
,
out_shp
=
get_conv_output_shape
(
ishape
,
kshape
,
desc_op
.
border_mode
,
desc_op
.
border_mode
,
desc_op
.
subsample
)
desc_op
.
subsample
)
out_shp
=
assert_conv_shape
(
out_shp
)
out
=
gpu_alloc_empty
(
ctx_name
,
dtype
=
img
.
dtype
)(
*
out_shp
)
out
=
gpu_alloc_empty
(
ctx_name
,
dtype
=
img
.
dtype
)(
*
out_shp
)
return
gpu_dnn_conv
(
algo
=
algo
)(
img
,
kerns
,
out
,
desc
)
return
gpu_dnn_conv
(
algo
=
algo
)(
img
,
kerns
,
out
,
desc
)
...
@@ -1094,12 +1098,13 @@ def dnn_conv3d(img, kerns, border_mode='valid', subsample=(1, 1, 1),
...
@@ -1094,12 +1098,13 @@ def dnn_conv3d(img, kerns, border_mode='valid', subsample=(1, 1, 1),
# that would be flipped by conv_mode='conv' in GpuDnnConvGradW.
# that would be flipped by conv_mode='conv' in GpuDnnConvGradW.
kerns
=
kerns
[:,
:,
::
-
1
,
::
-
1
]
kerns
=
kerns
[:,
:,
::
-
1
,
::
-
1
]
kerns
=
gpu_contiguous
(
kerns
.
dimshuffle
(
1
,
0
,
2
,
3
,
4
))
kerns
=
gpu_contiguous
(
kerns
.
dimshuffle
(
1
,
0
,
2
,
3
,
4
))
shape2
=
shape_i
(
img
,
2
,
fgraph
)
-
shape_i
(
kerns
,
2
,
fgraph
)
+
1
out_shp
=
(
shape_i
(
kerns
,
1
,
fgraph
),
shape3
=
shape_i
(
img
,
3
,
fgraph
)
-
shape_i
(
kerns
,
3
,
fgraph
)
+
1
shape_i
(
img
,
1
,
fgraph
),
shape4
=
shape_i
(
img
,
4
,
fgraph
)
-
shape_i
(
kerns
,
4
,
fgraph
)
+
1
shape_i
(
img
,
2
,
fgraph
)
-
shape_i
(
kerns
,
2
,
fgraph
)
+
1
,
out
=
gpu_alloc_empty
(
ctx_name
,
dtype
=
img
.
dtype
)(
shape_i
(
img
,
3
,
fgraph
)
-
shape_i
(
kerns
,
3
,
fgraph
)
+
1
,
shape_i
(
kerns
,
1
,
fgraph
),
shape_i
(
img
,
4
,
fgraph
)
-
shape_i
(
kerns
,
4
,
fgraph
)
+
1
)
shape_i
(
img
,
1
,
fgraph
),
shape2
,
shape3
,
shape4
)
out_shp
=
assert_conv_shape
(
out_shp
)
out
=
gpu_alloc_empty
(
ctx_name
,
dtype
=
img
.
dtype
)(
*
out_shp
)
desc
=
GpuDnnConvDesc
(
border_mode
=
'valid'
,
subsample
=
(
1
,
1
,
1
),
desc
=
GpuDnnConvDesc
(
border_mode
=
'valid'
,
subsample
=
(
1
,
1
,
1
),
conv_mode
=
'cross'
,
precision
=
precision
)(
out
.
shape
)
conv_mode
=
'cross'
,
precision
=
precision
)(
out
.
shape
)
conv
=
gpu_dnn_conv_gradW
()(
img
,
kerns
,
out
,
desc
)
conv
=
gpu_dnn_conv_gradW
()(
img
,
kerns
,
out
,
desc
)
...
@@ -1113,12 +1118,13 @@ def dnn_conv3d(img, kerns, border_mode='valid', subsample=(1, 1, 1),
...
@@ -1113,12 +1118,13 @@ def dnn_conv3d(img, kerns, border_mode='valid', subsample=(1, 1, 1),
img
=
gpu_contiguous
(
img
)
# cudnn v2 rc3 need contiguous data
img
=
gpu_contiguous
(
img
)
# cudnn v2 rc3 need contiguous data
kerns
=
gpu_contiguous
(
kerns
.
dimshuffle
(
1
,
0
,
2
,
3
,
4
))
kerns
=
gpu_contiguous
(
kerns
.
dimshuffle
(
1
,
0
,
2
,
3
,
4
))
conv_mode
=
'cross'
if
conv_mode
==
'conv'
else
'conv'
conv_mode
=
'cross'
if
conv_mode
==
'conv'
else
'conv'
shape2
=
shape_i
(
img
,
2
,
fgraph
)
+
shape_i
(
kerns
,
2
,
fgraph
)
-
1
out_shp
=
(
shape_i
(
img
,
0
,
fgraph
),
shape3
=
shape_i
(
img
,
3
,
fgraph
)
+
shape_i
(
kerns
,
3
,
fgraph
)
-
1
shape_i
(
kerns
,
1
,
fgraph
),
shape4
=
shape_i
(
img
,
4
,
fgraph
)
+
shape_i
(
kerns
,
4
,
fgraph
)
-
1
shape_i
(
img
,
2
,
fgraph
)
+
shape_i
(
kerns
,
2
,
fgraph
)
-
1
,
out
=
gpu_alloc_empty
(
ctx_name
,
dtype
=
img
.
dtype
)(
shape_i
(
img
,
0
,
fgraph
),
shape_i
(
img
,
3
,
fgraph
)
+
shape_i
(
kerns
,
3
,
fgraph
)
-
1
,
shape_i
(
kerns
,
1
,
fgraph
),
shape_i
(
img
,
4
,
fgraph
)
+
shape_i
(
kerns
,
4
,
fgraph
)
-
1
)
shape2
,
shape3
,
shape4
)
out_shp
=
assert_conv_shape
(
out_shp
)
out
=
gpu_alloc_empty
(
ctx_name
,
dtype
=
img
.
dtype
)(
*
out_shp
)
desc
=
GpuDnnConvDesc
(
border_mode
=
'valid'
,
subsample
=
(
1
,
1
,
1
),
desc
=
GpuDnnConvDesc
(
border_mode
=
'valid'
,
subsample
=
(
1
,
1
,
1
),
conv_mode
=
conv_mode
,
precision
=
precision
)(
kerns
.
shape
)
conv_mode
=
conv_mode
,
precision
=
precision
)(
kerns
.
shape
)
return
gpu_dnn_conv_gradI
()(
kerns
,
img
,
out
,
desc
)
return
gpu_dnn_conv_gradI
()(
kerns
,
img
,
out
,
desc
)
...
@@ -1138,6 +1144,7 @@ def dnn_conv3d(img, kerns, border_mode='valid', subsample=(1, 1, 1),
...
@@ -1138,6 +1144,7 @@ def dnn_conv3d(img, kerns, border_mode='valid', subsample=(1, 1, 1),
out_shp
=
get_conv_output_shape
(
ishape
,
kshape
,
out_shp
=
get_conv_output_shape
(
ishape
,
kshape
,
desc_op
.
border_mode
,
desc_op
.
border_mode
,
desc_op
.
subsample
)
desc_op
.
subsample
)
out_shp
=
assert_conv_shape
(
out_shp
)
out
=
gpu_alloc_empty
(
ctx_name
,
dtype
=
img
.
dtype
)(
*
out_shp
)
out
=
gpu_alloc_empty
(
ctx_name
,
dtype
=
img
.
dtype
)(
*
out_shp
)
return
gpu_dnn_conv
(
algo
=
algo
)(
img
,
kerns
,
out
,
desc
)
return
gpu_dnn_conv
(
algo
=
algo
)(
img
,
kerns
,
out
,
desc
)
...
...
theano/gpuarray/dnn_fwd.c
浏览文件 @
c072d669
...
@@ -39,11 +39,6 @@ APPLY_SPECIFIC(conv_fwd)(PyGpuArrayObject *input, PyGpuArrayObject *kerns,
...
@@ -39,11 +39,6 @@ APPLY_SPECIFIC(conv_fwd)(PyGpuArrayObject *input, PyGpuArrayObject *kerns,
return
1
;
return
1
;
}
}
if
(
c_set_tensorNd
(
input
,
APPLY_SPECIFIC
(
input
))
==
-
1
)
return
1
;
if
(
c_set_filter
(
kerns
,
APPLY_SPECIFIC
(
kerns
))
==
-
1
)
return
1
;
switch
(
input
->
ga
.
typecode
)
{
switch
(
input
->
ga
.
typecode
)
{
case
GA_DOUBLE
:
case
GA_DOUBLE
:
alpha_p
=
(
void
*
)
&
alpha
;
alpha_p
=
(
void
*
)
&
alpha
;
...
@@ -71,6 +66,20 @@ APPLY_SPECIFIC(conv_fwd)(PyGpuArrayObject *input, PyGpuArrayObject *kerns,
...
@@ -71,6 +66,20 @@ APPLY_SPECIFIC(conv_fwd)(PyGpuArrayObject *input, PyGpuArrayObject *kerns,
return
1
;
return
1
;
#endif
#endif
if
(
PyGpuArray_DIMS
(
input
)[
0
]
==
0
||
PyGpuArray_DIMS
(
kerns
)[
0
]
==
0
||
PyGpuArray_DIMS
(
kerns
)[
1
]
==
0
)
{
int
err2
=
GpuArray_memset
(
&
(
*
output
)
->
ga
,
0
);
if
(
err2
!=
GA_NO_ERROR
)
{
PyErr_Format
(
PyExc_RuntimeError
,
"GpuDnnConv could not fill the output with zeros: %d"
,
err2
);
return
1
;
}
return
0
;
}
if
(
c_set_tensorNd
(
input
,
APPLY_SPECIFIC
(
input
))
==
-
1
)
return
1
;
if
(
c_set_filter
(
kerns
,
APPLY_SPECIFIC
(
kerns
))
==
-
1
)
return
1
;
if
(
c_set_tensorNd
(
*
output
,
APPLY_SPECIFIC
(
output
))
==
-
1
)
if
(
c_set_tensorNd
(
*
output
,
APPLY_SPECIFIC
(
output
))
==
-
1
)
return
1
;
return
1
;
...
...
theano/gpuarray/dnn_gi.c
浏览文件 @
c072d669
...
@@ -38,11 +38,6 @@ APPLY_SPECIFIC(conv_gi)(PyGpuArrayObject *kerns, PyGpuArrayObject *output,
...
@@ -38,11 +38,6 @@ APPLY_SPECIFIC(conv_gi)(PyGpuArrayObject *kerns, PyGpuArrayObject *output,
return
1
;
return
1
;
}
}
if
(
c_set_tensorNd
(
output
,
APPLY_SPECIFIC
(
output
))
==
-
1
)
return
1
;
if
(
c_set_filter
(
kerns
,
APPLY_SPECIFIC
(
kerns
))
==
-
1
)
return
1
;
switch
(
im
->
ga
.
typecode
)
{
switch
(
im
->
ga
.
typecode
)
{
case
GA_DOUBLE
:
case
GA_DOUBLE
:
alpha_p
=
(
void
*
)
&
alpha
;
alpha_p
=
(
void
*
)
&
alpha
;
...
@@ -70,6 +65,20 @@ APPLY_SPECIFIC(conv_gi)(PyGpuArrayObject *kerns, PyGpuArrayObject *output,
...
@@ -70,6 +65,20 @@ APPLY_SPECIFIC(conv_gi)(PyGpuArrayObject *kerns, PyGpuArrayObject *output,
return
1
;
return
1
;
#endif
#endif
if
(
PyGpuArray_DIMS
(
im
)[
0
]
==
0
||
PyGpuArray_DIMS
(
kerns
)[
0
]
==
0
||
PyGpuArray_DIMS
(
kerns
)[
1
]
==
0
)
{
int
err2
=
GpuArray_memset
(
&
(
*
input
)
->
ga
,
0
);
if
(
err2
!=
GA_NO_ERROR
)
{
PyErr_Format
(
PyExc_RuntimeError
,
"GpuDnnConv grad wrt. inputs could not fill the output with zeros: %d"
,
err2
);
return
1
;
}
return
0
;
}
if
(
c_set_tensorNd
(
output
,
APPLY_SPECIFIC
(
output
))
==
-
1
)
return
1
;
if
(
c_set_filter
(
kerns
,
APPLY_SPECIFIC
(
kerns
))
==
-
1
)
return
1
;
if
(
c_set_tensorNd
(
*
input
,
APPLY_SPECIFIC
(
input
))
==
-
1
)
if
(
c_set_tensorNd
(
*
input
,
APPLY_SPECIFIC
(
input
))
==
-
1
)
return
1
;
return
1
;
...
@@ -77,6 +86,48 @@ APPLY_SPECIFIC(conv_gi)(PyGpuArrayObject *kerns, PyGpuArrayObject *output,
...
@@ -77,6 +86,48 @@ APPLY_SPECIFIC(conv_gi)(PyGpuArrayObject *kerns, PyGpuArrayObject *output,
cuda_enter
(
c
->
ctx
);
cuda_enter
(
c
->
ctx
);
int
expected_output_dims
[
5
]
=
{
0
};
err
=
cudnnGetConvolutionNdForwardOutputDim
(
desc
,
APPLY_SPECIFIC
(
input
),
APPLY_SPECIFIC
(
kerns
),
PyGpuArray_NDIM
(
im
),
expected_output_dims
);
if
(
err
!=
CUDNN_STATUS_SUCCESS
)
{
PyErr_Format
(
PyExc_RuntimeError
,
"error computing convolution output dim: %s"
,
cudnnGetErrorString
(
err
));
cuda_exit
(
c
->
ctx
);
return
1
;
}
if
(
PyGpuArray_NDIM
(
im
)
==
4
)
{
if
((
PyGpuArray_DIMS
(
output
)[
0
]
!=
expected_output_dims
[
0
])
||
(
PyGpuArray_DIMS
(
output
)[
1
]
!=
expected_output_dims
[
1
])
||
(
PyGpuArray_DIMS
(
output
)[
2
]
!=
expected_output_dims
[
2
])
||
(
PyGpuArray_DIMS
(
output
)[
3
]
!=
expected_output_dims
[
3
]))
{
PyErr_Format
(
PyExc_ValueError
,
"impossible convolution output dim: expected %ldx%ldx%ldx%ld"
" but received gradient with shape %ldx%ldx%ldx%ld"
,
expected_output_dims
[
0
],
expected_output_dims
[
1
],
expected_output_dims
[
2
],
expected_output_dims
[
3
],
PyGpuArray_DIMS
(
output
)[
0
],
PyGpuArray_DIMS
(
output
)[
1
],
PyGpuArray_DIMS
(
output
)[
2
],
PyGpuArray_DIMS
(
output
)[
3
]);
cuda_exit
(
c
->
ctx
);
return
1
;
}
}
else
if
(
PyGpuArray_NDIM
(
im
)
==
5
)
{
if
((
PyGpuArray_DIMS
(
output
)[
0
]
!=
expected_output_dims
[
0
])
||
(
PyGpuArray_DIMS
(
output
)[
1
]
!=
expected_output_dims
[
1
])
||
(
PyGpuArray_DIMS
(
output
)[
2
]
!=
expected_output_dims
[
2
])
||
(
PyGpuArray_DIMS
(
output
)[
3
]
!=
expected_output_dims
[
3
])
||
(
PyGpuArray_DIMS
(
output
)[
4
]
!=
expected_output_dims
[
4
]))
{
PyErr_Format
(
PyExc_ValueError
,
"impossible convolution output dim: expected %ldx%ldx%ldx%ldx%ld"
" but received gradient with shape %ldx%ldx%ldx%ldx%ld"
,
expected_output_dims
[
0
],
expected_output_dims
[
1
],
expected_output_dims
[
2
],
expected_output_dims
[
3
],
expected_output_dims
[
4
],
PyGpuArray_DIMS
(
output
)[
0
],
PyGpuArray_DIMS
(
output
)[
1
],
PyGpuArray_DIMS
(
output
)[
2
],
PyGpuArray_DIMS
(
output
)[
3
],
PyGpuArray_DIMS
(
output
)[
4
]);
cuda_exit
(
c
->
ctx
);
return
1
;
}
}
#ifdef CHOOSE_ALGO
#ifdef CHOOSE_ALGO
#ifndef CHOOSE_ONCE
#ifndef CHOOSE_ONCE
reuse_algo
=
1
;
reuse_algo
=
1
;
...
...
theano/gpuarray/dnn_gw.c
浏览文件 @
c072d669
...
@@ -38,11 +38,6 @@ APPLY_SPECIFIC(conv_gw)(PyGpuArrayObject *input, PyGpuArrayObject *output,
...
@@ -38,11 +38,6 @@ APPLY_SPECIFIC(conv_gw)(PyGpuArrayObject *input, PyGpuArrayObject *output,
return
1
;
return
1
;
}
}
if
(
c_set_tensorNd
(
input
,
APPLY_SPECIFIC
(
input
))
==
-
1
)
return
1
;
if
(
c_set_tensorNd
(
output
,
APPLY_SPECIFIC
(
output
))
==
-
1
)
return
1
;
switch
(
input
->
ga
.
typecode
)
{
switch
(
input
->
ga
.
typecode
)
{
case
GA_DOUBLE
:
case
GA_DOUBLE
:
alpha_p
=
(
void
*
)
&
alpha
;
alpha_p
=
(
void
*
)
&
alpha
;
...
@@ -70,6 +65,20 @@ APPLY_SPECIFIC(conv_gw)(PyGpuArrayObject *input, PyGpuArrayObject *output,
...
@@ -70,6 +65,20 @@ APPLY_SPECIFIC(conv_gw)(PyGpuArrayObject *input, PyGpuArrayObject *output,
return
1
;
return
1
;
#endif
#endif
if
(
PyGpuArray_DIMS
(
input
)[
0
]
==
0
||
PyGpuArray_DIMS
(
km
)[
0
]
==
0
||
PyGpuArray_DIMS
(
km
)[
1
]
==
0
)
{
int
err2
=
GpuArray_memset
(
&
(
*
kerns
)
->
ga
,
0
);
if
(
err2
!=
GA_NO_ERROR
)
{
PyErr_Format
(
PyExc_RuntimeError
,
"GpuDnnConv grad wrt. weights could not fill the output with zeros: %d"
,
err2
);
return
1
;
}
return
0
;
}
if
(
c_set_tensorNd
(
input
,
APPLY_SPECIFIC
(
input
))
==
-
1
)
return
1
;
if
(
c_set_tensorNd
(
output
,
APPLY_SPECIFIC
(
output
))
==
-
1
)
return
1
;
if
(
c_set_filter
(
*
kerns
,
APPLY_SPECIFIC
(
kerns
))
==
-
1
)
if
(
c_set_filter
(
*
kerns
,
APPLY_SPECIFIC
(
kerns
))
==
-
1
)
return
1
;
return
1
;
...
@@ -77,6 +86,48 @@ APPLY_SPECIFIC(conv_gw)(PyGpuArrayObject *input, PyGpuArrayObject *output,
...
@@ -77,6 +86,48 @@ APPLY_SPECIFIC(conv_gw)(PyGpuArrayObject *input, PyGpuArrayObject *output,
cuda_enter
(
c
->
ctx
);
cuda_enter
(
c
->
ctx
);
int
expected_output_dims
[
5
]
=
{
0
};
err
=
cudnnGetConvolutionNdForwardOutputDim
(
desc
,
APPLY_SPECIFIC
(
input
),
APPLY_SPECIFIC
(
kerns
),
PyGpuArray_NDIM
(
input
),
expected_output_dims
);
if
(
err
!=
CUDNN_STATUS_SUCCESS
)
{
PyErr_Format
(
PyExc_RuntimeError
,
"error computing convolution output dim: %s"
,
cudnnGetErrorString
(
err
));
cuda_exit
(
c
->
ctx
);
return
1
;
}
if
(
PyGpuArray_NDIM
(
input
)
==
4
)
{
if
((
PyGpuArray_DIMS
(
output
)[
0
]
!=
expected_output_dims
[
0
])
||
(
PyGpuArray_DIMS
(
output
)[
1
]
!=
expected_output_dims
[
1
])
||
(
PyGpuArray_DIMS
(
output
)[
2
]
!=
expected_output_dims
[
2
])
||
(
PyGpuArray_DIMS
(
output
)[
3
]
!=
expected_output_dims
[
3
]))
{
PyErr_Format
(
PyExc_ValueError
,
"impossible convolution output dim: expected %ldx%ldx%dx%ld"
" but received gradient with shape %ldx%ldx%dx%ld"
,
expected_output_dims
[
0
],
expected_output_dims
[
1
],
expected_output_dims
[
2
],
expected_output_dims
[
3
],
PyGpuArray_DIMS
(
output
)[
0
],
PyGpuArray_DIMS
(
output
)[
1
],
PyGpuArray_DIMS
(
output
)[
2
],
PyGpuArray_DIMS
(
output
)[
3
]);
cuda_exit
(
c
->
ctx
);
return
1
;
}
}
else
if
(
PyGpuArray_NDIM
(
input
)
==
5
)
{
if
((
PyGpuArray_DIMS
(
output
)[
0
]
!=
expected_output_dims
[
0
])
||
(
PyGpuArray_DIMS
(
output
)[
1
]
!=
expected_output_dims
[
1
])
||
(
PyGpuArray_DIMS
(
output
)[
2
]
!=
expected_output_dims
[
2
])
||
(
PyGpuArray_DIMS
(
output
)[
3
]
!=
expected_output_dims
[
3
])
||
(
PyGpuArray_DIMS
(
output
)[
4
]
!=
expected_output_dims
[
4
]))
{
PyErr_Format
(
PyExc_ValueError
,
"impossible convolution output dim: expected %ldx%ldx%ldx%ldx%ld"
" but received gradient with shape %ldx%ldx%ldx%ldx%ld"
,
expected_output_dims
[
0
],
expected_output_dims
[
1
],
expected_output_dims
[
2
],
expected_output_dims
[
3
],
expected_output_dims
[
4
],
PyGpuArray_DIMS
(
output
)[
0
],
PyGpuArray_DIMS
(
output
)[
1
],
PyGpuArray_DIMS
(
output
)[
2
],
PyGpuArray_DIMS
(
output
)[
3
],
PyGpuArray_DIMS
(
output
)[
4
]);
cuda_exit
(
c
->
ctx
);
return
1
;
}
}
#ifdef CHOOSE_ALGO
#ifdef CHOOSE_ALGO
#ifndef CHOOSE_ONCE
#ifndef CHOOSE_ONCE
reuse_algo
=
1
;
reuse_algo
=
1
;
...
...
theano/gpuarray/tests/test_abstractconv.py
浏览文件 @
c072d669
from
__future__
import
absolute_import
,
print_function
,
division
from
__future__
import
absolute_import
,
print_function
,
division
from
nose.plugins.skip
import
SkipTest
from
nose.plugins.skip
import
SkipTest
from
nose.tools
import
assert_raises
import
numpy
import
numpy
...
@@ -49,6 +50,31 @@ class TestDnnConv2d(test_abstract_conv.BaseTestConv2d):
...
@@ -49,6 +50,31 @@ class TestDnnConv2d(test_abstract_conv.BaseTestConv2d):
provide_shape
=
provide_shape
,
border_mode
=
b
,
provide_shape
=
provide_shape
,
border_mode
=
b
,
filter_flip
=
flip
,
target_op
=
GpuDnnConvGradI
)
filter_flip
=
flip
,
target_op
=
GpuDnnConvGradI
)
def
tcase_gi
(
self
,
i
,
f
,
o
,
s
,
b
,
flip
,
provide_shape
,
fd
=
(
1
,
1
),
expect_error
=
False
):
if
not
dnn_available
(
test_ctx_name
):
raise
SkipTest
(
dnn_available
.
msg
)
if
fd
!=
(
1
,
1
):
raise
SkipTest
(
"Doesn't have CUDNN implementation"
)
mode
=
mode_with_gpu
if
not
expect_error
:
self
.
run_gradinput
(
inputs_shape
=
i
,
filters_shape
=
f
,
output_shape
=
o
,
subsample
=
s
,
verify_grad
=
True
,
mode
=
mode
,
provide_shape
=
provide_shape
,
border_mode
=
b
,
filter_flip
=
flip
,
target_op
=
GpuDnnConvGradI
,
filter_dilation
=
fd
)
else
:
assert_raises
((
RuntimeError
,
ValueError
),
self
.
run_gradinput
,
inputs_shape
=
i
,
filters_shape
=
f
,
output_shape
=
o
,
subsample
=
s
,
verify_grad
=
False
,
mode
=
mode
,
provide_shape
=
provide_shape
,
border_mode
=
b
,
filter_flip
=
flip
,
target_op
=
GpuDnnConvGradI
,
ref
=
None
,
filter_dilation
=
fd
)
class
TestDnnConv3d
(
test_abstract_conv
.
BaseTestConv3d
):
class
TestDnnConv3d
(
test_abstract_conv
.
BaseTestConv3d
):
@classmethod
@classmethod
...
@@ -82,6 +108,31 @@ class TestDnnConv3d(test_abstract_conv.BaseTestConv3d):
...
@@ -82,6 +108,31 @@ class TestDnnConv3d(test_abstract_conv.BaseTestConv3d):
provide_shape
=
provide_shape
,
border_mode
=
b
,
provide_shape
=
provide_shape
,
border_mode
=
b
,
filter_flip
=
flip
,
target_op
=
GpuDnnConvGradI
)
filter_flip
=
flip
,
target_op
=
GpuDnnConvGradI
)
def
tcase_gi
(
self
,
i
,
f
,
o
,
s
,
b
,
flip
,
provide_shape
,
fd
=
(
1
,
1
,
1
),
expect_error
=
False
):
if
not
dnn_available
(
test_ctx_name
):
raise
SkipTest
(
dnn_available
.
msg
)
if
fd
!=
(
1
,
1
,
1
):
raise
SkipTest
(
"Doesn't have CUDNN implementation"
)
mode
=
mode_with_gpu
if
not
expect_error
:
self
.
run_gradinput
(
inputs_shape
=
i
,
filters_shape
=
f
,
output_shape
=
o
,
subsample
=
s
,
verify_grad
=
True
,
mode
=
mode
,
provide_shape
=
provide_shape
,
border_mode
=
b
,
filter_flip
=
flip
,
target_op
=
GpuDnnConvGradI
,
filter_dilation
=
fd
)
else
:
assert_raises
((
RuntimeError
,
ValueError
),
self
.
run_gradinput
,
inputs_shape
=
i
,
filters_shape
=
f
,
output_shape
=
o
,
subsample
=
s
,
verify_grad
=
False
,
mode
=
mode
,
provide_shape
=
provide_shape
,
border_mode
=
b
,
filter_flip
=
flip
,
target_op
=
GpuDnnConvGradI
,
ref
=
None
,
filter_dilation
=
fd
)
class
TestCorrMMConv2d
(
test_abstract_conv
.
BaseTestConv2d
):
class
TestCorrMMConv2d
(
test_abstract_conv
.
BaseTestConv2d
):
@classmethod
@classmethod
...
@@ -115,6 +166,28 @@ class TestCorrMMConv2d(test_abstract_conv.BaseTestConv2d):
...
@@ -115,6 +166,28 @@ class TestCorrMMConv2d(test_abstract_conv.BaseTestConv2d):
target_op
=
GpuCorrMM_gradInputs
,
target_op
=
GpuCorrMM_gradInputs
,
filter_dilation
=
fd
)
filter_dilation
=
fd
)
def
tcase_gi
(
self
,
i
,
f
,
o
,
s
,
b
,
flip
,
provide_shape
,
fd
=
(
1
,
1
),
expect_error
=
False
):
mode
=
self
.
mode
if
not
expect_error
:
self
.
run_gradinput
(
inputs_shape
=
i
,
filters_shape
=
f
,
output_shape
=
o
,
subsample
=
s
,
verify_grad
=
True
,
mode
=
mode
,
provide_shape
=
provide_shape
,
border_mode
=
b
,
filter_flip
=
flip
,
target_op
=
GpuCorrMM_gradInputs
,
filter_dilation
=
fd
)
else
:
assert_raises
(
ValueError
,
self
.
run_gradinput
,
inputs_shape
=
i
,
filters_shape
=
f
,
output_shape
=
o
,
subsample
=
s
,
verify_grad
=
False
,
mode
=
mode
,
provide_shape
=
provide_shape
,
border_mode
=
b
,
filter_flip
=
flip
,
target_op
=
GpuCorrMM_gradInputs
,
ref
=
None
,
filter_dilation
=
fd
)
class
TestCorrMMConv3d
(
test_abstract_conv
.
BaseTestConv3d
):
class
TestCorrMMConv3d
(
test_abstract_conv
.
BaseTestConv3d
):
@classmethod
@classmethod
...
@@ -148,6 +221,28 @@ class TestCorrMMConv3d(test_abstract_conv.BaseTestConv3d):
...
@@ -148,6 +221,28 @@ class TestCorrMMConv3d(test_abstract_conv.BaseTestConv3d):
target_op
=
GpuCorr3dMM_gradInputs
,
target_op
=
GpuCorr3dMM_gradInputs
,
filter_dilation
=
fd
)
filter_dilation
=
fd
)
def
tcase_gi
(
self
,
i
,
f
,
o
,
s
,
b
,
flip
,
provide_shape
,
fd
=
(
1
,
1
,
1
),
expect_error
=
False
):
mode
=
self
.
mode
if
not
expect_error
:
self
.
run_gradinput
(
inputs_shape
=
i
,
filters_shape
=
f
,
output_shape
=
o
,
subsample
=
s
,
verify_grad
=
True
,
mode
=
mode
,
provide_shape
=
provide_shape
,
border_mode
=
b
,
filter_flip
=
flip
,
target_op
=
GpuCorr3dMM_gradInputs
,
filter_dilation
=
fd
)
else
:
assert_raises
(
ValueError
,
self
.
run_gradinput
,
inputs_shape
=
i
,
filters_shape
=
f
,
output_shape
=
o
,
subsample
=
s
,
verify_grad
=
False
,
mode
=
mode
,
provide_shape
=
provide_shape
,
border_mode
=
b
,
filter_flip
=
flip
,
target_op
=
GpuCorr3dMM_gradInputs
,
ref
=
None
,
filter_dilation
=
fd
)
class
TestDnnConvTypes
(
test_abstract_conv
.
TestConvTypes
):
class
TestDnnConvTypes
(
test_abstract_conv
.
TestConvTypes
):
def
setUp
(
self
):
def
setUp
(
self
):
...
...
theano/gpuarray/tests/test_dnn.py
浏览文件 @
c072d669
...
@@ -12,6 +12,7 @@ import theano.tensor as T
...
@@ -12,6 +12,7 @@ import theano.tensor as T
import
theano.tests.unittest_tools
as
utt
import
theano.tests.unittest_tools
as
utt
from
theano.tensor.signal.pool
import
pool_2d
,
pool_3d
from
theano.tensor.signal.pool
import
pool_2d
,
pool_3d
from
theano.tensor.signal.pool
import
Pool
,
MaxPoolGrad
,
AveragePoolGrad
from
theano.tensor.signal.pool
import
Pool
,
MaxPoolGrad
,
AveragePoolGrad
from
theano.tensor.nnet.abstract_conv
import
get_conv_output_shape
from
..
import
dnn
from
..
import
dnn
from
..basic_ops
import
GpuAllocEmpty
from
..basic_ops
import
GpuAllocEmpty
...
@@ -628,56 +629,50 @@ class TestDnnInferShapes(utt.InferShapeTester):
...
@@ -628,56 +629,50 @@ class TestDnnInferShapes(utt.InferShapeTester):
[(
1
,
1
,
1
),
(
2
,
2
,
2
)],
[(
1
,
1
,
1
),
(
2
,
2
,
2
)],
'none'
)
'none'
)
def
_test_conv_gradw
(
self
,
img
,
kerns
,
out
,
img_val
,
kern_vals
,
border_mode
,
conv_mode
,
subsample
):
def
_test_conv_gradw
(
self
,
img
,
topgrad
,
kerns
,
img_shape
,
kerns_shape
,
border_mode
,
conv_mode
,
subsample
):
if
not
dnn
.
dnn_available
(
test_ctx_name
):
if
not
dnn
.
dnn_available
(
test_ctx_name
):
raise
SkipTest
(
dnn
.
dnn_available
.
msg
)
raise
SkipTest
(
dnn
.
dnn_available
.
msg
)
topgrad_shape
=
get_conv_output_shape
(
img_shape
,
kerns_shape
,
border_mode
,
subsample
)
img_val
=
numpy
.
asarray
(
img_val
=
numpy
.
asarray
(
img_val
,
numpy
.
random
.
rand
(
*
img_shape
)
,
dtype
=
theano
.
config
.
floatX
dtype
=
theano
.
config
.
floatX
)
)
kern
_vals
=
numpy
.
asarray
(
topgrad
_vals
=
numpy
.
asarray
(
kern_vals
,
numpy
.
random
.
rand
(
*
topgrad_shape
)
,
dtype
=
theano
.
config
.
floatX
dtype
=
theano
.
config
.
floatX
)
)
temp_img
=
img
.
dimshuffle
(
1
,
0
,
2
,
3
)
kerns_vals
=
numpy
.
zeros
(
kerns_shape
,
dtype
=
theano
.
config
.
floatX
)
temp_kerns
=
kerns
kerns_shape
=
theano
.
shared
(
numpy
.
asarray
(
kerns_shape
))
if
conv_mode
==
'conv'
:
temp_kerns
=
temp_kerns
[:,
:,
::
-
1
,
::
-
1
]
temp_kerns
=
temp_kerns
.
dimshuffle
(
1
,
0
,
2
,
3
)
shape
=
(
kern_vals
.
shape
[
1
],
img_val
.
shape
[
1
],
img_val
.
shape
[
2
]
-
kern_vals
.
shape
[
2
]
+
1
,
img_val
.
shape
[
3
]
-
kern_vals
.
shape
[
3
]
+
1
)
out_vals
=
numpy
.
zeros
(
shape
,
dtype
=
theano
.
config
.
floatX
)
desc
=
dnn
.
GpuDnnConvDesc
(
desc
=
dnn
.
GpuDnnConvDesc
(
border_mode
=
border_mode
,
border_mode
=
border_mode
,
subsample
=
subsample
,
subsample
=
subsample
,
conv_mode
=
conv_mode
,
conv_mode
=
conv_mode
,
precision
=
set_precision
(
theano
.
config
.
floatX
)
precision
=
set_precision
(
theano
.
config
.
floatX
)
)(
out
.
shape
)
)(
kerns_
shape
)
conv_grad_w
=
dnn
.
GpuDnnConvGradW
()(
conv_grad_w
=
dnn
.
GpuDnnConvGradW
()(
temp_
img
,
img
,
t
emp_kerns
,
t
opgrad
,
out
,
kerns
,
desc
,
desc
,
)
)
self
.
_compile_and_check
(
self
.
_compile_and_check
(
[
temp_img
,
temp_kerns
,
out
],
[
img
,
topgrad
,
kerns
],
[
conv_grad_w
],
[
conv_grad_w
],
[
img_val
,
kern_vals
,
out
_vals
],
[
img_val
,
topgrad_vals
,
kerns
_vals
],
dnn
.
GpuDnnConvGradW
dnn
.
GpuDnnConvGradW
)
)
@parameterized.expand
(
product
(
border_modes
,
conv_modes
),
utt
.
custom_name_func
)
@parameterized.expand
(
product
(
border_modes
,
conv_modes
),
utt
.
custom_name_func
)
def
test_conv_gradw
(
self
,
border_mode
,
conv_mode
):
def
test_conv_gradw
(
self
,
border_mode
,
conv_mode
):
self
.
_test_conv_gradw
(
T
.
tensor4
(
'img'
),
self
.
_test_conv_gradw
(
T
.
tensor4
(
'img'
),
T
.
tensor4
(
'topgrad'
),
T
.
tensor4
(
'kerns'
),
T
.
tensor4
(
'kerns'
),
T
.
tensor4
(
'out'
),
(
5
,
2
,
6
,
13
),
numpy
.
random
.
rand
(
2
,
5
,
6
,
8
),
(
1
,
2
,
3
,
7
),
numpy
.
random
.
rand
(
2
,
1
,
5
,
6
),
border_mode
,
border_mode
,
conv_mode
,
conv_mode
,
(
1
,
1
))
(
1
,
1
))
...
...
theano/sandbox/cuda/blas.py
浏览文件 @
c072d669
差异被折叠。
点击展开。
theano/sandbox/cuda/corr3d_gemm.cu
浏览文件 @
c072d669
...
@@ -429,9 +429,17 @@ CudaNdarray* corr3dMM(CudaNdarray *const bottom,
...
@@ -429,9 +429,17 @@ CudaNdarray* corr3dMM(CudaNdarray *const bottom,
const int dil_kW = (kW - 1) * dilW + 1;
const int dil_kW = (kW - 1) * dilW + 1;
const int dil_kD = (kD - 1) * dilD + 1;
const int dil_kD = (kD - 1) * dilD + 1;
// top: (batchSize, nFilters, topHeight, topWidth, topDepth)
// top: (batchSize, nFilters, topHeight, topWidth, topDepth)
const int topHeight = int((bottomHeight + 2*padH - dil_kH) / dH) + 1;
const int topHeightNoDH = (bottomHeight + 2*padH - dil_kH);
const int topWidth = int((bottomWidth + 2*padW - dil_kW) / dW) + 1;
const int topWidthNoDW = (bottomWidth + 2*padW - dil_kW);
const int topDepth = int((bottomDepth + 2*padD - dil_kD) / dD) + 1;
const int topDepthNoDD = (bottomDepth + 2*padD - dil_kD);
// the above values might be negative so we need to use Python-like
// flooring integer division to be compatible with get_conv_output.
// note: this macro implements Python's // for negative x only
#define _CONV_FLOORDIV_X(x,y) ((x < 0) ? (- ((-x) / y) - (((-x) % y) == 0 ? 0 : 1)) : (x / y))
const int topHeight = _CONV_FLOORDIV_X(topHeightNoDH, dH) + 1;
const int topWidth = _CONV_FLOORDIV_X(topWidthNoDW, dW) + 1;
const int topDepth = _CONV_FLOORDIV_X(topDepthNoDD, dD) + 1;
#undef _CONV_FLOORDIV
if (batchSize != CudaNdarray_HOST_DIMS(top)[0] ||
if (batchSize != CudaNdarray_HOST_DIMS(top)[0] ||
nFilters != CudaNdarray_HOST_DIMS(top)[1] ||
nFilters != CudaNdarray_HOST_DIMS(top)[1] ||
topHeight != CudaNdarray_HOST_DIMS(top)[2] ||
topHeight != CudaNdarray_HOST_DIMS(top)[2] ||
...
@@ -478,6 +486,19 @@ CudaNdarray* corr3dMM(CudaNdarray *const bottom,
...
@@ -478,6 +486,19 @@ CudaNdarray* corr3dMM(CudaNdarray *const bottom,
if (direction == 0)
if (direction == 0)
{ // forward pass
{ // forward pass
output = top;
output = top;
if (batchSize == 0 || nChannels == 0 || nFilters == 0) {
cudaError_t err = cudaMemset(output->devdata, 0,
CudaNdarray_SIZE(output) * sizeof(real));
if (err != cudaSuccess) {
PyErr_Format(PyExc_RuntimeError,
"GpuCorr3dMM could not fill the output with zeros: %s",
cudaGetErrorString(err));
Py_DECREF(col);
return NULL;
}
Py_DECREF(col);
return output;
}
// valid correlation: im2col, then gemm
// valid correlation: im2col, then gemm
// Iterate over batch
// Iterate over batch
for (int n = 0; n < batchSize; n++)
for (int n = 0; n < batchSize; n++)
...
@@ -527,6 +548,19 @@ CudaNdarray* corr3dMM(CudaNdarray *const bottom,
...
@@ -527,6 +548,19 @@ CudaNdarray* corr3dMM(CudaNdarray *const bottom,
{
{
// backprop wrt. weights
// backprop wrt. weights
output = weight;
output = weight;
if (batchSize == 0 || nChannels == 0 || nFilters == 0) {
cudaError_t err = cudaMemset(output->devdata, 0,
CudaNdarray_SIZE(output) * sizeof(real));
if (err != cudaSuccess) {
PyErr_Format(PyExc_RuntimeError,
"GpuCorr3dMM grad wrt. weights could not fill the output with zeros: %s",
cudaGetErrorString(err));
Py_DECREF(col);
return NULL;
}
Py_DECREF(col);
return output;
}
// valid convolution: im2col, then gemm
// valid convolution: im2col, then gemm
// Iterate over batch
// Iterate over batch
for (int n = 0; n < batchSize; n++)
for (int n = 0; n < batchSize; n++)
...
@@ -578,6 +612,19 @@ CudaNdarray* corr3dMM(CudaNdarray *const bottom,
...
@@ -578,6 +612,19 @@ CudaNdarray* corr3dMM(CudaNdarray *const bottom,
{
{
// backprop wrt. inputs
// backprop wrt. inputs
output = bottom;
output = bottom;
if (batchSize == 0 || nChannels == 0 || nFilters == 0) {
cudaError_t err = cudaMemset(output->devdata, 0,
CudaNdarray_SIZE(output) * sizeof(real));
if (err != cudaSuccess) {
PyErr_Format(PyExc_RuntimeError,
"GpuCorr3dMM grad wrt. inputs could not fill the output with zeros: %s",
cudaGetErrorString(err));
Py_DECREF(col);
return NULL;
}
Py_DECREF(col);
return output;
}
// full convolution: gemm, then col2im3d
// full convolution: gemm, then col2im3d
// Iterate over batch
// Iterate over batch
for (int n = 0; n < batchSize; n++)
for (int n = 0; n < batchSize; n++)
...
...
theano/sandbox/cuda/corr_gemm.cu
浏览文件 @
c072d669
...
@@ -333,8 +333,15 @@ CudaNdarray* corrMM(CudaNdarray *const bottom,
...
@@ -333,8 +333,15 @@ CudaNdarray* corrMM(CudaNdarray *const bottom,
const int dil_kH = (kH - 1) * dilH + 1;
const int dil_kH = (kH - 1) * dilH + 1;
const int dil_kW = (kW - 1) * dilW + 1;
const int dil_kW = (kW - 1) * dilW + 1;
// top: (batchSize, nFilters, topHeight, topWidth)
// top: (batchSize, nFilters, topHeight, topWidth)
const int topHeight = (bottomHeight + 2*padH - dil_kH) / dH + 1;
const int topHeightNoDH = (bottomHeight + 2*padH - dil_kH);
const int topWidth = (bottomWidth + 2*padW - dil_kW) / dW + 1;
const int topWidthNoDW = (bottomWidth + 2*padW - dil_kW);
// the above values might be negative so we need to use Python-like
// flooring integer division to be compatible with get_conv_output.
// note: this macro implements Python's // for negative x only
#define _CONV_FLOORDIV_X(x,y) ((x < 0) ? (- ((-x) / y) - (((-x) % y) == 0 ? 0 : 1)) : (x / y))
const int topHeight = _CONV_FLOORDIV_X(topHeightNoDH, dH) + 1;
const int topWidth = _CONV_FLOORDIV_X(topWidthNoDW, dW) + 1;
#undef _CONV_FLOORDIV
if (batchSize != CudaNdarray_HOST_DIMS(top)[0] ||
if (batchSize != CudaNdarray_HOST_DIMS(top)[0] ||
nFilters != CudaNdarray_HOST_DIMS(top)[1] ||
nFilters != CudaNdarray_HOST_DIMS(top)[1] ||
topHeight != CudaNdarray_HOST_DIMS(top)[2] ||
topHeight != CudaNdarray_HOST_DIMS(top)[2] ||
...
@@ -377,6 +384,19 @@ CudaNdarray* corrMM(CudaNdarray *const bottom,
...
@@ -377,6 +384,19 @@ CudaNdarray* corrMM(CudaNdarray *const bottom,
CudaNdarray *output;
CudaNdarray *output;
if (direction == 0) { // forward pass
if (direction == 0) { // forward pass
output = top;
output = top;
if (batchSize == 0 || nChannels == 0 || nFilters == 0) {
cudaError_t err = cudaMemset(output->devdata, 0,
CudaNdarray_SIZE(output) * sizeof(real));
if (err != cudaSuccess) {
PyErr_Format(PyExc_RuntimeError,
"GpuCorrMM could not fill the output with zeros: %s",
cudaGetErrorString(err));
Py_DECREF(col);
return NULL;
}
Py_DECREF(col);
return output;
}
// valid correlation: im2col, then gemm
// valid correlation: im2col, then gemm
// Iterate over batch
// Iterate over batch
for (int n = 0; n < batchSize; n++) {
for (int n = 0; n < batchSize; n++) {
...
@@ -445,6 +465,19 @@ CudaNdarray* corrMM(CudaNdarray *const bottom,
...
@@ -445,6 +465,19 @@ CudaNdarray* corrMM(CudaNdarray *const bottom,
}
}
else if (direction == 1) { // backprop wrt. weights
else if (direction == 1) { // backprop wrt. weights
output = weight;
output = weight;
if (batchSize == 0 || nChannels == 0 || nFilters == 0) {
cudaError_t err = cudaMemset(output->devdata, 0,
CudaNdarray_SIZE(output) * sizeof(real));
if (err != cudaSuccess) {
PyErr_Format(PyExc_RuntimeError,
"GpuCorrMM grad wrt. weights could not fill the output with zeros: %s",
cudaGetErrorString(err));
Py_DECREF(col);
return NULL;
}
Py_DECREF(col);
return output;
}
// valid convolution: im2col, then gemm
// valid convolution: im2col, then gemm
// Iterate over batch
// Iterate over batch
for (int n = 0; n < batchSize; n++) {
for (int n = 0; n < batchSize; n++) {
...
@@ -513,6 +546,19 @@ CudaNdarray* corrMM(CudaNdarray *const bottom,
...
@@ -513,6 +546,19 @@ CudaNdarray* corrMM(CudaNdarray *const bottom,
}
}
else if (direction == 2) { // backprop wrt. inputs
else if (direction == 2) { // backprop wrt. inputs
output = bottom;
output = bottom;
if (batchSize == 0 || nChannels == 0 || nFilters == 0) {
cudaError_t err = cudaMemset(output->devdata, 0,
CudaNdarray_SIZE(output) * sizeof(real));
if (err != cudaSuccess) {
PyErr_Format(PyExc_RuntimeError,
"GpuCorrMM grad wrt. inputs could not fill the output with zeros: %s",
cudaGetErrorString(err));
Py_DECREF(col);
return NULL;
}
Py_DECREF(col);
return output;
}
// full convolution: gemm, then col2im
// full convolution: gemm, then col2im
// Iterate over batch
// Iterate over batch
for (int n = 0; n < batchSize; n++) {
for (int n = 0; n < batchSize; n++) {
...
...
theano/sandbox/cuda/dnn.py
浏览文件 @
c072d669
...
@@ -14,7 +14,8 @@ from theano.gof.type import CDataType
...
@@ -14,7 +14,8 @@ from theano.gof.type import CDataType
from
theano.compile
import
optdb
from
theano.compile
import
optdb
from
theano.compile.ops
import
shape_i
from
theano.compile.ops
import
shape_i
from
theano.tensor.nnet
import
LogSoftmax
,
SoftmaxGrad
from
theano.tensor.nnet
import
LogSoftmax
,
SoftmaxGrad
from
theano.tensor.nnet.abstract_conv
import
get_conv_output_shape
from
theano.tensor.nnet.abstract_conv
import
(
get_conv_output_shape
,
assert_conv_shape
)
from
theano.tensor.signal.pool
import
(
from
theano.tensor.signal.pool
import
(
Pool
,
MaxPoolGrad
,
AveragePoolGrad
)
Pool
,
MaxPoolGrad
,
AveragePoolGrad
)
from
theano.sandbox.cuda.type
import
CudaNdarrayType
from
theano.sandbox.cuda.type
import
CudaNdarrayType
...
@@ -1132,10 +1133,12 @@ def dnn_conv(img, kerns, border_mode='valid', subsample=(1, 1),
...
@@ -1132,10 +1133,12 @@ def dnn_conv(img, kerns, border_mode='valid', subsample=(1, 1),
# that would be flipped by conv_mode='conv' in GpuDnnConvGradW.
# that would be flipped by conv_mode='conv' in GpuDnnConvGradW.
kerns
=
kerns
[:,
:,
::
-
1
,
::
-
1
]
kerns
=
kerns
[:,
:,
::
-
1
,
::
-
1
]
kerns
=
gpu_contiguous
(
kerns
.
dimshuffle
(
1
,
0
,
2
,
3
))
kerns
=
gpu_contiguous
(
kerns
.
dimshuffle
(
1
,
0
,
2
,
3
))
shape2
=
shape_i
(
img
,
2
,
fgraph
)
-
shape_i
(
kerns
,
2
,
fgraph
)
+
1
out_shp
=
(
shape_i
(
kerns
,
1
,
fgraph
),
shape3
=
shape_i
(
img
,
3
,
fgraph
)
-
shape_i
(
kerns
,
3
,
fgraph
)
+
1
shape_i
(
img
,
1
,
fgraph
),
out
=
gpu_alloc_empty
(
shape_i
(
kerns
,
1
,
fgraph
),
shape_i
(
img
,
2
,
fgraph
)
-
shape_i
(
kerns
,
2
,
fgraph
)
+
1
,
shape_i
(
img
,
1
,
fgraph
),
shape2
,
shape3
)
shape_i
(
img
,
3
,
fgraph
)
-
shape_i
(
kerns
,
3
,
fgraph
)
+
1
)
out_shp
=
assert_conv_shape
(
out_shp
)
out
=
gpu_alloc_empty
(
*
out_shp
)
desc
=
GpuDnnConvDesc
(
border_mode
=
'valid'
,
subsample
=
(
1
,
1
),
desc
=
GpuDnnConvDesc
(
border_mode
=
'valid'
,
subsample
=
(
1
,
1
),
conv_mode
=
'cross'
,
precision
=
precision
)(
img
.
shape
,
conv_mode
=
'cross'
,
precision
=
precision
)(
img
.
shape
,
out
.
shape
)
out
.
shape
)
...
@@ -1149,10 +1152,12 @@ def dnn_conv(img, kerns, border_mode='valid', subsample=(1, 1),
...
@@ -1149,10 +1152,12 @@ def dnn_conv(img, kerns, border_mode='valid', subsample=(1, 1),
img
=
gpu_contiguous
(
img
)
img
=
gpu_contiguous
(
img
)
kerns
=
gpu_contiguous
(
kerns
.
dimshuffle
(
1
,
0
,
2
,
3
))
kerns
=
gpu_contiguous
(
kerns
.
dimshuffle
(
1
,
0
,
2
,
3
))
conv_mode
=
'cross'
if
conv_mode
==
'conv'
else
'conv'
conv_mode
=
'cross'
if
conv_mode
==
'conv'
else
'conv'
shape2
=
shape_i
(
img
,
2
,
fgraph
)
+
shape_i
(
kerns
,
2
,
fgraph
)
-
1
out_shp
=
(
shape_i
(
img
,
0
,
fgraph
),
shape3
=
shape_i
(
img
,
3
,
fgraph
)
+
shape_i
(
kerns
,
3
,
fgraph
)
-
1
shape_i
(
kerns
,
1
,
fgraph
),
out
=
gpu_alloc_empty
(
shape_i
(
img
,
0
,
fgraph
),
shape_i
(
img
,
2
,
fgraph
)
+
shape_i
(
kerns
,
2
,
fgraph
)
-
1
,
shape_i
(
kerns
,
1
,
fgraph
),
shape2
,
shape3
)
shape_i
(
img
,
3
,
fgraph
)
+
shape_i
(
kerns
,
3
,
fgraph
)
-
1
)
out_shp
=
assert_conv_shape
(
out_shp
)
out
=
gpu_alloc_empty
(
*
out_shp
)
desc
=
GpuDnnConvDesc
(
border_mode
=
'valid'
,
subsample
=
(
1
,
1
),
desc
=
GpuDnnConvDesc
(
border_mode
=
'valid'
,
subsample
=
(
1
,
1
),
conv_mode
=
conv_mode
,
precision
=
precision
)(
out
.
shape
,
conv_mode
=
conv_mode
,
precision
=
precision
)(
out
.
shape
,
kerns
.
shape
)
kerns
.
shape
)
...
@@ -1170,6 +1175,7 @@ def dnn_conv(img, kerns, border_mode='valid', subsample=(1, 1),
...
@@ -1170,6 +1175,7 @@ def dnn_conv(img, kerns, border_mode='valid', subsample=(1, 1),
out_shp
=
GpuDnnConv
.
get_out_shape
(
img
.
shape
,
kerns
.
shape
,
out_shp
=
GpuDnnConv
.
get_out_shape
(
img
.
shape
,
kerns
.
shape
,
desc_op
.
border_mode
,
desc_op
.
border_mode
,
desc_op
.
subsample
)
desc_op
.
subsample
)
out_shp
=
assert_conv_shape
(
out_shp
)
out
=
gpu_alloc_empty
(
*
out_shp
)
out
=
gpu_alloc_empty
(
*
out_shp
)
return
GpuDnnConv
(
algo
=
algo
)(
img
,
kerns
,
out
,
desc
)
return
GpuDnnConv
(
algo
=
algo
)(
img
,
kerns
,
out
,
desc
)
...
@@ -1248,11 +1254,13 @@ def dnn_conv3d(img, kerns, border_mode='valid', subsample=(1, 1, 1),
...
@@ -1248,11 +1254,13 @@ def dnn_conv3d(img, kerns, border_mode='valid', subsample=(1, 1, 1),
# that would be flipped by conv_mode='conv' in GpuDnnConvGradW.
# that would be flipped by conv_mode='conv' in GpuDnnConvGradW.
kerns
=
kerns
[:,
:,
::
-
1
,
::
-
1
,
::
-
1
]
kerns
=
kerns
[:,
:,
::
-
1
,
::
-
1
,
::
-
1
]
kerns
=
gpu_contiguous
(
kerns
.
dimshuffle
(
1
,
0
,
2
,
3
,
4
))
kerns
=
gpu_contiguous
(
kerns
.
dimshuffle
(
1
,
0
,
2
,
3
,
4
))
shape2
=
shape_i
(
img
,
2
,
fgraph
)
-
shape_i
(
kerns
,
2
,
fgraph
)
+
1
out_shp
=
(
shape_i
(
kerns
,
1
,
fgraph
),
shape3
=
shape_i
(
img
,
3
,
fgraph
)
-
shape_i
(
kerns
,
3
,
fgraph
)
+
1
shape_i
(
img
,
1
,
fgraph
),
shape4
=
shape_i
(
img
,
4
,
fgraph
)
-
shape_i
(
kerns
,
4
,
fgraph
)
+
1
shape_i
(
img
,
2
,
fgraph
)
-
shape_i
(
kerns
,
2
,
fgraph
)
+
1
,
out
=
gpu_alloc_empty
(
shape_i
(
kerns
,
1
,
fgraph
),
shape_i
(
img
,
3
,
fgraph
)
-
shape_i
(
kerns
,
3
,
fgraph
)
+
1
,
shape_i
(
img
,
1
,
fgraph
),
shape2
,
shape3
,
shape4
)
shape_i
(
img
,
4
,
fgraph
)
-
shape_i
(
kerns
,
4
,
fgraph
)
+
1
)
out_shp
=
assert_conv_shape
(
out_shp
)
out
=
gpu_alloc_empty
(
*
out_shp
)
desc
=
GpuDnnConvDesc
(
border_mode
=
'valid'
,
subsample
=
(
1
,
1
,
1
),
desc
=
GpuDnnConvDesc
(
border_mode
=
'valid'
,
subsample
=
(
1
,
1
,
1
),
conv_mode
=
'cross'
,
precision
=
precision
)(
img
.
shape
,
conv_mode
=
'cross'
,
precision
=
precision
)(
img
.
shape
,
out
.
shape
)
out
.
shape
)
...
@@ -1271,6 +1279,7 @@ def dnn_conv3d(img, kerns, border_mode='valid', subsample=(1, 1, 1),
...
@@ -1271,6 +1279,7 @@ def dnn_conv3d(img, kerns, border_mode='valid', subsample=(1, 1, 1),
out_shp
=
GpuDnnConv3d
.
get_out_shape
(
img
.
shape
,
kerns
.
shape
,
out_shp
=
GpuDnnConv3d
.
get_out_shape
(
img
.
shape
,
kerns
.
shape
,
desc_op
.
border_mode
,
desc_op
.
border_mode
,
desc_op
.
subsample
)
desc_op
.
subsample
)
out_shp
=
assert_conv_shape
(
out_shp
)
out
=
gpu_alloc_empty
(
*
out_shp
)
out
=
gpu_alloc_empty
(
*
out_shp
)
return
GpuDnnConv3d
(
algo
=
algo
)(
img
,
kerns
,
out
,
desc
)
return
GpuDnnConv3d
(
algo
=
algo
)(
img
,
kerns
,
out
,
desc
)
...
...
theano/sandbox/cuda/dnn_fwd.c
浏览文件 @
c072d669
...
@@ -12,11 +12,6 @@ APPLY_SPECIFIC(conv_fwd)(CudaNdarray *input, CudaNdarray *kerns,
...
@@ -12,11 +12,6 @@ APPLY_SPECIFIC(conv_fwd)(CudaNdarray *input, CudaNdarray *kerns,
return
1
;
return
1
;
}
}
if
(
c_set_tensorNd
(
input
,
APPLY_SPECIFIC
(
input
))
==
-
1
)
return
1
;
if
(
c_set_filterNd
(
kerns
,
APPLY_SPECIFIC
(
kerns
))
==
-
1
)
return
1
;
int
nb_dim
=
CudaNdarray_NDIM
(
input
);
int
nb_dim
=
CudaNdarray_NDIM
(
input
);
#ifdef CONV_INPLACE
#ifdef CONV_INPLACE
...
@@ -30,8 +25,24 @@ APPLY_SPECIFIC(conv_fwd)(CudaNdarray *input, CudaNdarray *kerns,
...
@@ -30,8 +25,24 @@ APPLY_SPECIFIC(conv_fwd)(CudaNdarray *input, CudaNdarray *kerns,
return
1
;
return
1
;
#endif
#endif
if
(
c_set_tensorNd
(
*
output
,
APPLY_SPECIFIC
(
output
))
==
-
1
)
if
(
CudaNdarray_DIMS
(
input
)[
0
]
==
0
||
CudaNdarray_DIMS
(
kerns
)[
0
]
==
0
||
CudaNdarray_DIMS
(
kerns
)[
1
]
==
0
)
{
return
1
;
cudaError_t
err2
=
cudaMemset
((
*
output
)
->
devdata
,
0
,
CudaNdarray_SIZE
(
*
output
)
*
sizeof
(
real
));
if
(
err2
!=
cudaSuccess
)
{
PyErr_Format
(
PyExc_RuntimeError
,
"GpuDnnConv could not fill the output with zeros: %s"
,
cudaGetErrorString
(
err2
));
return
1
;
}
return
0
;
}
if
(
c_set_tensorNd
(
input
,
APPLY_SPECIFIC
(
input
))
==
-
1
)
return
1
;
if
(
c_set_filterNd
(
kerns
,
APPLY_SPECIFIC
(
kerns
))
==
-
1
)
return
1
;
if
(
c_set_tensorNd
(
*
output
,
APPLY_SPECIFIC
(
output
))
==
-
1
)
return
1
;
{
{
size_t
worksize
;
size_t
worksize
;
...
...
theano/sandbox/cuda/dnn_gi.c
浏览文件 @
c072d669
...
@@ -12,11 +12,6 @@ APPLY_SPECIFIC(conv_gi)(CudaNdarray *kerns, CudaNdarray *output,
...
@@ -12,11 +12,6 @@ APPLY_SPECIFIC(conv_gi)(CudaNdarray *kerns, CudaNdarray *output,
return
1
;
return
1
;
}
}
if
(
c_set_tensorNd
(
output
,
APPLY_SPECIFIC
(
output
))
==
-
1
)
return
1
;
if
(
c_set_filterNd
(
kerns
,
APPLY_SPECIFIC
(
kerns
))
==
-
1
)
return
1
;
int
nb_dim
=
CudaNdarray_NDIM
(
output
);
int
nb_dim
=
CudaNdarray_NDIM
(
output
);
#ifdef CONV_INPLACE
#ifdef CONV_INPLACE
...
@@ -30,9 +25,64 @@ APPLY_SPECIFIC(conv_gi)(CudaNdarray *kerns, CudaNdarray *output,
...
@@ -30,9 +25,64 @@ APPLY_SPECIFIC(conv_gi)(CudaNdarray *kerns, CudaNdarray *output,
return
1
;
return
1
;
#endif
#endif
if
(
CudaNdarray_DIMS
(
im
)[
0
]
==
0
||
CudaNdarray_DIMS
(
kerns
)[
0
]
==
0
||
CudaNdarray_DIMS
(
kerns
)[
1
]
==
0
)
{
cudaError_t
err2
=
cudaMemset
((
*
input
)
->
devdata
,
0
,
CudaNdarray_SIZE
(
*
input
)
*
sizeof
(
real
));
if
(
err2
!=
cudaSuccess
)
{
PyErr_Format
(
PyExc_RuntimeError
,
"GpuDnnConv grad wrt. inputs could not fill the output with zeros: %s"
,
cudaGetErrorString
(
err2
));
return
1
;
}
return
0
;
}
if
(
c_set_tensorNd
(
output
,
APPLY_SPECIFIC
(
output
))
==
-
1
)
return
1
;
if
(
c_set_filterNd
(
kerns
,
APPLY_SPECIFIC
(
kerns
))
==
-
1
)
return
1
;
if
(
c_set_tensorNd
(
*
input
,
APPLY_SPECIFIC
(
input
))
==
-
1
)
if
(
c_set_tensorNd
(
*
input
,
APPLY_SPECIFIC
(
input
))
==
-
1
)
return
1
;
return
1
;
int
expected_output_dims
[
5
]
=
{
0
};
err
=
cudnnGetConvolutionNdForwardOutputDim
(
desc
,
APPLY_SPECIFIC
(
input
),
APPLY_SPECIFIC
(
kerns
),
nb_dim
,
expected_output_dims
);
if
(
err
!=
CUDNN_STATUS_SUCCESS
)
{
PyErr_Format
(
PyExc_RuntimeError
,
"error computing convolution output dim: %s"
,
cudnnGetErrorString
(
err
));
return
1
;
}
if
(
nb_dim
==
4
)
{
if
((
CudaNdarray_HOST_DIMS
(
output
)[
0
]
!=
expected_output_dims
[
0
])
||
(
CudaNdarray_HOST_DIMS
(
output
)[
1
]
!=
expected_output_dims
[
1
])
||
(
CudaNdarray_HOST_DIMS
(
output
)[
2
]
!=
expected_output_dims
[
2
])
||
(
CudaNdarray_HOST_DIMS
(
output
)[
3
]
!=
expected_output_dims
[
3
]))
{
PyErr_Format
(
PyExc_ValueError
,
"impossible convolution output dim: expected %ldx%ldx%ldx%ld"
" but received gradient with shape %ldx%ldx%ldx%ld"
,
(
long
int
)
expected_output_dims
[
0
],
(
long
int
)
expected_output_dims
[
1
],
(
long
int
)
expected_output_dims
[
2
],
(
long
int
)
expected_output_dims
[
3
],
(
long
int
)
CudaNdarray_HOST_DIMS
(
output
)[
0
],
(
long
int
)
CudaNdarray_HOST_DIMS
(
output
)[
1
],
(
long
int
)
CudaNdarray_HOST_DIMS
(
output
)[
2
],
(
long
int
)
CudaNdarray_HOST_DIMS
(
output
)[
3
]);
return
1
;
}
}
else
if
(
nb_dim
==
5
)
{
if
((
CudaNdarray_HOST_DIMS
(
output
)[
0
]
!=
expected_output_dims
[
0
])
||
(
CudaNdarray_HOST_DIMS
(
output
)[
1
]
!=
expected_output_dims
[
1
])
||
(
CudaNdarray_HOST_DIMS
(
output
)[
2
]
!=
expected_output_dims
[
2
])
||
(
CudaNdarray_HOST_DIMS
(
output
)[
3
]
!=
expected_output_dims
[
3
])
||
(
CudaNdarray_HOST_DIMS
(
output
)[
4
]
!=
expected_output_dims
[
4
]))
{
PyErr_Format
(
PyExc_ValueError
,
"impossible convolution output dim: expected %ldx%ldx%ldx%ldx%ld"
" but received gradient with shape %ldx%ldx%ldx%ldx%ld"
,
(
long
int
)
expected_output_dims
[
0
],
(
long
int
)
expected_output_dims
[
1
],
(
long
int
)
expected_output_dims
[
2
],
(
long
int
)
expected_output_dims
[
3
],
(
long
int
)
expected_output_dims
[
4
],
(
long
int
)
CudaNdarray_HOST_DIMS
(
output
)[
0
],
(
long
int
)
CudaNdarray_HOST_DIMS
(
output
)[
1
],
(
long
int
)
CudaNdarray_HOST_DIMS
(
output
)[
2
],
(
long
int
)
CudaNdarray_HOST_DIMS
(
output
)[
3
],
(
long
int
)
CudaNdarray_HOST_DIMS
(
output
)[
4
]);
return
1
;
}
}
{
{
size_t
worksize
;
size_t
worksize
;
void
*
workspace
;
void
*
workspace
;
...
...
theano/sandbox/cuda/dnn_gw.c
浏览文件 @
c072d669
...
@@ -12,11 +12,6 @@ APPLY_SPECIFIC(conv_gw)(CudaNdarray *input, CudaNdarray *output,
...
@@ -12,11 +12,6 @@ APPLY_SPECIFIC(conv_gw)(CudaNdarray *input, CudaNdarray *output,
return
1
;
return
1
;
}
}
if
(
c_set_tensorNd
(
input
,
APPLY_SPECIFIC
(
input
))
==
-
1
)
return
1
;
if
(
c_set_tensorNd
(
output
,
APPLY_SPECIFIC
(
output
))
==
-
1
)
return
1
;
int
nb_dim
=
CudaNdarray_NDIM
(
output
);
int
nb_dim
=
CudaNdarray_NDIM
(
output
);
#ifdef CONV_INPLACE
#ifdef CONV_INPLACE
...
@@ -30,9 +25,64 @@ APPLY_SPECIFIC(conv_gw)(CudaNdarray *input, CudaNdarray *output,
...
@@ -30,9 +25,64 @@ APPLY_SPECIFIC(conv_gw)(CudaNdarray *input, CudaNdarray *output,
return
1
;
return
1
;
#endif
#endif
if
(
CudaNdarray_DIMS
(
input
)[
0
]
==
0
||
CudaNdarray_DIMS
(
km
)[
0
]
==
0
||
CudaNdarray_DIMS
(
km
)[
1
]
==
0
)
{
cudaError_t
err2
=
cudaMemset
((
*
kerns
)
->
devdata
,
0
,
CudaNdarray_SIZE
(
*
kerns
)
*
sizeof
(
real
));
if
(
err2
!=
cudaSuccess
)
{
PyErr_Format
(
PyExc_RuntimeError
,
"GpuDnnConv grad wrt. weights could not fill the output with zeros: %s"
,
cudaGetErrorString
(
err2
));
return
1
;
}
return
0
;
}
if
(
c_set_tensorNd
(
input
,
APPLY_SPECIFIC
(
input
))
==
-
1
)
return
1
;
if
(
c_set_tensorNd
(
output
,
APPLY_SPECIFIC
(
output
))
==
-
1
)
return
1
;
if
(
c_set_filterNd
(
*
kerns
,
APPLY_SPECIFIC
(
kerns
))
==
-
1
)
if
(
c_set_filterNd
(
*
kerns
,
APPLY_SPECIFIC
(
kerns
))
==
-
1
)
return
1
;
return
1
;
int
expected_output_dims
[
5
]
=
{
0
};
err
=
cudnnGetConvolutionNdForwardOutputDim
(
desc
,
APPLY_SPECIFIC
(
input
),
APPLY_SPECIFIC
(
kerns
),
nb_dim
,
expected_output_dims
);
if
(
err
!=
CUDNN_STATUS_SUCCESS
)
{
PyErr_Format
(
PyExc_RuntimeError
,
"error computing convolution output dim: %s"
,
cudnnGetErrorString
(
err
));
return
1
;
}
if
(
nb_dim
==
4
)
{
if
((
CudaNdarray_HOST_DIMS
(
output
)[
0
]
!=
expected_output_dims
[
0
])
||
(
CudaNdarray_HOST_DIMS
(
output
)[
1
]
!=
expected_output_dims
[
1
])
||
(
CudaNdarray_HOST_DIMS
(
output
)[
2
]
!=
expected_output_dims
[
2
])
||
(
CudaNdarray_HOST_DIMS
(
output
)[
3
]
!=
expected_output_dims
[
3
]))
{
PyErr_Format
(
PyExc_ValueError
,
"impossible convolution output dim: expected %ldx%ldx%dx%ld"
" but received gradient with shape %ldx%ldx%dx%ld"
,
(
long
int
)
expected_output_dims
[
0
],
(
long
int
)
expected_output_dims
[
1
],
(
long
int
)
expected_output_dims
[
2
],
(
long
int
)
expected_output_dims
[
3
],
(
long
int
)
CudaNdarray_HOST_DIMS
(
output
)[
0
],
(
long
int
)
CudaNdarray_HOST_DIMS
(
output
)[
1
],
(
long
int
)
CudaNdarray_HOST_DIMS
(
output
)[
2
],
(
long
int
)
CudaNdarray_HOST_DIMS
(
output
)[
3
]);
return
1
;
}
}
else
if
(
nb_dim
==
5
)
{
if
((
CudaNdarray_HOST_DIMS
(
output
)[
0
]
!=
expected_output_dims
[
0
])
||
(
CudaNdarray_HOST_DIMS
(
output
)[
1
]
!=
expected_output_dims
[
1
])
||
(
CudaNdarray_HOST_DIMS
(
output
)[
2
]
!=
expected_output_dims
[
2
])
||
(
CudaNdarray_HOST_DIMS
(
output
)[
3
]
!=
expected_output_dims
[
3
])
||
(
CudaNdarray_HOST_DIMS
(
output
)[
4
]
!=
expected_output_dims
[
4
]))
{
PyErr_Format
(
PyExc_ValueError
,
"impossible convolution output dim: expected %ldx%ldx%ldx%ldx%ld"
" but received gradient with shape %ldx%ldx%ldx%ldx%ld"
,
(
long
int
)
expected_output_dims
[
0
],
(
long
int
)
expected_output_dims
[
1
],
(
long
int
)
expected_output_dims
[
2
],
(
long
int
)
expected_output_dims
[
3
],
(
long
int
)
expected_output_dims
[
4
],
(
long
int
)
CudaNdarray_HOST_DIMS
(
output
)[
0
],
(
long
int
)
CudaNdarray_HOST_DIMS
(
output
)[
1
],
(
long
int
)
CudaNdarray_HOST_DIMS
(
output
)[
2
],
(
long
int
)
CudaNdarray_HOST_DIMS
(
output
)[
3
],
(
long
int
)
CudaNdarray_HOST_DIMS
(
output
)[
4
]);
return
1
;
}
}
{
{
size_t
worksize
;
size_t
worksize
;
void
*
workspace
;
void
*
workspace
;
...
...
theano/sandbox/cuda/tests/test_abstractconv.py
浏览文件 @
c072d669
...
@@ -13,6 +13,7 @@ from theano.sandbox.cuda.blas import (
...
@@ -13,6 +13,7 @@ from theano.sandbox.cuda.blas import (
GpuCorrMM
,
GpuCorrMM_gradWeights
,
GpuCorrMM_gradInputs
,
GpuCorrMM
,
GpuCorrMM_gradWeights
,
GpuCorrMM_gradInputs
,
GpuCorr3dMM
,
GpuCorr3dMM_gradWeights
,
GpuCorr3dMM_gradInputs
)
GpuCorr3dMM
,
GpuCorr3dMM_gradWeights
,
GpuCorr3dMM_gradInputs
)
from
nose.plugins.skip
import
SkipTest
from
nose.plugins.skip
import
SkipTest
from
nose.tools
import
assert_raises
import
theano.sandbox.cuda
as
cuda
import
theano.sandbox.cuda
as
cuda
if
not
cuda
.
cuda_available
:
if
not
cuda
.
cuda_available
:
...
@@ -57,6 +58,31 @@ class TestDnnConv2d(test_abstract_conv.BaseTestConv2d):
...
@@ -57,6 +58,31 @@ class TestDnnConv2d(test_abstract_conv.BaseTestConv2d):
filter_flip
=
flip
,
target_op
=
GpuDnnConvGradI
,
filter_flip
=
flip
,
target_op
=
GpuDnnConvGradI
,
filter_dilation
=
fd
)
filter_dilation
=
fd
)
def
tcase_gi
(
self
,
i
,
f
,
o
,
s
,
b
,
flip
,
provide_shape
,
fd
=
(
1
,
1
),
expect_error
=
False
):
if
fd
!=
(
1
,
1
):
raise
SkipTest
(
"No dilation implementation for cuDNN ConvOp."
)
if
not
dnn_available
():
raise
SkipTest
(
cuda
.
dnn
.
dnn_available
.
msg
)
mode
=
mode_with_gpu
if
not
expect_error
:
self
.
run_gradinput
(
inputs_shape
=
i
,
filters_shape
=
f
,
output_shape
=
o
,
subsample
=
s
,
verify_grad
=
True
,
mode
=
mode
,
provide_shape
=
provide_shape
,
border_mode
=
b
,
filter_flip
=
flip
,
target_op
=
GpuDnnConvGradI
,
filter_dilation
=
fd
)
else
:
assert_raises
((
RuntimeError
,
ValueError
),
self
.
run_gradinput
,
inputs_shape
=
i
,
filters_shape
=
f
,
output_shape
=
o
,
subsample
=
s
,
verify_grad
=
False
,
mode
=
mode
,
provide_shape
=
provide_shape
,
border_mode
=
b
,
filter_flip
=
flip
,
target_op
=
GpuDnnConvGradI
,
ref
=
None
,
filter_dilation
=
fd
)
class
TestDnnConv3d
(
test_abstract_conv
.
BaseTestConv3d
):
class
TestDnnConv3d
(
test_abstract_conv
.
BaseTestConv3d
):
@classmethod
@classmethod
...
@@ -91,6 +117,31 @@ class TestDnnConv3d(test_abstract_conv.BaseTestConv3d):
...
@@ -91,6 +117,31 @@ class TestDnnConv3d(test_abstract_conv.BaseTestConv3d):
filter_flip
=
flip
,
target_op
=
GpuDnnConv3dGradI
,
filter_flip
=
flip
,
target_op
=
GpuDnnConv3dGradI
,
filter_dilation
=
fd
)
filter_dilation
=
fd
)
def
tcase_gi
(
self
,
i
,
f
,
o
,
s
,
b
,
flip
,
provide_shape
,
fd
=
(
1
,
1
,
1
),
expect_error
=
False
):
if
fd
!=
(
1
,
1
,
1
):
raise
SkipTest
(
"No dilation implementation for cuDNN ConvOp."
)
if
not
dnn_available
():
raise
SkipTest
(
cuda
.
dnn
.
dnn_available
.
msg
)
mode
=
mode_with_gpu
if
not
expect_error
:
self
.
run_gradinput
(
inputs_shape
=
i
,
filters_shape
=
f
,
output_shape
=
o
,
subsample
=
s
,
verify_grad
=
True
,
mode
=
mode
,
provide_shape
=
provide_shape
,
border_mode
=
b
,
filter_flip
=
flip
,
target_op
=
GpuDnnConvGradI
,
filter_dilation
=
fd
)
else
:
assert_raises
((
RuntimeError
,
ValueError
),
self
.
run_gradinput
,
inputs_shape
=
i
,
filters_shape
=
f
,
output_shape
=
o
,
subsample
=
s
,
verify_grad
=
False
,
mode
=
mode
,
provide_shape
=
provide_shape
,
border_mode
=
b
,
filter_flip
=
flip
,
target_op
=
GpuDnnConvGradI
,
ref
=
None
,
filter_dilation
=
fd
)
class
TestCorrMMConv2d
(
test_abstract_conv
.
BaseTestConv2d
):
class
TestCorrMMConv2d
(
test_abstract_conv
.
BaseTestConv2d
):
@classmethod
@classmethod
...
@@ -124,6 +175,28 @@ class TestCorrMMConv2d(test_abstract_conv.BaseTestConv2d):
...
@@ -124,6 +175,28 @@ class TestCorrMMConv2d(test_abstract_conv.BaseTestConv2d):
target_op
=
GpuCorrMM_gradInputs
,
target_op
=
GpuCorrMM_gradInputs
,
filter_dilation
=
fd
)
filter_dilation
=
fd
)
def
tcase_gi
(
self
,
i
,
f
,
o
,
s
,
b
,
flip
,
provide_shape
,
fd
=
(
1
,
1
),
expect_error
=
False
):
mode
=
self
.
mode
if
not
expect_error
:
self
.
run_gradinput
(
inputs_shape
=
i
,
filters_shape
=
f
,
output_shape
=
o
,
subsample
=
s
,
verify_grad
=
True
,
mode
=
mode
,
provide_shape
=
provide_shape
,
border_mode
=
b
,
filter_flip
=
flip
,
target_op
=
GpuCorrMM_gradInputs
,
filter_dilation
=
fd
)
else
:
assert_raises
(
ValueError
,
self
.
run_gradinput
,
inputs_shape
=
i
,
filters_shape
=
f
,
output_shape
=
o
,
subsample
=
s
,
verify_grad
=
False
,
mode
=
mode
,
provide_shape
=
provide_shape
,
border_mode
=
b
,
filter_flip
=
flip
,
target_op
=
GpuCorrMM_gradInputs
,
ref
=
None
,
filter_dilation
=
fd
)
class
TestCorrMMConv3d
(
test_abstract_conv
.
BaseTestConv3d
):
class
TestCorrMMConv3d
(
test_abstract_conv
.
BaseTestConv3d
):
@classmethod
@classmethod
...
@@ -157,6 +230,28 @@ class TestCorrMMConv3d(test_abstract_conv.BaseTestConv3d):
...
@@ -157,6 +230,28 @@ class TestCorrMMConv3d(test_abstract_conv.BaseTestConv3d):
target_op
=
GpuCorr3dMM_gradInputs
,
target_op
=
GpuCorr3dMM_gradInputs
,
filter_dilation
=
fd
)
filter_dilation
=
fd
)
def
tcase_gi
(
self
,
i
,
f
,
o
,
s
,
b
,
flip
,
provide_shape
,
fd
=
(
1
,
1
,
1
),
expect_error
=
False
):
mode
=
self
.
mode
if
not
expect_error
:
self
.
run_gradinput
(
inputs_shape
=
i
,
filters_shape
=
f
,
output_shape
=
o
,
subsample
=
s
,
verify_grad
=
True
,
mode
=
mode
,
provide_shape
=
provide_shape
,
border_mode
=
b
,
filter_flip
=
flip
,
target_op
=
GpuCorr3dMM_gradInputs
,
filter_dilation
=
fd
)
else
:
assert_raises
(
ValueError
,
self
.
run_gradinput
,
inputs_shape
=
i
,
filters_shape
=
f
,
output_shape
=
o
,
subsample
=
s
,
verify_grad
=
False
,
mode
=
mode
,
provide_shape
=
provide_shape
,
border_mode
=
b
,
filter_flip
=
flip
,
target_op
=
GpuCorr3dMM_gradInputs
,
ref
=
None
,
filter_dilation
=
fd
)
class
TestDnnConvTypes
(
test_abstract_conv
.
TestConvTypes
):
class
TestDnnConvTypes
(
test_abstract_conv
.
TestConvTypes
):
def
setUp
(
self
):
def
setUp
(
self
):
...
...
theano/sandbox/cuda/tests/test_dnn.py
浏览文件 @
c072d669
...
@@ -4,6 +4,7 @@ import os
...
@@ -4,6 +4,7 @@ import os
import
sys
import
sys
from
nose.plugins.skip
import
SkipTest
from
nose.plugins.skip
import
SkipTest
from
nose_parameterized
import
parameterized
from
itertools
import
chain
,
product
from
itertools
import
chain
,
product
import
six.moves.cPickle
as
pickle
import
six.moves.cPickle
as
pickle
from
six
import
StringIO
from
six
import
StringIO
...
@@ -16,6 +17,7 @@ import theano.tensor as T
...
@@ -16,6 +17,7 @@ import theano.tensor as T
import
theano.tests.unittest_tools
as
utt
import
theano.tests.unittest_tools
as
utt
from
theano.tensor.signal.pool
import
pool_2d
,
pool_3d
from
theano.tensor.signal.pool
import
pool_2d
,
pool_3d
from
theano.tensor.signal.pool
import
Pool
,
MaxPoolGrad
,
AveragePoolGrad
from
theano.tensor.signal.pool
import
Pool
,
MaxPoolGrad
,
AveragePoolGrad
from
theano.tensor.nnet.abstract_conv
import
get_conv_output_shape
import
theano.sandbox.cuda.dnn
as
dnn
import
theano.sandbox.cuda.dnn
as
dnn
from
theano.sandbox.cuda.basic_ops
import
GpuAllocEmpty
,
gpu_alloc_empty
from
theano.sandbox.cuda.basic_ops
import
GpuAllocEmpty
,
gpu_alloc_empty
from
theano.sandbox.cuda
import
float32_shared_constructor
as
shared
from
theano.sandbox.cuda
import
float32_shared_constructor
as
shared
...
@@ -979,98 +981,104 @@ class TestDnnInferShapes(utt.InferShapeTester):
...
@@ -979,98 +981,104 @@ class TestDnnInferShapes(utt.InferShapeTester):
dnn
.
GpuDnnConv3d
dnn
.
GpuDnnConv3d
)
)
def
test_conv_gradw
(
self
):
def
_test_conv_gradw
(
self
,
img
,
topgrad
,
kerns
,
img_shape
,
kerns_shape
,
border_mode
,
conv_mode
,
subsample
):
if
not
dnn
.
dnn_available
():
if
not
dnn
.
dnn_available
():
raise
SkipTest
(
dnn
.
dnn_available
.
msg
)
raise
SkipTest
(
dnn
.
dnn_available
.
msg
)
img
=
T
.
ftensor4
(
'img'
)
kerns
=
T
.
ftensor4
(
'kerns'
)
topgrad_shape
=
get_conv_output_shape
(
img_shape
,
kerns_shape
,
out
=
T
.
ftensor4
(
'out'
)
border_mode
,
subsample
)
img_val
=
numpy
.
asarray
(
img_val
=
numpy
.
asarray
(
numpy
.
random
.
rand
(
2
,
5
,
6
,
8
),
numpy
.
random
.
rand
(
*
img_shape
),
dtype
=
'float32'
dtype
=
theano
.
config
.
floatX
)
)
kern
_vals
=
numpy
.
asarray
(
topgrad
_vals
=
numpy
.
asarray
(
numpy
.
random
.
rand
(
2
,
1
,
5
,
6
),
numpy
.
random
.
rand
(
*
topgrad_shape
),
dtype
=
'float32'
dtype
=
theano
.
config
.
floatX
)
)
for
params
in
product
(
kerns_vals
=
numpy
.
zeros
(
kerns_shape
,
dtype
=
theano
.
config
.
floatX
)
[
'valid'
,
'full'
,
'half'
],
kerns_shape
=
theano
.
shared
(
numpy
.
asarray
(
kerns_shape
))
[(
1
,
1
)],
# strides besides (1, 1)
topgrad_shape
=
theano
.
shared
(
numpy
.
asarray
(
topgrad_shape
))
[
'conv'
,
'cross'
]
desc
=
dnn
.
GpuDnnConvDesc
(
):
border_mode
=
border_mode
,
temp_img
=
img
.
dimshuffle
(
1
,
0
,
2
,
3
)
subsample
=
subsample
,
temp_kerns
=
kerns
conv_mode
=
conv_mode
if
params
[
2
]
==
'conv'
:
)(
topgrad_shape
,
kerns_shape
)
temp_kerns
=
temp_kerns
[:,
:,
::
-
1
,
::
-
1
]
conv_grad_w
=
dnn
.
GpuDnnConvGradW
()(
temp_kerns
=
temp_kerns
.
dimshuffle
(
1
,
0
,
2
,
3
)
img
,
shape
=
(
topgrad
,
kern_vals
.
shape
[
1
],
img_val
.
shape
[
1
],
kerns
,
img_val
.
shape
[
2
]
-
kern_vals
.
shape
[
2
]
+
1
,
desc
,
img_val
.
shape
[
3
]
-
kern_vals
.
shape
[
3
]
+
1
)
)
self
.
_compile_and_check
(
out_vals
=
numpy
.
zeros
(
shape
,
dtype
=
'float32'
)
[
img
,
topgrad
,
kerns
],
desc
=
dnn
.
GpuDnnConvDesc
(
[
conv_grad_w
],
border_mode
=
params
[
0
],
[
img_val
,
topgrad_vals
,
kerns_vals
],
subsample
=
params
[
1
],
dnn
.
GpuDnnConvGradW
conv_mode
=
params
[
2
]
)
)(
temp_img
.
shape
,
out
.
shape
)
conv_grad_w
=
dnn
.
GpuDnnConvGradW
()(
border_modes
=
[
'valid'
,
'full'
,
'half'
]
temp_img
,
conv_modes
=
[
'conv'
,
'cross'
]
temp_kerns
,
out
,
desc
,
)
self
.
_compile_and_check
(
[
temp_img
,
temp_kerns
,
out
],
[
conv_grad_w
],
[
img_val
,
kern_vals
,
out_vals
],
dnn
.
GpuDnnConvGradW
)
def
test_conv3d_gradw
(
self
):
@parameterized.expand
(
product
(
border_modes
,
conv_modes
),
utt
.
custom_name_func
)
def
test_conv_gradw
(
self
,
border_mode
,
conv_mode
):
self
.
_test_conv_gradw
(
T
.
tensor4
(
'img'
),
T
.
tensor4
(
'topgrad'
),
T
.
tensor4
(
'kerns'
),
(
5
,
2
,
6
,
13
),
(
1
,
2
,
3
,
7
),
border_mode
,
conv_mode
,
(
1
,
1
))
def
_test_conv3d_gradw
(
self
,
img
,
topgrad
,
kerns
,
img_shape
,
kerns_shape
,
border_mode
,
conv_mode
,
subsample
):
if
not
(
cuda
.
dnn
.
dnn_available
()
and
dnn
.
version
()
>=
(
2000
,
2000
)):
if
not
(
cuda
.
dnn
.
dnn_available
()
and
dnn
.
version
()
>=
(
2000
,
2000
)):
raise
SkipTest
(
'"cuDNN 3D convolution requires cuDNN v2'
)
raise
SkipTest
(
'"cuDNN 3D convolution requires cuDNN v2'
)
img
=
T
.
ftensor5
(
'img'
)
kerns
=
T
.
ftensor5
(
'kerns'
)
topgrad_shape
=
get_conv_output_shape
(
img_shape
,
kerns_shape
,
out
=
T
.
ftensor5
(
'out'
)
border_mode
,
subsample
)
img_val
=
numpy
.
asarray
(
img_val
=
numpy
.
asarray
(
numpy
.
random
.
rand
(
9
,
2
,
4
,
8
,
13
),
numpy
.
random
.
rand
(
*
img_shape
),
dtype
=
'float32'
dtype
=
theano
.
config
.
floatX
)
)
kern
_vals
=
numpy
.
asarray
(
topgrad
_vals
=
numpy
.
asarray
(
numpy
.
random
.
rand
(
11
,
2
,
3
,
1
,
4
),
numpy
.
random
.
rand
(
*
topgrad_shape
),
dtype
=
'float32'
dtype
=
theano
.
config
.
floatX
)
)
for
params
in
product
(
kerns_vals
=
numpy
.
zeros
(
kerns_shape
,
dtype
=
theano
.
config
.
floatX
)
[
'valid'
,
'full'
,
'half'
],
kerns_shape
=
theano
.
shared
(
numpy
.
asarray
(
kerns_shape
))
[(
1
,
1
,
1
),
(
2
,
2
,
2
)],
topgrad_shape
=
theano
.
shared
(
numpy
.
asarray
(
topgrad_shape
))
[
'conv'
,
'cross'
]
desc
=
dnn
.
GpuDnnConvDesc
(
):
border_mode
=
border_mode
,
out_vals
=
numpy
.
zeros
(
subsample
=
subsample
,
dnn
.
GpuDnnConv3d
.
get_out_shape
(
img_val
.
shape
,
kern_vals
.
shape
,
conv_mode
=
conv_mode
border_mode
=
params
[
0
],
)(
topgrad_shape
,
kerns_shape
)
subsample
=
params
[
1
]),
conv_grad_w
=
dnn
.
GpuDnnConv3dGradW
()(
dtype
=
'float32'
)
img
,
topgrad
,
kerns
,
desc
,
)
self
.
_compile_and_check
(
[
img
,
topgrad
,
kerns
],
[
conv_grad_w
],
[
img_val
,
topgrad_vals
,
kerns_vals
],
dnn
.
GpuDnnConv3dGradW
)
desc
=
dnn
.
GpuDnnConvDesc
(
@parameterized.expand
(
product
(
border_modes
,
conv_modes
),
utt
.
custom_name_func
)
border_mode
=
params
[
0
],
def
test_conv3d_gradw
(
self
,
border_mode
,
conv_mode
):
subsample
=
params
[
1
],
self
.
_test_conv3d_gradw
(
T
.
tensor5
(
'img'
),
conv_mode
=
params
[
2
]
T
.
tensor5
(
'topgrad'
),
)(
img
.
shape
,
out
.
shape
)
T
.
tensor5
(
'kerns'
),
conv_grad_w
=
dnn
.
GpuDnnConv3dGradW
()(
(
5
,
2
,
6
,
13
,
21
),
img
,
(
1
,
2
,
3
,
7
,
9
),
out
,
border_mode
,
kerns
,
conv_mode
,
desc
,
(
1
,
1
,
1
))
)
self
.
_compile_and_check
(
[
img
,
out
,
kerns
],
[
conv_grad_w
],
[
img_val
,
out_vals
,
kern_vals
],
dnn
.
GpuDnnConv3dGradW
)
def
test_conv_gradi
(
self
):
def
test_conv_gradi
(
self
):
if
not
dnn
.
dnn_available
():
if
not
dnn
.
dnn_available
():
...
...
theano/tensor/nnet/abstract_conv.py
浏览文件 @
c072d669
差异被折叠。
点击展开。
theano/tensor/nnet/corr.py
浏览文件 @
c072d669
...
@@ -123,7 +123,7 @@ class BaseCorrMM(gof.OpenMPOp):
...
@@ -123,7 +123,7 @@ class BaseCorrMM(gof.OpenMPOp):
def
c_code_cache_version
(
self
):
def
c_code_cache_version
(
self
):
# raise this whenever modifying any of the support_code_files
# raise this whenever modifying any of the support_code_files
return
(
1
,
self
.
openmp
,
blas_header_version
())
return
(
5
,
self
.
openmp
,
blas_header_version
())
def
c_support_code_apply
(
self
,
node
,
nodename
):
def
c_support_code_apply
(
self
,
node
,
nodename
):
# REMEMBER TO RAISE c_code_cache_version when changing any of
# REMEMBER TO RAISE c_code_cache_version when changing any of
...
@@ -234,17 +234,17 @@ class BaseCorrMM(gof.OpenMPOp):
...
@@ -234,17 +234,17 @@ class BaseCorrMM(gof.OpenMPOp):
# When subsampling, we cannot unambiguously infer the height and width
# When subsampling, we cannot unambiguously infer the height and width
# of bottom and weights from top, so we require them to be given.
# of bottom and weights from top, so we require them to be given.
# Similarly, when border_mode="half", we cannot infer the weight size.
# Similarly, when border_mode="half", we cannot infer the weight size.
if
((
direction
!=
0
)
and
(
dH
!=
1
))
or
((
direction
==
1
)
and
(
padH
==
-
1
)):
if
height
:
if
not
height
:
raise
ValueError
(
"height must be given for backprop with vertical sampling or border_mode='half'"
)
height
=
'(*(npy_int64 *)(PyArray_DATA(
%
s)))'
%
height
height
=
'(*(npy_int64 *)(PyArray_DATA(
%
s)))'
%
height
else
:
else
:
if
((
direction
!=
0
)
and
(
dH
!=
1
))
or
((
direction
==
1
)
and
(
padH
==
-
1
)):
raise
ValueError
(
"height must be given for backprop with vertical sampling or border_mode='half'"
)
height
=
'-1'
height
=
'-1'
if
((
direction
!=
0
)
and
(
dW
!=
1
))
or
((
direction
==
1
)
and
(
padW
==
-
1
)):
if
width
:
if
not
width
:
raise
ValueError
(
"width must be given for backprop with horizontal sampling or border_mode='half'"
)
width
=
'(*(npy_int64 *)(PyArray_DATA(
%
s)))'
%
width
width
=
'(*(npy_int64 *)(PyArray_DATA(
%
s)))'
%
width
else
:
else
:
if
((
direction
!=
0
)
and
(
dW
!=
1
))
or
((
direction
==
1
)
and
(
padW
==
-
1
)):
raise
ValueError
(
"width must be given for backprop with horizontal sampling or border_mode='half'"
)
width
=
'-1'
width
=
'-1'
sub
=
sub
.
copy
()
sub
=
sub
.
copy
()
sub
.
update
(
locals
())
sub
.
update
(
locals
())
...
@@ -268,15 +268,15 @@ class BaseCorrMM(gof.OpenMPOp):
...
@@ -268,15 +268,15 @@ class BaseCorrMM(gof.OpenMPOp):
// Obtain or infer kernel width and height
// Obtain or infer kernel width and height
// (we need to know it early to be able to handle auto-padding)
// (we need to know it early to be able to handle auto-padding)
int kH, kW;
int kH, kW
, dil_kH, dil_kW
;
if (direction != 1) {
if (direction != 1) {
// weight is an input variable, we can just read its shape
// weight is an input variable, we can just read its shape
kH = PyArray_DIMS(weights)[2];
kH = PyArray_DIMS(weights)[2];
kW = PyArray_DIMS(weights)[3];
kW = PyArray_DIMS(weights)[3];
}
}
else {
else {
if (
(dH != 1) || (padH == -1)
) {
if (
%(height)
s != -1
) {
//
vertical subsampling or half padding, kernel height is specified
//
kernel height is specified (perhaps vertical subsampling or half padding)
kH =
%(height)
s;
kH =
%(height)
s;
}
}
else if (padH == -2) {
else if (padH == -2) {
...
@@ -287,7 +287,8 @@ class BaseCorrMM(gof.OpenMPOp):
...
@@ -287,7 +287,8 @@ class BaseCorrMM(gof.OpenMPOp):
// explicit padding, we can infer the kernel height
// explicit padding, we can infer the kernel height
kH = (PyArray_DIMS(bottom)[2] + 2*padH - (PyArray_DIMS(top)[2] - 1) * dH - 1) / dilH +1;
kH = (PyArray_DIMS(bottom)[2] + 2*padH - (PyArray_DIMS(top)[2] - 1) * dH - 1) / dilH +1;
}
}
if ((dW != 1) || (padW == -1)) {
if (
%(width)
s != -1) {
// kernel width is specified (perhaps horizontal subsampling or half padding)
kW =
%(width)
s;
kW =
%(width)
s;
}
}
else if (padW == -2) {
else if (padW == -2) {
...
@@ -299,8 +300,8 @@ class BaseCorrMM(gof.OpenMPOp):
...
@@ -299,8 +300,8 @@ class BaseCorrMM(gof.OpenMPOp):
}
}
// Implicit dilated kernel size
// Implicit dilated kernel size
int
dil_kH = (kH - 1) * dilH + 1;
dil_kH = (kH - 1) * dilH + 1;
int
dil_kW = (kW - 1) * dilW + 1;
dil_kW = (kW - 1) * dilW + 1;
// Auto-padding if requested
// Auto-padding if requested
if (padH == -1) { // vertical half padding
if (padH == -1) { // vertical half padding
...
@@ -334,6 +335,21 @@ class BaseCorrMM(gof.OpenMPOp):
...
@@ -334,6 +335,21 @@ class BaseCorrMM(gof.OpenMPOp):
out_dim[1] = (npy_intp)PyArray_DIMS(weights)[0];
out_dim[1] = (npy_intp)PyArray_DIMS(weights)[0];
out_dim[2] = (npy_intp)((PyArray_DIMS(bottom)[2] + 2*padH - ((PyArray_DIMS(weights)[2]-1)*dilH + 1)) / dH + 1);
out_dim[2] = (npy_intp)((PyArray_DIMS(bottom)[2] + 2*padH - ((PyArray_DIMS(weights)[2]-1)*dilH + 1)) / dH + 1);
out_dim[3] = (npy_intp)((PyArray_DIMS(bottom)[3] + 2*padW - ((PyArray_DIMS(weights)[3]-1)*dilW + 1)) / dW + 1);
out_dim[3] = (npy_intp)((PyArray_DIMS(bottom)[3] + 2*padW - ((PyArray_DIMS(weights)[3]-1)*dilW + 1)) / dW + 1);
if (out_dim[0] < 0 || out_dim[1] < 0 || out_dim[2] <= 0 || out_dim[3] <= 0)
{
PyErr_Format(PyExc_ValueError,
"CorrMM: impossible output shape
\\
n"
" bottom shape:
%%
ld x
%%
ld x
%%
ld x
%%
ld
\\
n"
" weights shape:
%%
ld x
%%
ld x
%%
ld x
%%
ld
\\
n"
" top shape:
%%
ld x
%%
ld x
%%
ld x
%%
ld
\\
n",
(long int)PyArray_DIMS(bottom)[0], (long int)PyArray_DIMS(bottom)[1],
(long int)PyArray_DIMS(bottom)[2], (long int)PyArray_DIMS(bottom)[3],
(long int)PyArray_DIMS(weights)[0], (long int)PyArray_DIMS(weights)[1],
(long int)PyArray_DIMS(weights)[2], (long int)PyArray_DIMS(weights)[3],
(long int)out_dim[0], (long int)out_dim[1], (long int)out_dim[2],
(long int)out_dim[3]);
%(fail)
s
}
break;
break;
case 1: // backprop wrt. weights
case 1: // backprop wrt. weights
// output is weights: (num_filters, num_channels, height, width)
// output is weights: (num_filters, num_channels, height, width)
...
@@ -342,14 +358,44 @@ class BaseCorrMM(gof.OpenMPOp):
...
@@ -342,14 +358,44 @@ class BaseCorrMM(gof.OpenMPOp):
out_dim[1] = (npy_intp)PyArray_DIMS(bottom)[1];
out_dim[1] = (npy_intp)PyArray_DIMS(bottom)[1];
out_dim[2] = (npy_intp)kH; // already inferred further above
out_dim[2] = (npy_intp)kH; // already inferred further above
out_dim[3] = (npy_intp)kW; // how convenient
out_dim[3] = (npy_intp)kW; // how convenient
if (out_dim[0] < 0 || out_dim[1] < 0 || out_dim[2] <= 0 || out_dim[3] <= 0)
{
PyErr_Format(PyExc_ValueError,
"CorrMM backprop wrt. weights: impossible output shape
\\
n"
" bottom shape:
%%
ld x
%%
ld x
%%
ld x
%%
ld
\\
n"
" weights shape:
%%
ld x
%%
ld x
%%
ld x
%%
ld
\\
n"
" top shape:
%%
ld x
%%
ld x
%%
ld x
%%
ld
\\
n",
(long int)PyArray_DIMS(bottom)[0], (long int)PyArray_DIMS(bottom)[1],
(long int)PyArray_DIMS(bottom)[2], (long int)PyArray_DIMS(bottom)[3],
(long int)out_dim[0], (long int)out_dim[1], (long int)out_dim[2],
(long int)out_dim[3],
(long int)PyArray_DIMS(top)[0], (long int)PyArray_DIMS(top)[1],
(long int)PyArray_DIMS(top)[2], (long int)PyArray_DIMS(top)[3]);
%(fail)
s
}
break;
break;
case 2: // backprop wrt. inputs
case 2: // backprop wrt. inputs
// output is bottom: (batchsize, num_channels, height, width)
// output is bottom: (batchsize, num_channels, height, width)
// height and width: bottom = (top - 1) * sample + (weights-1)*dil + 1 - 2*pad
// height and width: bottom = (top - 1) * sample + (weights-1)*dil + 1 - 2*pad
out_dim[0] = (npy_intp)PyArray_DIMS(top)[0];
out_dim[0] = (npy_intp)PyArray_DIMS(top)[0];
out_dim[1] = (npy_intp)PyArray_DIMS(weights)[1];
out_dim[1] = (npy_intp)PyArray_DIMS(weights)[1];
out_dim[2] = (npy_intp)((dH != 1) ?
%(height)
s : (PyArray_DIMS(top)[2] - 1) * dH + (PyArray_DIMS(weights)[2]-1)*dilH + 1 - 2*padH);
out_dim[2] = (npy_intp)((
%(height)
s != -1) ?
%(height)
s : (PyArray_DIMS(top)[2] - 1) * dH + (PyArray_DIMS(weights)[2]-1)*dilH + 1 - 2*padH);
out_dim[3] = (npy_intp)((dW != 1) ?
%(width)
s : (PyArray_DIMS(top)[3] - 1) * dW + (PyArray_DIMS(weights)[3]-1)*dilW + 1 - 2*padW);
out_dim[3] = (npy_intp)((
%(width)
s != -1) ?
%(width)
s : (PyArray_DIMS(top)[3] - 1) * dW + (PyArray_DIMS(weights)[3]-1)*dilW + 1 - 2*padW);
if (out_dim[0] < 0 || out_dim[1] < 0 || out_dim[2] <= 0 || out_dim[3] <= 0)
{
PyErr_Format(PyExc_ValueError,
"CorrMM backprop wrt. inputs: impossible output shape
\\
n"
" bottom shape:
%%
ld x
%%
ld x
%%
ld x
%%
ld
\\
n"
" weights shape:
%%
ld x
%%
ld x
%%
ld x
%%
ld
\\
n"
" top shape:
%%
ld x
%%
ld x
%%
ld x
%%
ld
\\
n",
(long int)out_dim[0], (long int)out_dim[1], (long int)out_dim[2],
(long int)out_dim[3],
(long int)PyArray_DIMS(weights)[0], (long int)PyArray_DIMS(weights)[1],
(long int)PyArray_DIMS(weights)[2], (long int)PyArray_DIMS(weights)[3],
(long int)PyArray_DIMS(top)[0], (long int)PyArray_DIMS(top)[1],
(long int)PyArray_DIMS(top)[2], (long int)PyArray_DIMS(top)[3]);
%(fail)
s
}
break;
break;
default:
default:
PyErr_SetString(PyExc_ValueError, "BaseCorrMM: direction must be 0, 1, or 2
\\
n");
PyErr_SetString(PyExc_ValueError, "BaseCorrMM: direction must be 0, 1, or 2
\\
n");
...
@@ -491,13 +537,13 @@ class CorrMM_gradWeights(BaseCorrMM):
...
@@ -491,13 +537,13 @@ class CorrMM_gradWeights(BaseCorrMM):
raise
TypeError
(
'img must be 4D tensor'
)
raise
TypeError
(
'img must be 4D tensor'
)
if
topgrad
.
type
.
ndim
!=
4
:
if
topgrad
.
type
.
ndim
!=
4
:
raise
TypeError
(
'topgrad must be 4D tensor'
)
raise
TypeError
(
'topgrad must be 4D tensor'
)
if
s
elf
.
subsample
!=
(
1
,
1
)
or
self
.
border_mode
==
"half"
:
if
s
hape
is
None
:
if
s
hape
is
None
:
if
s
elf
.
subsample
!=
(
1
,
1
)
or
self
.
border_mode
==
"half"
:
raise
ValueError
(
'shape must be given if subsample != (1, 1)'
raise
ValueError
(
'shape must be given if subsample != (1, 1)'
' or border_mode == "half"'
)
' or border_mode == "half"'
)
height_width
=
[
as_tensor_variable
(
shape
[
0
])
.
astype
(
'int64'
),
as_tensor_variable
(
shape
[
1
])
.
astype
(
'int64'
)]
else
:
height_width
=
[]
height_width
=
[]
else
:
height_width
=
[
as_tensor_variable
(
shape
[
0
])
.
astype
(
'int64'
),
as_tensor_variable
(
shape
[
1
])
.
astype
(
'int64'
)]
broadcastable
=
[
topgrad
.
type
.
broadcastable
[
1
],
img
.
type
.
broadcastable
[
1
],
broadcastable
=
[
topgrad
.
type
.
broadcastable
[
1
],
img
.
type
.
broadcastable
[
1
],
False
,
False
]
False
,
False
]
...
@@ -588,9 +634,13 @@ class CorrMM_gradInputs(BaseCorrMM):
...
@@ -588,9 +634,13 @@ class CorrMM_gradInputs(BaseCorrMM):
raise
TypeError
(
'kern must be 4D tensor'
)
raise
TypeError
(
'kern must be 4D tensor'
)
if
topgrad
.
type
.
ndim
!=
4
:
if
topgrad
.
type
.
ndim
!=
4
:
raise
TypeError
(
'topgrad must be 4D tensor'
)
raise
TypeError
(
'topgrad must be 4D tensor'
)
if
self
.
subsample
!=
(
1
,
1
)
and
shape
is
None
:
if
shape
is
None
:
raise
ValueError
(
'shape must be given if subsample != (1, 1)'
)
if
self
.
subsample
!=
(
1
,
1
):
height_width
=
[
as_tensor_variable
(
shape
[
0
])
.
astype
(
'int64'
),
as_tensor_variable
(
shape
[
1
])
.
astype
(
'int64'
)]
if
self
.
subsample
!=
(
1
,
1
)
else
[]
raise
ValueError
(
'shape must be given if subsample != (1, 1)'
)
height_width
=
[]
else
:
height_width
=
[
as_tensor_variable
(
shape
[
0
])
.
astype
(
'int64'
),
as_tensor_variable
(
shape
[
1
])
.
astype
(
'int64'
)]
broadcastable
=
[
topgrad
.
type
.
broadcastable
[
0
],
kern
.
type
.
broadcastable
[
1
],
broadcastable
=
[
topgrad
.
type
.
broadcastable
[
0
],
kern
.
type
.
broadcastable
[
1
],
False
,
False
]
False
,
False
]
...
...
theano/tensor/nnet/corr3d.py
浏览文件 @
c072d669
差异被折叠。
点击展开。
theano/tensor/nnet/corr3d_gemm.c
浏览文件 @
c072d669
...
@@ -188,9 +188,17 @@ PyArrayObject* corr3dMM(PyArrayObject* bottom,
...
@@ -188,9 +188,17 @@ PyArrayObject* corr3dMM(PyArrayObject* bottom,
const
int
dil_kW
=
(
kW
-
1
)
*
dilW
+
1
;
const
int
dil_kW
=
(
kW
-
1
)
*
dilW
+
1
;
const
int
dil_kD
=
(
kD
-
1
)
*
dilD
+
1
;
const
int
dil_kD
=
(
kD
-
1
)
*
dilD
+
1
;
// top: (batchSize, nFilters, topHeight, topWidth, topDepth)
// top: (batchSize, nFilters, topHeight, topWidth, topDepth)
const
int
topHeight
=
(
bottomHeight
+
2
*
padH
-
dil_kH
)
/
dH
+
1
;
const
int
topHeightNoDH
=
(
bottomHeight
+
2
*
padH
-
dil_kH
);
const
int
topWidth
=
(
bottomWidth
+
2
*
padW
-
dil_kW
)
/
dW
+
1
;
const
int
topWidthNoDW
=
(
bottomWidth
+
2
*
padW
-
dil_kW
);
const
int
topDepth
=
(
bottomDepth
+
2
*
padD
-
dil_kD
)
/
dD
+
1
;
const
int
topDepthNoDD
=
(
bottomDepth
+
2
*
padD
-
dil_kD
);
// the above values might be negative so we need to use Python-like
// flooring integer division to be compatible with get_conv_output.
// note: this macro implements Python's // for negative x only
#define _CONV_FLOORDIV_X(x,y) ((x < 0) ? (- ((-x) / y) - (((-x) %% y) == 0 ? 0 : 1)) : (x / y))
const
int
topHeight
=
_CONV_FLOORDIV_X
(
topHeightNoDH
,
dH
)
+
1
;
const
int
topWidth
=
_CONV_FLOORDIV_X
(
topWidthNoDW
,
dW
)
+
1
;
const
int
topDepth
=
_CONV_FLOORDIV_X
(
topDepthNoDD
,
dD
)
+
1
;
#undef _CONV_FLOORDIV
if
(
batchSize
!=
PyArray_DIMS
(
top
)[
0
]
||
if
(
batchSize
!=
PyArray_DIMS
(
top
)[
0
]
||
nFilters
!=
PyArray_DIMS
(
top
)[
1
]
||
nFilters
!=
PyArray_DIMS
(
top
)[
1
]
||
topHeight
!=
PyArray_DIMS
(
top
)[
2
]
||
topHeight
!=
PyArray_DIMS
(
top
)[
2
]
||
...
@@ -245,7 +253,23 @@ PyArrayObject* corr3dMM(PyArrayObject* bottom,
...
@@ -245,7 +253,23 @@ PyArrayObject* corr3dMM(PyArrayObject* bottom,
char
Trans
=
'T'
;
char
Trans
=
'T'
;
PyArrayObject
*
output
;
PyArrayObject
*
output
;
if
(
direction
==
0
)
{
// forward pass
if
(
batchSize
==
0
||
nChannels
==
0
||
nFilters
==
0
)
{
switch
(
direction
)
{
case
0
:
output
=
top
;
break
;
case
1
:
output
=
weight
;
break
;
case
2
:
output
=
bottom
;
break
;
default:
return
NULL
;
}
PyArray_FILLWBYTE
(
output
,
0
);
}
else
if
(
direction
==
0
)
{
// forward pass
output
=
top
;
output
=
top
;
// valid correlation: im3d2col, then gemm
// valid correlation: im3d2col, then gemm
// Iterate over batch
// Iterate over batch
...
...
theano/tensor/nnet/corr_gemm.c
浏览文件 @
c072d669
...
@@ -164,8 +164,15 @@ PyArrayObject* corrMM(PyArrayObject* bottom,
...
@@ -164,8 +164,15 @@ PyArrayObject* corrMM(PyArrayObject* bottom,
const
int
dil_kH
=
(
kH
-
1
)
*
dilH
+
1
;
const
int
dil_kH
=
(
kH
-
1
)
*
dilH
+
1
;
const
int
dil_kW
=
(
kW
-
1
)
*
dilW
+
1
;
const
int
dil_kW
=
(
kW
-
1
)
*
dilW
+
1
;
// top: (batchSize, nFilters, topHeight, topWidth)
// top: (batchSize, nFilters, topHeight, topWidth)
const
int
topHeight
=
(
bottomHeight
+
2
*
padH
-
dil_kH
)
/
dH
+
1
;
const
int
topHeightNoDH
=
(
bottomHeight
+
2
*
padH
-
dil_kH
);
const
int
topWidth
=
(
bottomWidth
+
2
*
padW
-
dil_kW
)
/
dW
+
1
;
const
int
topWidthNoDW
=
(
bottomWidth
+
2
*
padW
-
dil_kW
);
// the above values might be negative so we need to use Python-like
// flooring integer division to be compatible with get_conv_output.
// note: this macro implements Python's // for negative x only
#define _CONV_FLOORDIV_X(x,y) ((x < 0) ? (- ((-x) / y) - (((-x) %% y) == 0 ? 0 : 1)) : (x / y))
const
int
topHeight
=
_CONV_FLOORDIV_X
(
topHeightNoDH
,
dH
)
+
1
;
const
int
topWidth
=
_CONV_FLOORDIV_X
(
topWidthNoDW
,
dW
)
+
1
;
#undef _CONV_FLOORDIV
if
(
batchSize
!=
PyArray_DIMS
(
top
)[
0
]
||
if
(
batchSize
!=
PyArray_DIMS
(
top
)[
0
]
||
nFilters
!=
PyArray_DIMS
(
top
)[
1
]
||
nFilters
!=
PyArray_DIMS
(
top
)[
1
]
||
topHeight
!=
PyArray_DIMS
(
top
)[
2
]
||
topHeight
!=
PyArray_DIMS
(
top
)[
2
]
||
...
@@ -219,7 +226,23 @@ PyArrayObject* corrMM(PyArrayObject* bottom,
...
@@ -219,7 +226,23 @@ PyArrayObject* corrMM(PyArrayObject* bottom,
char
Trans
=
'T'
;
char
Trans
=
'T'
;
PyArrayObject
*
output
;
PyArrayObject
*
output
;
if
(
direction
==
0
)
{
// forward pass
if
(
batchSize
==
0
||
nChannels
==
0
||
nFilters
==
0
)
{
switch
(
direction
)
{
case
0
:
output
=
top
;
break
;
case
1
:
output
=
weight
;
break
;
case
2
:
output
=
bottom
;
break
;
default:
return
NULL
;
}
PyArray_FILLWBYTE
(
output
,
0
);
}
else
if
(
direction
==
0
)
{
// forward pass
output
=
top
;
output
=
top
;
// valid correlation: im2col, then gemm
// valid correlation: im2col, then gemm
// Iterate over batch
// Iterate over batch
...
...
theano/tensor/nnet/tests/test_abstract_conv.py
浏览文件 @
c072d669
差异被折叠。
点击展开。
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论