Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
59dcaf9c
提交
59dcaf9c
authored
8月 27, 2015
作者:
Arnaud Bergeron
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Convert GpuDnnConv to v3 (and use the new facilities).
上级
88ed910a
隐藏空白字符变更
内嵌
并排
正在显示
4 个修改的文件
包含
288 行增加
和
80 行删除
+288
-80
dnn.py
theano/sandbox/gpuarray/dnn.py
+113
-51
dnn_base.c
theano/sandbox/gpuarray/dnn_base.c
+45
-18
dnn_conv_base.c
theano/sandbox/gpuarray/dnn_conv_base.c
+2
-2
dnn_fwd.c
theano/sandbox/gpuarray/dnn_fwd.c
+128
-9
没有找到文件。
theano/sandbox/gpuarray/dnn.py
浏览文件 @
59dcaf9c
...
...
@@ -348,55 +348,97 @@ class GpuDnnConv(DnnBase, COp):
kernel
descr
The convolution descriptor.
workmem
Either 'none', 'small' or 'large'. Default is the value of
:attr:`config.dnn.conv.workmem`.
algo : {'small', 'none', 'large', 'fft', 'guess_once', 'guess_on_shape_change', 'time_once', 'time_on_shape_change'}
Default is the value of :attr:`config.dnn.conv.algo_fwd`.
"""
__props__
=
(
'
workmem
'
,
'inplace'
)
__props__
=
(
'
algo
'
,
'inplace'
)
def
__init__
(
self
,
workmem
=
None
,
inplace
=
False
):
def
__init__
(
self
,
algo
=
None
,
inplace
=
False
):
COp
.
__init__
(
self
,
[
"dnn_base.c"
,
"dnn_conv_base.c"
,
"dnn_fwd.c"
],
"APPLY_SPECIFIC(conv_fwd)"
)
if
workmem
is
None
:
workmem
=
config
.
dnn
.
conv
.
workmem
self
.
workmem
=
workmem
if
algo
is
None
:
algo
=
config
.
dnn
.
conv
.
algo_fwd
self
.
algo
=
algo
self
.
inplace
=
inplace
if
self
.
inplace
:
self
.
destroy_map
=
{
0
:
[
2
]}
assert
self
.
workmem
in
[
'none'
,
'small'
,
'large'
]
if
version
()
<
3000
:
if
self
.
algo
==
'fft'
:
raise
RuntimeError
(
"CuDNN FFT convolution requires CuDNN v3"
)
elif
self
.
algo
in
[
'guess_once'
,
'guess_on_shape_change'
]:
raise
RuntimeError
(
"CuDNN selection of convolution "
"implementation based on heuristics "
"requires CuDNN v3"
)
elif
self
.
algo
in
[
'time_once'
,
'time_on_shape_change'
]:
raise
RuntimeError
(
"CuDNN convolution timing requires CuDNN v3"
)
assert
self
.
algo
in
[
'none'
,
'small'
,
'large'
,
'fft'
,
'guess_once'
,
'guess_on_shape_change'
,
'time_once'
,
'time_on_shape_change'
]
def
__setstate__
(
self
,
d
):
self
.
__dict__
.
update
(
d
)
if
not
hasattr
(
self
,
'algo'
):
if
hasattr
(
self
,
'workmem'
):
self
.
algo
=
self
.
workmem
else
:
self
.
algo
=
config
.
dnn
.
conv
.
algo_fwd
if
not
hasattr
(
self
,
'inplace'
):
self
.
inplace
=
False
def
get_op_params
(
self
):
defs
=
[]
if
self
.
inplace
:
inpl_def
=
[(
'CONV_INPLACE'
,
'1'
)]
else
:
inpl_def
=
[]
if
version
()
==
-
1
:
alg_def
=
(
'CONV_ALGO'
,
"0"
)
else
:
if
self
.
workmem
==
'none'
:
alg
=
'CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM'
elif
self
.
workmem
==
'small'
:
alg
=
'CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM'
elif
self
.
workmem
==
'large'
:
alg
=
'CUDNN_CONVOLUTION_FWD_ALGO_GEMM'
alg_def
=
(
'CONV_ALGO'
,
alg
)
return
[
alg_def
]
+
inpl_def
defs
.
append
((
'CONV_INPLACE'
,
'1'
))
alg
=
'CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM'
if
self
.
algo
==
'none'
:
alg
=
'CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM'
elif
self
.
algo
==
'small'
:
alg
=
'CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM'
elif
self
.
algo
==
'large'
:
alg
=
'CUDNN_CONVOLUTION_FWD_ALGO_GEMM'
elif
self
.
algo
==
'fft'
:
alg
=
'CUDNN_CONVOLUTION_FWD_ALGO_FFT'
defs
.
append
((
'CONV_ALGO'
,
alg
))
if
self
.
algo
in
[
'guess_once'
,
'guess_on_shape_change'
,
'time_once'
,
'time_on_shape_change'
]:
defs
.
append
((
'CHOOSE_ALGO'
,
''
))
if
self
.
algo
in
[
'guess_once'
,
'time_once'
]:
defs
.
append
((
'CHOOSE_ONCE'
,
''
))
if
self
.
algo
in
[
'time_once'
,
'time_on_shape_change'
]:
defs
.
append
((
'CHOOSE_TIME'
,
''
))
return
defs
def
make_node
(
self
,
img
,
kern
,
output
,
desc
,
alpha
=
None
,
beta
=
None
):
img
=
as_gpuarray_variable
(
img
)
kern
=
as_gpuarray_variable
(
kern
)
output
=
as_gpuarray_variable
(
output
)
if
img
.
type
.
ndim
!=
4
:
raise
TypeError
(
'img must be 4D tensor'
)
if
kern
.
type
.
ndim
!=
4
:
raise
TypeError
(
'kern must be 4D tensor'
)
if
output
.
type
.
ndim
!=
4
:
raise
TypeError
(
'output must be a 4D tensor'
)
if
not
isinstance
(
desc
.
type
,
CDataType
)
\
or
desc
.
type
.
ctype
!=
'cudnnConvolutionDescriptor_t'
:
if
img
.
type
.
ndim
not
in
(
4
,
5
):
raise
TypeError
(
'img must be 4D or 5D tensor'
)
if
kern
.
type
.
ndim
not
in
(
4
,
5
):
raise
TypeError
(
'kern must be 4D or 5D tensor'
)
if
output
.
type
.
ndim
not
in
(
4
,
5
):
raise
TypeError
(
'output must be a 4D or 5D tensor'
)
if
(
img
.
type
.
ndim
!=
kern
.
type
.
ndim
or
img
.
type
.
ndim
!=
output
.
type
.
ndim
):
raise
TypeError
(
"The number of dimensions of "
"img, kern and output must match"
)
if
img
.
type
.
ndim
==
5
and
self
.
algo
==
'fft'
:
raise
ValueError
(
"convolution algo fft can't be used for "
"3d convolutions"
)
if
(
not
isinstance
(
desc
.
type
,
CDataType
)
or
desc
.
type
.
ctype
!=
'cudnnConvolutionDescriptor_t'
):
raise
TypeError
(
'desc must be cudnnConvolutionDescriptor_t'
)
alpha
=
ensure_dt
(
alpha
,
_one
,
'alpha'
,
img
.
dtype
)
...
...
@@ -438,22 +480,41 @@ class GpuDnnConv(DnnBase, COp):
kh
=
kshape
[
2
]
# Height of each filter
kw
=
kshape
[
3
]
# Width of each filter
sh
,
sw
=
subsample
nd
=
len
(
subsample
)
if
nd
>
2
:
d
=
ishape
[
4
]
kd
=
ishape
[
4
]
sh
=
subsample
[
0
]
sw
=
subsample
[
1
]
if
nd
>
2
:
sd
=
subsample
[
2
]
if
border_mode
==
'full'
:
padh
=
kh
-
1
padw
=
kw
-
1
if
nd
>
4
:
padd
=
kd
-
1
elif
isinstance
(
border_mode
,
tuple
):
padh
,
padw
=
border_mode
padh
=
border_mode
[
0
]
padw
=
border_mode
[
1
]
if
nd
>
2
:
padd
=
border_mode
[
2
]
else
:
assert
border_mode
==
'valid'
padh
=
0
padw
=
0
padd
=
0
res
=
[
b
,
nb
,
(
h
+
2
*
padh
-
kh
)
//
sh
+
1
,
(
w
+
2
*
padw
-
kw
)
//
sw
+
1
]
if
nd
>
2
:
res
.
append
(
d
+
2
*
padd
-
kd
//
sd
+
1
)
return
(
b
,
nb
,
(
h
+
2
*
padh
-
kh
)
//
sh
+
1
,
(
w
+
2
*
padw
-
kw
)
//
sw
+
1
)
return
res
def
infer_shape
(
self
,
node
,
shape
):
return
[
shape
[
2
]]
...
...
@@ -607,7 +668,8 @@ class GpuDnnConvGradI(DnnBase):
def
dnn_conv
(
img
,
kerns
,
border_mode
=
'valid'
,
subsample
=
(
1
,
1
),
conv_mode
=
'conv'
,
direction_hint
=
None
,
workmem
=
None
):
conv_mode
=
'conv'
,
direction_hint
=
None
,
workmem
=
None
,
algo
=
None
):
"""
GPU convolution using cuDNN from NVIDIA.
...
...
@@ -631,19 +693,19 @@ def dnn_conv(img, kerns, border_mode='valid', subsample=(1, 1),
direction_hint
Used by graph optimizers to change algorithm choice.
By default, GpuDnnConv will be used to carry out the convolution.
If border_mode is 'valid', subsample is (1,1) and direction_hint is
If border_mode is 'valid', subsample is (1,
1) and direction_hint is
'bprop weights', it will use GpuDnnConvGradW.
If border_mode is 'full', subsample is (1,1) and direction_hint is
If border_mode is 'full', subsample is (1,
1) and direction_hint is
*not* 'forward!', it will use GpuDnnConvGradI.
This parameter is used internally by graph optimizers and may be
removed at any time without a deprecation period. You have been warned.
workmem
Specify the amount of working memory allowed. More memory is usuall
y
faster. One of 'none', 'small' or 'large' (default is None which take
s
its value from :attr:`config.dnn.conv.workmem`)
.
algo : {'none', 'small', 'large', 'fft', 'guess_once', 'guess_on_shape_change', 'time_once', 'time_on_shape_change'}
Convolution implementation to use. Some of its values ma
y
require certain versions of CuDNN to be installed. Default i
s
the value of :attr:`config.dnn.conv.algo_fwd`
.
.. warning:: The cuDNN library only works with GPU that have a compute
capability of 3.0 or higer.
This means that older GPU
will not
.. warning:: The cuDNN library only works with GPU
s
that have a compute
capability of 3.0 or higer.
This means that older GPUs
will not
work with this Op.
"""
...
...
@@ -696,7 +758,7 @@ def dnn_conv(img, kerns, border_mode='valid', subsample=(1, 1),
desc_op
.
border_mode
,
desc_op
.
subsample
)
out
=
GpuAllocEmpty
(
img
.
dtype
)(
*
out_shp
)
return
GpuDnnConv
(
workmem
=
workmem
)(
img
,
kerns
,
out
,
desc
)
return
GpuDnnConv
(
algo
=
algo
)(
img
,
kerns
,
out
,
desc
)
class
GpuDnnPoolDesc
(
Op
):
...
...
@@ -1563,7 +1625,7 @@ def local_dnn_conv_inplace(node):
isinstance
(
dest
.
owner
.
op
,
GpuAllocEmpty
)
and
len
(
dest
.
clients
)
>
1
):
inputs
[
2
]
=
GpuAllocEmpty
(
dest
.
owner
.
op
.
dtype
)(
*
dest
.
owner
.
inputs
)
return
[
GpuDnnConv
(
workmem
=
node
.
op
.
workmem
,
inplace
=
True
)(
*
inputs
)]
return
[
GpuDnnConv
(
algo
=
node
.
op
.
algo
,
inplace
=
True
)(
*
inputs
)]
@local_optimizer
([
GpuDnnConvGradW
],
inplace
=
True
)
...
...
@@ -1604,7 +1666,7 @@ optdb.register('local_dnna_conv_inplace',
def
local_dnn_conv_alpha_merge
(
node
,
*
inputs
):
if
not
dnn_available
()
or
version
()
==
-
1
:
return
None
return
[
GpuDnnConv
(
workmem
=
node
.
op
.
workmem
)(
*
inputs
)]
return
[
GpuDnnConv
(
algo
=
node
.
op
.
algo
)(
*
inputs
)]
@register_opt
(
'cudnn'
)
...
...
theano/sandbox/gpuarray/dnn_base.c
浏览文件 @
59dcaf9c
...
...
@@ -2,7 +2,7 @@
static
cudnnHandle_t
_handle
=
NULL
;
static
int
c_set_tensor
4
d
(
PyGpuArrayObject
*
var
,
cudnnTensorDescriptor_t
desc
)
{
c_set_tensor
N
d
(
PyGpuArrayObject
*
var
,
cudnnTensorDescriptor_t
desc
)
{
cudnnDataType_t
dt
;
size_t
ds
;
switch
(
var
->
ga
.
typecode
)
{
...
...
@@ -12,26 +12,37 @@ c_set_tensor4d(PyGpuArrayObject *var, cudnnTensorDescriptor_t desc) {
case
GA_DOUBLE
:
dt
=
CUDNN_DATA_DOUBLE
;
break
;
#ifdef CUDNN_VERSION > 3000
case
GA_HALF
:
dt
=
CUDNN_DATA_HALF
;
break
;
#endif
default:
PyErr_SetString
(
PyExc_TypeError
,
"Non-float datatype in c_set_tensor
4
d"
);
PyErr_SetString
(
PyExc_TypeError
,
"Non-float datatype in c_set_tensor
N
d"
);
return
-
1
;
}
ds
=
gpuarray_get_elsize
(
var
->
ga
.
typecode
);
int
str0
,
str1
,
str2
,
str3
;
// cudnn do not like 0s in strides
str3
=
PyGpuArray_STRIDES
(
var
)[
3
]
?
PyGpuArray_STRIDES
(
var
)[
3
]
/
ds
:
1
;
str2
=
PyGpuArray_STRIDES
(
var
)[
2
]
?
PyGpuArray_STRIDES
(
var
)[
2
]
/
ds
:
PyGpuArray_DIMS
(
var
)[
3
];
str1
=
PyGpuArray_STRIDES
(
var
)[
1
]
?
PyGpuArray_STRIDES
(
var
)[
1
]
/
ds
:
PyGpuArray_DIMS
(
var
)[
2
]
*
PyGpuArray_DIMS
(
var
)[
3
];
str0
=
PyGpuArray_STRIDES
(
var
)[
0
]
?
PyGpuArray_STRIDES
(
var
)[
0
]
/
ds
:
PyGpuArray_DIMS
(
var
)[
2
]
*
PyGpuArray_DIMS
(
var
)[
3
]
*
PyGpuArray_DIMS
(
var
)[
1
];
cudnnStatus_t
err
=
cudnnSetTensor4dDescriptorEx
(
desc
,
dt
,
PyGpuArray_DIM
(
var
,
0
),
PyGpuArray_DIM
(
var
,
1
),
PyGpuArray_DIM
(
var
,
2
),
PyGpuArray_DIM
(
var
,
3
),
str0
,
str1
,
str2
,
str3
);
int
strs
[
5
],
dims
[
5
],
default_stride
=
1
;
unsigned
int
nd
=
PyGpuArray_NDIM
(
var
);
if
(
nd
>
5
)
{
PyErr_SetString
(
PyExc_TypeError
,
"Tensor of more than 5d"
);
return
-
1
;
}
for
(
unsigned
int
_i
=
nd
;
_i
>
0
;
_i
--
)
{
unsigned
int
i
=
_i
-
1
;
strs
[
i
]
=
PyGpuArray_STRIDE
(
var
,
i
)
?
PyGpuArray_STRIDE
(
var
,
i
)
/
ds
:
default_stride
;
default_stride
*=
PyGpuArray_DIM
(
var
,
i
);
dims
[
i
]
=
PyGpuArray_DIM
(
var
,
i
);
}
cudnnStatus_t
err
=
cudnnSetTensorNdDescriptor
(
desc
,
dt
,
nd
,
dims
,
strs
);
if
(
err
!=
CUDNN_STATUS_SUCCESS
)
{
PyErr_Format
(
PyExc_RuntimeError
,
"Could not set tensor
4
d descriptor: %s"
,
"Could not set tensor
N
d descriptor: %s"
,
cudnnGetErrorString
(
err
));
return
-
1
;
}
...
...
@@ -53,14 +64,30 @@ c_set_filter(PyGpuArrayObject *var, cudnnFilterDescriptor_t desc) {
case
GA_DOUBLE
:
dt
=
CUDNN_DATA_DOUBLE
;
break
;
#ifdef CUDNN_VERSION > 3000
case
GA_HALF
:
dt
=
CUDNN_DATA_HALF
;
break
;
#endif
default:
PyErr_SetString
(
PyExc_TypeError
,
"Non-float datatype in c_set_filter"
);
return
-
1
;
}
cudnnStatus_t
err
=
cudnnSetFilter4dDescriptor
(
desc
,
dt
,
PyGpuArray_DIMS
(
var
)[
0
],
PyGpuArray_DIMS
(
var
)[
1
],
PyGpuArray_DIMS
(
var
)[
2
],
PyGpuArray_DIMS
(
var
)[
3
]);
int
dims
[
5
];
unsigned
int
nd
=
PyGpuArray_NDIM
(
var
);
if
(
nd
>
5
)
{
PyErr_SetString
(
PyExc_TypeError
,
"Tensor of more than 5d"
);
return
-
1
;
}
for
(
unsigned
int
_i
=
nd
;
_i
>
0
;
_i
--
)
{
unsigned
int
i
=
_i
-
1
;
dims
[
i
]
=
PyGpuArray_DIM
(
var
,
i
);
}
cudnnStatus_t
err
=
cudnnSetFilterNdDescriptor
(
desc
,
dt
,
nd
,
dims
);
if
(
err
!=
CUDNN_STATUS_SUCCESS
)
{
PyErr_Format
(
PyExc_RuntimeError
,
"Could not set filter descriptor: %s."
,
...
...
theano/sandbox/gpuarray/dnn_conv_base.c
浏览文件 @
59dcaf9c
...
...
@@ -10,12 +10,12 @@ APPLY_SPECIFIC(input) = NULL;
APPLY_SPECIFIC
(
output
)
=
NULL
;
APPLY_SPECIFIC
(
kerns
)
=
NULL
;
if
((
APPLY_SPECIFIC
(
err
)
=
cudnnCreateTensorDescriptor
(
&
APPLY_SPECIFIC
(
input
)))
!=
CUDNN_STATUS_SUCCESS
)
{
PyErr_Format
(
PyExc_MemoryError
,
"could not allocate tensor
4d
descriptor "
PyErr_Format
(
PyExc_MemoryError
,
"could not allocate tensor descriptor "
"(inp): %s"
,
cudnnGetErrorString
(
APPLY_SPECIFIC
(
err
)));
FAIL
;
}
if
((
APPLY_SPECIFIC
(
err
)
=
cudnnCreateTensorDescriptor
(
&
APPLY_SPECIFIC
(
output
)))
!=
CUDNN_STATUS_SUCCESS
)
{
PyErr_Format
(
PyExc_MemoryError
,
"could not allocate tensor
4d
descriptor "
PyErr_Format
(
PyExc_MemoryError
,
"could not allocate tensor descriptor "
"(out): %s"
,
cudnnGetErrorString
(
APPLY_SPECIFIC
(
err
)));
FAIL
;
}
...
...
theano/sandbox/gpuarray/dnn_fwd.c
浏览文件 @
59dcaf9c
...
...
@@ -13,11 +13,11 @@ APPLY_SPECIFIC(conv_fwd)(PyGpuArrayObject *input, PyGpuArrayObject *kerns,
if
(
PyGpuArray_DIMS
(
input
)[
1
]
!=
PyGpuArray_DIMS
(
kerns
)[
1
])
{
PyErr_SetString
(
PyExc_ValueError
,
"
GpuDnnConv
images and kernel must have the same stack size"
);
"images and kernel must have the same stack size"
);
return
1
;
}
if
(
c_set_tensor
4
d
(
input
,
APPLY_SPECIFIC
(
input
))
==
-
1
)
if
(
c_set_tensor
N
d
(
input
,
APPLY_SPECIFIC
(
input
))
==
-
1
)
return
1
;
if
(
c_set_filter
(
kerns
,
APPLY_SPECIFIC
(
kerns
))
==
-
1
)
return
1
;
...
...
@@ -28,6 +28,7 @@ APPLY_SPECIFIC(conv_fwd)(PyGpuArrayObject *input, PyGpuArrayObject *kerns,
beta_p
=
(
void
*
)
&
beta
;
break
;
case
GA_FLOAT
:
case
GA_HALF
:
alpha_p
=
(
void
*
)
&
af
;
beta_p
=
(
void
*
)
&
bf
;
break
;
...
...
@@ -49,29 +50,147 @@ APPLY_SPECIFIC(conv_fwd)(PyGpuArrayObject *input, PyGpuArrayObject *kerns,
return
1
;
#endif
if
(
c_set_tensor
4
d
(
*
output
,
APPLY_SPECIFIC
(
output
))
==
-
1
)
if
(
c_set_tensor
N
d
(
*
output
,
APPLY_SPECIFIC
(
output
))
==
-
1
)
return
1
;
cudnnConvolutionFwdAlgo_t
algo
=
CONV_ALGO
;
#ifdef CHOOSE_ALGO
/* Static variables are only initialized once so this will not
* reset the previous algo every time */
static
int
reuse_algo
=
0
;
static
cudnnConvolutionFwdAlgo_t
prev_algo
=
CONV_ALGO
;
#ifndef CHOOSE_ONCE
static
size_t
prev_img_dims
[
5
]
=
{
0
};
static
size_t
prev_kern_dims
[
5
]
=
{
0
};
reuse_algo
=
1
;
for
(
unsigned
int
i
=
0
;
i
<
PyGpuArray_NDIM
(
input
);
i
++
)
{
reuse_algo
=
(
reuse_algo
&&
PyGpuArray_DIM
(
input
,
i
)
==
prev_img_dims
[
i
]);
reuse_algo
=
(
reuse_algo
&&
PyGpuArray_DIM
(
kerns
,
i
)
==
prev_kern_dims
[
i
]);
}
#endif
if
(
!
reuse_algo
)
{
#ifdef CHOOSE_TIME
int
count
;
cudnnConvolutionFwdAlgoPerf_t
choice
;
err
=
cudnnFindConvolutionForwardAlgorithm
(
_handle
,
APPLY_SPECIFIC
(
input
),
APPLY_SPECIFIC
(
kerns
),
desc
,
APPLY_SPECIFIC
(
output
),
1
,
&
count
,
&
choice
);
if
(
err
!=
CUDNN_STATUS_SUCCESS
)
{
PyErr_Format
(
PyExc_RuntimeError
,
"error selecting convolution algo: %s"
,
cudnnGetErrorString
(
err
));
return
1
;
}
algo
=
choice
.
algo
;
#else
size_t
free
=
0
,
total
=
0
;
cudaError_t
err2
=
cudaMemGetInfo
(
&
free
,
&
total
);
if
(
err2
!=
cudaSuccess
)
{
PyErr_Format
(
PyExc_RuntimeError
,
"Error when trying to find the "
"memory information on the GPU: %s
\n
"
,
cudaGetErrorString
(
err2
));
return
1
;
}
err
=
cudnnGetConvolutionForwardAlgorithm
(
_handle
,
APPLY_SPECIFIC
(
input
),
APPLY_SPECIFIC
(
kerns
),
desc
,
APPLY_SPECIFIC
(
output
),
CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT
,
free
,
&
algo
);
if
(
err
!=
CUDNN_STATUS_SUCCESS
)
{
PyErr_Format
(
PyExc_RuntimeError
,
"error selecting convolution algo: %s"
,
cudnnGetErrorString
(
err
));
return
1
;
}
#endif
prev_algo
=
algo
;
}
else
{
algo
=
prev_algo
;
}
#ifdef CHOOSE_ONCE
reuse_algo
=
1
;
#else
for
(
unsigned
int
i
=
0
;
i
<
PyGpuArray_NDIM
(
input
);
i
++
)
{
prev_img_dims
[
i
]
=
PyGpuArray_DIM
(
input
,
i
);
prev_kern_dims
[
i
]
=
PyGpuArray_DIM
(
kerns
,
i
);
}
#endif
#endif
/* These two algos are not supported for 3d conv */
if
(
PyGpuArray_NDIM
(
input
)
==
5
&&
(
algo
==
CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM
||
algo
==
CUDNN_CONVOLUTION_FWD_ALGO_GEMM
))
algo
=
CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM
;
#if CUDNN_VERSION > 3000
if
(
algo
==
CUDNN_CONVOLUTION_FWD_ALGO_FFT
)
{
int
nd
;
int
pad
[
2
];
int
stride
[
2
];
int
upscale
[
2
];
cudnnConvolutionMode_t
mode
;
err
=
cudnnGetConvolutionNdDescriptor
(
desc
,
2
,
&
nd
,
pad
,
stride
,
upscale
,
&
mode
);
if
(
err
!=
CUDNN_STATUS_SUCCESS
)
{
PyErr_Format
(
PyExc_RuntimeError
,
"error getting convolution properties: %s"
,
cudnnGetErrorString
(
err
));
return
1
;
}
if
(
stride
[
0
]
!=
1
||
stride
[
1
]
!=
1
||
PyGpuArray_DIM
(
input
,
0
)
>
1024
||
PyGpuArray_DIM
(
input
,
1
)
>
1024
||
(
PyGpuArray_DIM
(
kerns
,
0
)
==
1
&&
PyGpuArray_DIM
(
kerns
,
1
)
==
1
))
{
algo
=
CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM
;
}
}
#endif
#if CUDNN_VERSION < 3000
/* cuDNN before v3 does not support kernels larger than input even
* if appropriate padding is selected. */
for
(
unsigned
int
i
=
2
;
i
<
PyGpuArray_NDIM
(
input
);
i
++
)
{
if
(
PyGpuArray_DIM
(
kerns
,
i
)
>
PyGpuArray_DIM
(
input
,
i
))
{
PyErr_SetString
(
PyExc_RuntimeError
,
"the current version "
"of CuDNN does not support kernels larger than the "
"inputs in any spatial dimension, even if the inputs "
"are padded such that the padded inputs are larger "
"than the kernels. Update your installation of CuDNN "
"to V3 or more recent to solve the issue."
);
return
1
;
}
}
#endif
{
size_t
worksize
;
gpudata
*
workspace
;
PyGpuContextObject
*
c
;
err
=
cudnnGetConvolutionForwardWorkspaceSize
(
_handle
,
APPLY_SPECIFIC
(
input
),
APPLY_SPECIFIC
(
kerns
),
desc
,
APPLY_SPECIFIC
(
output
),
CONV_ALGO
,
algo
,
&
worksize
);
if
(
err
!=
CUDNN_STATUS_SUCCESS
)
{
PyErr_Format
(
PyExc_RuntimeError
,
"
GpuDnnConv:
error getting worksize: %s"
,
"error getting worksize: %s"
,
cudnnGetErrorString
(
err
));
return
1
;
}
/*
/*
* This is less than ideal since we need to free it after (which
* introduces a synchronization point. But we don't have a module
* to place a nice get_work_mem() function in.
...
...
@@ -91,7 +210,7 @@ APPLY_SPECIFIC(conv_fwd)(PyGpuArrayObject *input, PyGpuArrayObject *kerns,
alpha_p
,
APPLY_SPECIFIC
(
input
),
PyGpuArray_DEV_DATA
(
input
),
APPLY_SPECIFIC
(
kerns
),
PyGpuArray_DEV_DATA
(
kerns
),
desc
,
CONV_ALGO
,
desc
,
algo
,
worksize
==
0
?
NULL
:
*
(
void
**
)
workspace
,
worksize
,
beta_p
,
APPLY_SPECIFIC
(
output
),
PyGpuArray_DEV_DATA
(
*
output
));
...
...
@@ -101,7 +220,7 @@ APPLY_SPECIFIC(conv_fwd)(PyGpuArrayObject *input, PyGpuArrayObject *kerns,
}
if
(
err
!=
CUDNN_STATUS_SUCCESS
)
{
PyErr_Format
(
PyExc_RuntimeError
,
"
GpuDnnConv:
error doing operation: %s"
,
PyErr_Format
(
PyExc_RuntimeError
,
"error doing operation: %s"
,
cudnnGetErrorString
(
err
));
return
1
;
}
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论