Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
46783773
提交
46783773
authored
5月 30, 2017
作者:
abergeron
提交者:
GitHub
5月 30, 2017
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #5964 from notoraptor/cudnn-less-ccode-and-v6
Wrap Op params for many gpuarray DNN Ops and add cuDNN v6 integration.
上级
3baa162f
101fbd90
隐藏空白字符变更
内嵌
并排
正在显示
15 个修改的文件
包含
677 行增加
和
625 行删除
+677
-625
configdefaults.py
theano/configdefaults.py
+9
-11
op.py
theano/gof/op.py
+4
-1
type.py
theano/gof/type.py
+9
-1
conv_desc.c
theano/gpuarray/conv_desc.c
+22
-21
cudnn_defs.py
theano/gpuarray/cudnn_defs.py
+132
-0
dnn.py
theano/gpuarray/dnn.py
+141
-239
dnn_batchnorm_grad.c
theano/gpuarray/dnn_batchnorm_grad.c
+3
-3
dnn_batchnorm_inf.c
theano/gpuarray/dnn_batchnorm_inf.c
+11
-11
dnn_fwd.c
theano/gpuarray/dnn_fwd.c
+98
-106
dnn_gi.c
theano/gpuarray/dnn_gi.c
+108
-96
dnn_gw.c
theano/gpuarray/dnn_gw.c
+95
-100
dnn_pool.c
theano/gpuarray/dnn_pool.c
+3
-3
dnn_pool_grad.c
theano/gpuarray/dnn_pool_grad.c
+3
-3
test_dnn.py
theano/gpuarray/tests/test_dnn.py
+33
-30
pool.py
theano/tensor/signal/pool.py
+6
-0
没有找到文件。
theano/configdefaults.py
浏览文件 @
46783773
...
...
@@ -268,19 +268,18 @@ def safe_no_dnn_algo_bwd(algo):
'`dnn.conv.algo_bwd_filter` and `dnn.conv.algo_bwd_data` instead.'
)
return
True
# Those are the options provided by Theano to choose algorithms at runtime.
SUPPORTED_DNN_CONV_ALGO_RUNTIME
=
(
'guess_once'
,
'guess_on_shape_change'
,
'time_once'
,
'time_on_shape_change'
)
# Those are the supported algorithm by Theano,
# The tests will reference those lists.
SUPPORTED_DNN_CONV_ALGO_FWD
=
(
'small'
,
'none'
,
'large'
,
'fft'
,
'fft_tiling'
,
'winograd'
,
'guess_once'
,
'guess_on_shape_change'
,
'time_once'
,
'time_on_shape_change'
)
SUPPORTED_DNN_CONV_ALGO_FWD
=
(
'small'
,
'none'
,
'large'
,
'fft'
,
'fft_tiling'
,
'winograd'
)
+
SUPPORTED_DNN_CONV_ALGO_RUNTIME
SUPPORTED_DNN_CONV_ALGO_BWD_DATA
=
(
'none'
,
'deterministic'
,
'fft'
,
'fft_tiling'
,
'winograd'
)
+
SUPPORTED_DNN_CONV_ALGO_RUNTIME
SUPPORTED_DNN_CONV_ALGO_BWD_DATA
=
(
'none'
,
'deterministic'
,
'fft'
,
'fft_tiling'
,
'winograd'
,
'guess_once'
,
'guess_on_shape_change'
,
'time_once'
,
'time_on_shape_change'
)
SUPPORTED_DNN_CONV_ALGO_BWD_FILTER
=
(
'none'
,
'deterministic'
,
'fft'
,
'small'
)
+
SUPPORTED_DNN_CONV_ALGO_RUNTIME
SUPPORTED_DNN_CONV_ALGO_BWD_FILTER
=
(
'none'
,
'deterministic'
,
'fft'
,
'small'
,
'guess_once'
,
'guess_on_shape_change'
,
'time_once'
,
'time_on_shape_change'
)
SUPPORTED_DNN_CONV_PRECISION
=
(
'as_input_f32'
,
'as_input'
,
'float16'
,
'float32'
,
'float64'
)
AddConfigVar
(
'dnn.conv.algo_bwd'
,
"This flag is deprecated; use dnn.conv.algo_bwd_data and "
...
...
@@ -311,8 +310,7 @@ AddConfigVar('dnn.conv.precision',
"Default data precision to use for the computation in cuDNN "
"convolutions (defaults to the same dtype as the inputs of the "
"convolutions, or float32 if inputs are float16)."
,
EnumStr
(
'as_input_f32'
,
'as_input'
,
'float16'
,
'float32'
,
'float64'
),
EnumStr
(
*
SUPPORTED_DNN_CONV_PRECISION
),
in_c_key
=
False
)
...
...
theano/gof/op.py
浏览文件 @
46783773
...
...
@@ -1413,7 +1413,10 @@ class COp(Op):
return
[]
def
c_code_cache_version
(
self
):
return
hash
(
tuple
(
self
.
func_codes
))
version
=
(
hash
(
tuple
(
self
.
func_codes
)),
)
if
hasattr
(
self
,
'params_type'
):
version
+=
(
self
.
params_type
.
c_code_cache_version
(),
)
return
version
def
c_init_code
(
self
):
"""
...
...
theano/gof/type.py
浏览文件 @
46783773
...
...
@@ -963,6 +963,12 @@ class EnumType(Type, dict):
"""
return
alias
in
self
.
aliases
def
get_aliases
(
self
):
"""
Return the list of all aliases in this enumeration.
"""
return
self
.
aliases
.
keys
()
def
__repr__
(
self
):
names_to_aliases
=
{
constant_name
:
''
for
constant_name
in
self
}
for
alias
in
self
.
aliases
:
...
...
@@ -1184,4 +1190,6 @@ class CEnumType(EnumList):
fail
=
sub
[
'fail'
])
def
c_code_cache_version
(
self
):
return
(
1
,
super
(
CEnumType
,
self
)
.
c_code_cache_version
())
# C code depends on (C constant name, Python value) associations (given by `self.items()`),
# so we should better take them into account in C code version.
return
(
1
,
tuple
(
self
.
items
()),
super
(
CEnumType
,
self
)
.
c_code_cache_version
())
theano/gpuarray/conv_desc.c
浏览文件 @
46783773
#section support_code_apply
int
APPLY_SPECIFIC
(
conv_desc
)(
PyArrayObject
*
filt_shp
,
cudnnConvolutionDescriptor_t
*
desc
)
{
cudnnConvolutionDescriptor_t
*
desc
,
PARAMS_TYPE
*
params
)
{
cudnnStatus_t
err
;
int
pad
[
3
]
=
{
PAD_0
,
PAD_1
,
PAD_
2
};
int
strides
[
3
]
=
{
SUB_0
,
SUB_1
,
SUB_
2
};
int
dilation
[
3
]
=
{
DIL_0
,
DIL_1
,
DIL_
2
};
int
pad
[
3
]
=
{
params
->
pad0
,
params
->
pad1
,
params
->
pad
2
};
int
strides
[
3
]
=
{
params
->
sub0
,
params
->
sub1
,
params
->
sub
2
};
int
dilation
[
3
]
=
{
params
->
dil0
,
params
->
dil1
,
params
->
dil
2
};
#if BORDER_MODE == 0
pad
[
0
]
=
(
*
(
npy_int64
*
)
PyArray_GETPTR1
(
filt_shp
,
2
)
-
1
)
*
DIL_0
;
pad
[
1
]
=
(
*
(
npy_int64
*
)
PyArray_GETPTR1
(
filt_shp
,
3
)
-
1
)
*
DIL_1
;
#if NB_DIMS > 2
pad
[
2
]
=
(
*
(
npy_int64
*
)
PyArray_GETPTR1
(
filt_shp
,
4
)
-
1
)
*
DIL_2
;
#endif
#elif BORDER_MODE == 2
pad
[
0
]
=
((
*
(
npy_int64
*
)
PyArray_GETPTR1
(
filt_shp
,
2
)
-
1
)
*
DIL_0
+
1
)
/
2
;
pad
[
1
]
=
((
*
(
npy_int64
*
)
PyArray_GETPTR1
(
filt_shp
,
3
)
-
1
)
*
DIL_1
+
1
)
/
2
;
#if NB_DIMS > 2
pad
[
2
]
=
((
*
(
npy_int64
*
)
PyArray_GETPTR1
(
filt_shp
,
4
)
-
1
)
*
DIL_2
+
1
)
/
2
;
#endif
#endif
if
(
params
->
bmode
==
BORDER_MODE_FULL
)
{
pad
[
0
]
=
(
*
(
npy_int64
*
)
PyArray_GETPTR1
(
filt_shp
,
2
)
-
1
)
*
dilation
[
0
]
;
pad
[
1
]
=
(
*
(
npy_int64
*
)
PyArray_GETPTR1
(
filt_shp
,
3
)
-
1
)
*
dilation
[
1
]
;
if
(
params
->
nb_dims
>
2
)
{
pad
[
2
]
=
(
*
(
npy_int64
*
)
PyArray_GETPTR1
(
filt_shp
,
4
)
-
1
)
*
dilation
[
2
]
;
}
}
else
if
(
params
->
bmode
==
BORDER_MODE_HALF
)
{
pad
[
0
]
=
((
*
(
npy_int64
*
)
PyArray_GETPTR1
(
filt_shp
,
2
)
-
1
)
*
dilation
[
0
]
+
1
)
/
2
;
pad
[
1
]
=
((
*
(
npy_int64
*
)
PyArray_GETPTR1
(
filt_shp
,
3
)
-
1
)
*
dilation
[
1
]
+
1
)
/
2
;
if
(
params
->
nb_dims
>
2
)
{
pad
[
2
]
=
((
*
(
npy_int64
*
)
PyArray_GETPTR1
(
filt_shp
,
4
)
-
1
)
*
dilation
[
2
]
+
1
)
/
2
;
}
}
if
(
PyArray_DIM
(
filt_shp
,
0
)
-
2
!=
NB_DIMS
)
{
if
(
PyArray_DIM
(
filt_shp
,
0
)
-
2
!=
params
->
nb_dims
)
{
PyErr_Format
(
PyExc_ValueError
,
"Filter shape has too many dimensions: "
"expected %d, got %lld."
,
NB_DIMS
,
"expected %d, got %lld."
,
params
->
nb_dims
,
(
long
long
)
PyArray_DIM
(
filt_shp
,
0
));
return
-
1
;
}
...
...
@@ -35,8 +36,8 @@ int APPLY_SPECIFIC(conv_desc)(PyArrayObject *filt_shp,
return
-
1
;
}
err
=
cudnnSetConvolutionNdDescriptor
(
*
desc
,
NB_DIMS
,
pad
,
strides
,
dilation
,
CONV_MODE
,
PRECISION
);
err
=
cudnnSetConvolutionNdDescriptor
(
*
desc
,
params
->
nb_dims
,
pad
,
strides
,
dilation
,
params
->
conv_mode
,
params
->
precision
);
if
(
err
!=
CUDNN_STATUS_SUCCESS
)
{
PyErr_Format
(
PyExc_MemoryError
,
"could not set convolution "
"descriptor: %s"
,
cudnnGetErrorString
(
err
));
...
...
theano/gpuarray/cudnn_defs.py
0 → 100644
浏览文件 @
46783773
"""
Declarations of cuDNN types and constants used in Theano gpuarray DNN module.
For every cuDNN API supported by Theano, this module defines a class that
provides the set of cuDNN definitions to be used in Theano Ops.
Use :func:`get_definitions` to get the right cuDNN definitions
for a given cuDNN version.
Currently supported cuDNN APIs:
- v5.1
- v6.0
"""
from
__future__
import
absolute_import
,
print_function
,
division
from
theano.gof
import
CEnumType
# NB: Some cuDNN algorithms are listed in cuDNN enums but not implemented.
# We still register them here because we try to exactly copy cuDNN enums
# in Python side, but they will have no aliases associated, to help
# exclude them from lists of supported algorithms.
class
CuDNNV51
(
object
):
version
=
5
cudnnConvolutionMode_t
=
CEnumType
((
'CUDNN_CONVOLUTION'
,
'conv'
),
(
'CUDNN_CROSS_CORRELATION'
,
'cross'
),
ctype
=
'cudnnConvolutionMode_t'
)
cudnnDataType_t
=
CEnumType
((
'CUDNN_DATA_FLOAT'
,
'float32'
),
(
'CUDNN_DATA_DOUBLE'
,
'float64'
),
(
'CUDNN_DATA_HALF'
,
'float16'
),
# CUDNN_DATA_INT8 # new in v6
# CUDNN_DATA_INT32 # new in v6
# CUDNN_DATA_INT8x4 # new in v6
ctype
=
'cudnnDataType_t'
)
cudnnConvolutionFwdAlgo_t
=
CEnumType
((
'CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM'
,
'none'
),
(
'CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM'
,
'small'
),
(
'CUDNN_CONVOLUTION_FWD_ALGO_GEMM'
,
'large'
),
# not implemented:
(
'CUDNN_CONVOLUTION_FWD_ALGO_DIRECT'
),
(
'CUDNN_CONVOLUTION_FWD_ALGO_FFT'
,
'fft'
),
(
'CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING'
,
'fft_tiling'
),
(
'CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD'
,
'winograd'
),
# TODO: Not yet tested/documented:
(
'CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED'
,
'winograd_non_fused'
),
ctype
=
'cudnnConvolutionFwdAlgo_t'
)
conv3d_fwd_algorithms
=
(
'none'
,
'small'
,
'fft_tiling'
)
cudnnConvolutionBwdFilterAlgo_t
=
CEnumType
((
'CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0'
,
'none'
),
(
'CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1'
,
'deterministic'
),
(
'CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT'
,
'fft'
),
(
'CUDNN_CONVOLUTION_BWD_FILTER_ALGO_3'
,
'small'
),
# not implemented:
(
'CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD'
),
# TODO: not yet tested/documented:
(
'CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD_NONFUSED'
,
'winograd_non_fused'
),
ctype
=
'cudnnConvolutionBwdFilterAlgo_t'
)
conv3d_bwd_filter_algorithms
=
(
'none'
,
'small'
)
cudnnConvolutionBwdDataAlgo_t
=
CEnumType
((
'CUDNN_CONVOLUTION_BWD_DATA_ALGO_0'
,
'none'
),
(
'CUDNN_CONVOLUTION_BWD_DATA_ALGO_1'
,
'deterministic'
),
(
'CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT'
,
'fft'
),
(
'CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING'
,
'fft_tiling'
),
(
'CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD'
,
'winograd'
),
# TODO: not yet tested/documented:
(
'CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD_NONFUSED'
,
'winograd_non_fused'
),
ctype
=
'cudnnConvolutionBwdDataAlgo_t'
)
conv3d_bwd_data_algorithms
=
(
'none'
,
'deterministic'
,
'fft_tiling'
)
cudnnPoolingMode_t
=
CEnumType
((
'CUDNN_POOLING_MAX'
,
'max'
),
(
'CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING'
,
'average_inc_pad'
),
(
'CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING'
,
'average_exc_pad'
),
ctype
=
'cudnnPoolingMode_t'
)
cudnnSoftmaxAlgorithm_t
=
CEnumType
((
'CUDNN_SOFTMAX_FAST'
,
'fast'
),
(
'CUDNN_SOFTMAX_ACCURATE'
,
'accurate'
),
(
'CUDNN_SOFTMAX_LOG'
,
'log'
),
ctype
=
'cudnnSoftmaxAlgorithm_t'
)
cudnnSoftmaxMode_t
=
CEnumType
((
'CUDNN_SOFTMAX_MODE_INSTANCE'
,
'instance'
),
(
'CUDNN_SOFTMAX_MODE_CHANNEL'
,
'channel'
),
ctype
=
'cudnnSoftmaxMode_t'
)
cudnnBatchNormMode_t
=
CEnumType
((
'CUDNN_BATCHNORM_PER_ACTIVATION'
,
'per-activation'
),
(
'CUDNN_BATCHNORM_SPATIAL'
,
'spatial'
),
ctype
=
'cudnnBatchNormMode_t'
)
class
CuDNNV6
(
CuDNNV51
):
version
=
6
cudnnPoolingMode_t
=
CEnumType
((
'CUDNN_POOLING_MAX'
,
'max'
),
(
'CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING'
,
'average_inc_pad'
),
(
'CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING'
,
'average_exc_pad'
),
# new in v6:
(
'CUDNN_POOLING_MAX_DETERMINISTIC'
,
'max_deterministic'
),
ctype
=
'cudnnPoolingMode_t'
)
cudnnConvolutionBwdFilterAlgo_t
=
CEnumType
((
'CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0'
,
'none'
),
(
'CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1'
,
'deterministic'
),
(
'CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT'
,
'fft'
),
(
'CUDNN_CONVOLUTION_BWD_FILTER_ALGO_3'
,
'small'
),
(
'CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD'
),
(
'CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD_NONFUSED'
,
'winograd_non_fused'
),
# TODO: not yet tested/documented:
# new in v6:
(
'CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT_TILING'
,
'fft_tiling'
),
ctype
=
'cudnnConvolutionBwdFilterAlgo_t'
)
def
get_definitions
(
cudnn_version
=
None
):
"""
Return cuDNN definitions to be used by Theano for the given cuDNN version.
``cudnn_version`` must be None or an integer
(typically the version returned by :func:`theano.gpuarray.dnn.version`).
if None, return definitions for the most recent supported cuDNN version.
"""
if
cudnn_version
is
not
None
and
cudnn_version
//
1000
==
5
:
return
CuDNNV51
()
# By default, we use definitions for the last supported cuDNN version.
return
CuDNNV6
()
theano/gpuarray/dnn.py
浏览文件 @
46783773
...
...
@@ -9,10 +9,10 @@ from six import integer_types
import
theano
from
theano
import
Op
,
Apply
,
tensor
,
config
,
Variable
from
theano.scalar
import
as_scalar
,
constant
,
Log
,
get_scalar_type
from
theano.scalar
import
as_scalar
,
constant
,
Log
,
get_scalar_type
,
int32
as
int_t
,
bool
as
bool_t
from
theano.tensor
import
as_tensor_variable
from
theano.gradient
import
DisconnectedType
,
grad_not_implemented
from
theano.gof
import
Optimizer
,
local_optimizer
,
COp
,
ParamsType
,
CEnumType
from
theano.gof
import
Optimizer
,
local_optimizer
,
COp
,
ParamsType
,
EnumList
from
theano.gof.cmodule
import
GCC_compiler
from
theano.gof.type
import
CDataType
,
Generic
from
theano.compile
import
optdb
...
...
@@ -28,7 +28,7 @@ from theano.tensor.nnet.abstract_conv import (AbstractConv2d,
assert_conv_shape
)
from
theano.tensor.signal.pool
import
(
Pool
,
MaxPoolGrad
,
AveragePoolGrad
)
from
.
import
pygpu
from
.
import
pygpu
,
cudnn_defs
from
.type
import
(
get_context
,
gpu_context_type
,
list_contexts
,
GpuArraySharedVariable
)
from
.basic_ops
import
(
as_gpuarray_variable
,
infer_context_name
,
...
...
@@ -44,7 +44,10 @@ from .opt import (gpu_seqopt, register_opt, pool_db, pool_db2,
from
.opt_util
import
alpha_merge
,
output_merge
,
inplace_allocempty
,
pad_dims
,
unpad_dims
from
theano.configdefaults
import
SUPPORTED_DNN_CONV_ALGO_BWD_FILTER
from
theano.configdefaults
import
SUPPORTED_DNN_CONV_ALGO_RUNTIME
DNN_CONV_ALGO_CHOOSE_ONCE
=
[
'guess_once'
,
'time_once'
]
DNN_CONV_ALGO_CHOOSE_TIME
=
[
'time_once'
,
'time_on_shape_change'
]
try
:
from
pygpu
import
gpuarray
...
...
@@ -59,12 +62,12 @@ def _dnn_lib():
lib_name
=
ctypes
.
util
.
find_library
(
'cudnn'
)
if
lib_name
is
None
and
sys
.
platform
==
'win32'
:
# Update these names when new versions of cudnn are supported.
for
name
in
[
'cudnn64_5.dll'
]:
for
name
in
[
'cudnn64_
6.dll'
,
'cudnn64_
5.dll'
]:
lib_name
=
ctypes
.
util
.
find_library
(
name
)
if
lib_name
:
break
if
lib_name
is
None
:
raise
RuntimeError
(
'Could not find cudnn library (looked for v5
[.1]
)'
)
raise
RuntimeError
(
'Could not find cudnn library (looked for v5
* or v6*
)'
)
_dnn_lib
.
handle
=
ctypes
.
cdll
.
LoadLibrary
(
lib_name
)
cudnn
=
_dnn_lib
.
handle
cudnn
.
cudnnCreate
.
argtypes
=
[
ctypes
.
POINTER
(
ctypes
.
c_void_p
)]
...
...
@@ -116,10 +119,14 @@ if ((err = cudnnCreate(&_handle)) != CUDNN_STATUS_SUCCESS) {
# default gpu, not the one selected by the user. If mixed
# GPU are installed or if the GPUs are configured in
# exclusive mode, this cause bad detection.
avail
,
out
,
err
=
GCC_compiler
.
try_flags
(
# NB: GCC_compiler.try_flags() may return just a boolean instead of a tuple (avail, out, here).
compiler_res
=
GCC_compiler
.
try_flags
(
params
,
preambule
=
preambule
,
body
=
body
,
try_run
=
False
,
output
=
True
)
avail
,
out
,
err
=
compiler_res
if
isinstance
(
compiler_res
,
tuple
)
else
(
compiler_res
,
None
,
None
)
if
not
avail
:
return
False
,
(
"cannot compile with cuDNN. "
"We got this error:
\n
"
+
str
(
err
))
...
...
@@ -129,13 +136,12 @@ if ((err = cudnnCreate(&_handle)) != CUDNN_STATUS_SUCCESS) {
def
_dnn_check_version
():
v
=
version
()
if
v
<
5000
:
return
False
,
"cuDNN version is too old. Update to v5, was
%
d."
%
v
# 5200 should not print warning with cudnn 5.1 final.
return
False
,
"cuDNN version is too old. Update to v5* or higher, was
%
d."
%
v
if
v
>=
6100
:
warnings
.
warn
(
"Your cuDNN version is more recent than "
"Theano. If you encounter problems, try "
"updating Theano or downgrading cuDNN to "
"
version 6.0
."
)
"
a version >= v5 and <= v6
."
)
return
True
,
None
...
...
@@ -281,6 +287,9 @@ handle_type = CDataType('cudnnHandle_t', 'cudnnDestroy',
lib_dirs
=
[
config
.
dnn
.
library_path
],
version
=
version
(
raises
=
False
))
# Get cuDNN definitions to be used.
cudnn
=
cudnn_defs
.
get_definitions
(
version
(
raises
=
False
))
def
get_precision
(
precision
,
inputs
):
if
precision
is
None
:
...
...
@@ -367,6 +376,15 @@ class GpuDnnConvDesc(COp):
"""
__props__
=
(
'border_mode'
,
'subsample'
,
'dilation'
,
'conv_mode'
,
'precision'
)
params_type
=
ParamsType
(
pad0
=
int_t
,
pad1
=
int_t
,
pad2
=
int_t
,
sub0
=
int_t
,
sub1
=
int_t
,
sub2
=
int_t
,
dil0
=
int_t
,
dil1
=
int_t
,
dil2
=
int_t
,
nb_dims
=
int_t
,
bmode
=
EnumList
((
'BORDER_MODE_FULL'
,
'full'
),
(
'BORDER_MODE_VALID'
,
'valid'
),
(
'BORDER_MODE_HALF'
,
'half'
)),
conv_mode
=
cudnn
.
cudnnConvolutionMode_t
,
precision
=
cudnn
.
cudnnDataType_t
)
def
c_headers
(
self
):
return
[
'cudnn.h'
,
'cudnn_helper.h'
]
...
...
@@ -404,13 +422,13 @@ class GpuDnnConvDesc(COp):
self
.
border_mode
=
border_mode
assert
len
(
subsample
)
in
(
2
,
3
)
self
.
subsample
=
subsample
assert
c
onv_mode
in
(
'conv'
,
'cross'
)
assert
c
udnn
.
cudnnConvolutionMode_t
.
has_alias
(
conv_mode
)
self
.
conv_mode
=
conv_mode
assert
len
(
dilation
)
==
len
(
subsample
)
self
.
dilation
=
dilation
assert
precision
in
[
'float16'
,
'float32'
,
'float64'
]
assert
cudnn
.
cudnnDataType_t
.
has_alias
(
precision
)
self
.
precision
=
precision
def
make_node
(
self
,
kern_shape
):
...
...
@@ -430,59 +448,18 @@ class GpuDnnConvDesc(COp):
out
.
tag
.
values_eq_approx
=
tensor
.
type
.
values_eq_approx_always_true
return
node
def
get_op_params
(
self
):
pad0
=
'0'
pad1
=
'0'
pad2
=
'0'
if
isinstance
(
self
.
border_mode
,
tuple
):
pad0
=
str
(
self
.
border_mode
[
0
])
pad1
=
str
(
self
.
border_mode
[
1
])
if
len
(
self
.
border_mode
)
>
2
:
pad2
=
str
(
self
.
border_mode
[
2
])
bmode
=
'1'
elif
self
.
border_mode
==
"valid"
:
bmode
=
'1'
elif
self
.
border_mode
==
"half"
:
bmode
=
'2'
elif
self
.
border_mode
==
"full"
:
bmode
=
'0'
else
:
raise
ValueError
(
"Invalid value for border_mode"
)
if
self
.
conv_mode
==
'conv'
:
conv_flag
=
'CUDNN_CONVOLUTION'
else
:
conv_flag
=
'CUDNN_CROSS_CORRELATION'
sub0
=
str
(
self
.
subsample
[
0
])
sub1
=
str
(
self
.
subsample
[
1
])
if
len
(
self
.
subsample
)
>
2
:
sub2
=
str
(
self
.
subsample
[
2
])
else
:
sub2
=
'0'
dil0
=
str
(
self
.
dilation
[
0
])
dil1
=
str
(
self
.
dilation
[
1
])
if
len
(
self
.
dilation
)
>
2
:
dil2
=
str
(
self
.
dilation
[
2
])
else
:
dil2
=
'0'
if
self
.
precision
==
'float16'
:
precision
=
'CUDNN_DATA_HALF'
elif
self
.
precision
==
'float32'
:
precision
=
'CUDNN_DATA_FLOAT'
else
:
assert
self
.
precision
==
'float64'
precision
=
'CUDNN_DATA_DOUBLE'
return
[(
'NB_DIMS'
,
str
(
len
(
self
.
subsample
))),
(
'BORDER_MODE'
,
bmode
),
(
'PAD_0'
,
pad0
),
(
'PAD_1'
,
pad1
),
(
'PAD_2'
,
pad2
),
(
'DIL_0'
,
dil0
),
(
'DIL_1'
,
dil1
),
(
'DIL_2'
,
dil2
),
(
'CONV_MODE'
,
conv_flag
),
(
'SUB_0'
,
sub0
),
(
'SUB_1'
,
sub1
),
(
'SUB_2'
,
sub2
),
(
'PRECISION'
,
precision
)]
bmode
=
property
(
lambda
self
:
'valid'
if
isinstance
(
self
.
border_mode
,
tuple
)
else
self
.
border_mode
)
pad0
=
property
(
lambda
self
:
self
.
border_mode
[
0
]
if
isinstance
(
self
.
border_mode
,
tuple
)
else
0
)
pad1
=
property
(
lambda
self
:
self
.
border_mode
[
1
]
if
isinstance
(
self
.
border_mode
,
tuple
)
else
0
)
pad2
=
property
(
lambda
self
:
self
.
border_mode
[
2
]
if
(
isinstance
(
self
.
border_mode
,
tuple
)
and
len
(
self
.
border_mode
)
>
2
)
else
0
)
sub0
=
property
(
lambda
self
:
self
.
subsample
[
0
])
sub1
=
property
(
lambda
self
:
self
.
subsample
[
1
])
sub2
=
property
(
lambda
self
:
self
.
subsample
[
2
]
if
len
(
self
.
subsample
)
>
2
else
0
)
dil0
=
property
(
lambda
self
:
self
.
dilation
[
0
])
dil1
=
property
(
lambda
self
:
self
.
dilation
[
1
])
dil2
=
property
(
lambda
self
:
self
.
dilation
[
2
]
if
len
(
self
.
dilation
)
>
2
else
0
)
nb_dims
=
property
(
lambda
self
:
len
(
self
.
subsample
))
def
c_code_cache_version
(
self
):
return
(
super
(
GpuDnnConvDesc
,
self
)
.
c_code_cache_version
(),
version
())
...
...
@@ -533,6 +510,12 @@ class GpuDnnConv(DnnBase):
_f16_ok
=
True
__props__
=
(
'algo'
,
'inplace'
)
check_input
=
False
params_type
=
ParamsType
(
conv_algo
=
cudnn
.
cudnnConvolutionFwdAlgo_t
,
choose_algo
=
bool_t
,
choose_once
=
bool_t
,
choose_time
=
bool_t
,
inplace
=
bool_t
,
handle
=
handle_type
)
def
__init__
(
self
,
algo
=
None
,
inplace
=
False
):
DnnBase
.
__init__
(
self
,
[
"dnn_conv_base.c"
,
"dnn_fwd.c"
],
"APPLY_SPECIFIC(conv_fwd)"
)
...
...
@@ -541,13 +524,18 @@ class GpuDnnConv(DnnBase):
algo
=
config
.
dnn
.
conv
.
algo_fwd
self
.
algo
=
algo
self
.
inplace
=
inplace
self
.
inplace
=
bool
(
inplace
)
if
self
.
inplace
:
self
.
destroy_map
=
{
0
:
[
2
]}
assert
self
.
algo
in
[
'none'
,
'small'
,
'large'
,
'fft'
,
'fft_tiling'
,
'winograd'
,
'guess_once'
,
'guess_on_shape_change'
,
'time_once'
,
'time_on_shape_change'
]
assert
cudnn
.
cudnnConvolutionFwdAlgo_t
.
has_alias
(
self
.
algo
)
or
self
.
algo
in
SUPPORTED_DNN_CONV_ALGO_RUNTIME
self
.
conv_algo
=
cudnn
.
cudnnConvolutionFwdAlgo_t
.
CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM
if
self
.
algo
not
in
SUPPORTED_DNN_CONV_ALGO_RUNTIME
:
self
.
conv_algo
=
self
.
algo
self
.
choose_algo
=
self
.
algo
in
SUPPORTED_DNN_CONV_ALGO_RUNTIME
self
.
choose_once
=
self
.
algo
in
DNN_CONV_ALGO_CHOOSE_ONCE
self
.
choose_time
=
self
.
algo
in
DNN_CONV_ALGO_CHOOSE_TIME
def
__setstate__
(
self
,
d
):
self
.
__dict__
.
update
(
d
)
...
...
@@ -559,38 +547,6 @@ class GpuDnnConv(DnnBase):
if
not
hasattr
(
self
,
'inplace'
):
self
.
inplace
=
False
def
get_op_params
(
self
):
defs
=
[]
if
self
.
inplace
:
defs
.
append
((
'CONV_INPLACE'
,
'1'
))
alg
=
'CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM'
if
self
.
algo
==
'none'
:
# 3d
alg
=
'CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM'
elif
self
.
algo
==
'small'
:
# 3d
alg
=
'CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM'
elif
self
.
algo
==
'large'
:
alg
=
'CUDNN_CONVOLUTION_FWD_ALGO_GEMM'
elif
self
.
algo
==
'direct'
:
alg
=
'CUDNN_CONVOLUTION_FWD_ALGO_DIRECT'
elif
self
.
algo
==
'fft'
:
alg
=
'CUDNN_CONVOLUTION_FWD_ALGO_FFT'
elif
self
.
algo
==
'fft_tiling'
:
# 3d
alg
=
'CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING'
elif
self
.
algo
==
'winograd'
:
alg
=
'CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD'
defs
.
append
((
'CONV_ALGO'
,
alg
))
if
self
.
algo
in
[
'guess_once'
,
'guess_on_shape_change'
,
'time_once'
,
'time_on_shape_change'
]:
defs
.
append
((
'CHOOSE_ALGO'
,
''
))
if
self
.
algo
in
[
'guess_once'
,
'time_once'
]:
defs
.
append
((
'CHOOSE_ONCE'
,
''
))
if
self
.
algo
in
[
'time_once'
,
'time_on_shape_change'
]:
defs
.
append
((
'CHOOSE_TIME'
,
''
))
return
defs
def
make_node
(
self
,
img
,
kern
,
output
,
desc
,
alpha
=
None
,
beta
=
None
):
ctx_name
=
infer_context_name
(
img
,
kern
,
output
)
img
=
as_gpuarray_variable
(
img
,
ctx_name
)
...
...
@@ -609,7 +565,7 @@ class GpuDnnConv(DnnBase):
raise
TypeError
(
"The number of dimensions of "
"img, kern and output must match"
)
if
img
.
type
.
ndim
==
5
and
self
.
algo
in
[
'large'
,
'fft'
]
:
if
img
.
type
.
ndim
==
5
and
self
.
algo
not
in
cudnn
.
conv3d_fwd_algorithms
:
raise
ValueError
(
"convolution algo
%
s can't be used for "
"3d convolutions"
,
(
self
.
algo
,))
...
...
@@ -687,17 +643,30 @@ class GpuDnnConvGradW(DnnBase):
_f16_ok
=
True
__props__
=
(
'algo'
,
'inplace'
)
check_input
=
False
params_type
=
ParamsType
(
conv_algo
=
cudnn
.
cudnnConvolutionBwdFilterAlgo_t
,
choose_algo
=
bool_t
,
choose_once
=
bool_t
,
choose_time
=
bool_t
,
inplace
=
bool_t
,
handle
=
handle_type
)
def
__init__
(
self
,
inplace
=
False
,
algo
=
None
):
DnnBase
.
__init__
(
self
,
[
"dnn_conv_base.c"
,
"dnn_gw.c"
],
"APPLY_SPECIFIC(conv_gw)"
)
self
.
inplace
=
inplace
self
.
inplace
=
bool
(
inplace
)
if
self
.
inplace
:
self
.
destroy_map
=
{
0
:
[
2
]}
if
algo
is
None
:
algo
=
config
.
dnn
.
conv
.
algo_bwd_filter
self
.
algo
=
algo
assert
self
.
algo
in
SUPPORTED_DNN_CONV_ALGO_BWD_FILTER
assert
cudnn
.
cudnnConvolutionBwdFilterAlgo_t
.
has_alias
(
self
.
algo
)
or
self
.
algo
in
SUPPORTED_DNN_CONV_ALGO_RUNTIME
self
.
conv_algo
=
cudnn
.
cudnnConvolutionBwdFilterAlgo_t
.
CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0
if
self
.
algo
not
in
SUPPORTED_DNN_CONV_ALGO_RUNTIME
:
self
.
conv_algo
=
self
.
algo
self
.
choose_algo
=
self
.
algo
in
SUPPORTED_DNN_CONV_ALGO_RUNTIME
self
.
choose_once
=
self
.
algo
in
DNN_CONV_ALGO_CHOOSE_ONCE
self
.
choose_time
=
self
.
algo
in
DNN_CONV_ALGO_CHOOSE_TIME
def
__setstate__
(
self
,
d
):
self
.
__dict__
.
update
(
d
)
...
...
@@ -724,33 +693,6 @@ class GpuDnnConvGradW(DnnBase):
# not connected to desc
return
[[
1
],
[
1
],
[
1
],
[
0
],
[
1
],
[
1
]]
def
get_op_params
(
self
):
defs
=
[]
if
self
.
inplace
:
defs
.
append
((
'CONV_INPLACE'
,
'1'
))
alg
=
'CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0'
if
self
.
algo
==
'none'
:
# 3d
alg
=
'CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0'
if
self
.
algo
==
'deterministic'
:
alg
=
'CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1'
if
self
.
algo
==
'fft'
:
alg
=
'CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT'
if
self
.
algo
==
'small'
:
# 3d
# non-deterministic, small workspace
alg
=
'CUDNN_CONVOLUTION_BWD_FILTER_ALGO_3'
if
self
.
algo
in
[
'guess_once'
,
'guess_on_shape_change'
,
'time_once'
,
'time_on_shape_change'
]:
defs
.
append
((
'CHOOSE_ALGO'
,
''
))
if
self
.
algo
in
[
'guess_once'
,
'time_once'
]:
defs
.
append
((
'CHOOSE_ONCE'
,
''
))
if
self
.
algo
in
[
'time_once'
,
'time_on_shape_change'
]:
defs
.
append
((
'CHOOSE_TIME'
,
''
))
defs
.
append
((
'CONV_ALGO'
,
alg
))
return
defs
def
op_may_fail_with_subsample
(
self
,
img
,
desc
):
return
(
version
()
<
6000
and
img
.
type
.
dtype
==
'float32'
and
...
...
@@ -793,8 +735,7 @@ class GpuDnnConvGradW(DnnBase):
raise
TypeError
(
"The number of dimensions of "
"img, topgrad and output must match"
)
if
(
img
.
type
.
ndim
==
5
and
self
.
algo
in
[
'fft'
,
'deterministic'
]):
if
img
.
type
.
ndim
==
5
and
self
.
algo
not
in
cudnn
.
conv3d_bwd_filter_algorithms
:
raise
ValueError
(
"convolution algo
%
s can't be used for "
"3d convolutions"
,
(
self
.
algo
,))
...
...
@@ -830,19 +771,30 @@ class GpuDnnConvGradI(DnnBase):
_f16_ok
=
True
__props__
=
(
'algo'
,
'inplace'
,)
check_input
=
False
params_type
=
ParamsType
(
conv_algo
=
cudnn
.
cudnnConvolutionBwdDataAlgo_t
,
choose_algo
=
bool_t
,
choose_once
=
bool_t
,
choose_time
=
bool_t
,
inplace
=
bool_t
,
handle
=
handle_type
)
def
__init__
(
self
,
inplace
=
False
,
algo
=
None
):
DnnBase
.
__init__
(
self
,
[
"dnn_conv_base.c"
,
"dnn_gi.c"
],
"APPLY_SPECIFIC(conv_gi)"
)
self
.
inplace
=
inplace
self
.
inplace
=
bool
(
inplace
)
if
self
.
inplace
:
self
.
destroy_map
=
{
0
:
[
2
]}
if
algo
is
None
:
algo
=
config
.
dnn
.
conv
.
algo_bwd_data
self
.
algo
=
algo
assert
self
.
algo
in
[
'none'
,
'deterministic'
,
'fft'
,
'fft_tiling'
,
'winograd'
,
'guess_once'
,
'guess_on_shape_change'
,
'time_once'
,
'time_on_shape_change'
]
assert
cudnn
.
cudnnConvolutionBwdDataAlgo_t
.
has_alias
(
self
.
algo
)
or
self
.
algo
in
SUPPORTED_DNN_CONV_ALGO_RUNTIME
self
.
conv_algo
=
cudnn
.
cudnnConvolutionBwdDataAlgo_t
.
CUDNN_CONVOLUTION_BWD_DATA_ALGO_0
if
self
.
algo
not
in
SUPPORTED_DNN_CONV_ALGO_RUNTIME
:
self
.
conv_algo
=
self
.
algo
self
.
choose_algo
=
self
.
algo
in
SUPPORTED_DNN_CONV_ALGO_RUNTIME
self
.
choose_once
=
self
.
algo
in
DNN_CONV_ALGO_CHOOSE_ONCE
self
.
choose_time
=
self
.
algo
in
DNN_CONV_ALGO_CHOOSE_TIME
def
__setstate__
(
self
,
d
):
self
.
__dict__
.
update
(
d
)
...
...
@@ -869,36 +821,6 @@ class GpuDnnConvGradI(DnnBase):
# not connected to desc
return
[[
1
],
[
1
],
[
1
],
[
0
],
[
1
],
[
1
]]
def
get_op_params
(
self
):
defs
=
[]
if
self
.
inplace
:
defs
.
append
((
'CONV_INPLACE'
,
'1'
))
alg
=
'CUDNN_CONVOLUTION_BWD_DATA_ALGO_0'
if
self
.
algo
==
'none'
:
# 3d
alg
=
'CUDNN_CONVOLUTION_BWD_DATA_ALGO_0'
elif
self
.
algo
==
'deterministic'
:
# 3d
alg
=
'CUDNN_CONVOLUTION_BWD_DATA_ALGO_1'
elif
self
.
algo
==
'fft'
:
alg
=
'CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT'
elif
self
.
algo
==
'fft_tiling'
:
# 3d
# big workspace but less than fft
alg
=
'CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING'
elif
self
.
algo
==
'winograd'
:
alg
=
'CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD'
if
self
.
algo
in
[
'guess_once'
,
'guess_on_shape_change'
,
'time_once'
,
'time_on_shape_change'
]:
defs
.
append
((
'CHOOSE_ALGO'
,
''
))
if
self
.
algo
in
[
'guess_once'
,
'time_once'
]:
defs
.
append
((
'CHOOSE_ONCE'
,
''
))
if
self
.
algo
in
[
'time_once'
,
'time_on_shape_change'
]:
defs
.
append
((
'CHOOSE_TIME'
,
''
))
defs
.
append
((
'CONV_ALGO'
,
alg
))
return
defs
def
make_node
(
self
,
kern
,
topgrad
,
output
,
desc
,
alpha
=
None
,
beta
=
None
):
ctx_name
=
infer_context_name
(
kern
,
topgrad
,
output
)
kern
=
as_gpuarray_variable
(
kern
,
ctx_name
)
...
...
@@ -916,7 +838,7 @@ class GpuDnnConvGradI(DnnBase):
raise
TypeError
(
"The number of dimensions of "
"kern, topgrad and output must match"
)
if
kern
.
type
.
ndim
==
5
and
self
.
algo
in
[
'fft'
]
:
if
kern
.
type
.
ndim
==
5
and
self
.
algo
not
in
cudnn
.
conv3d_bwd_data_algorithms
:
raise
ValueError
(
"convolution algo
%
s can't be used for "
"3d convolutions"
,
(
self
.
algo
,))
...
...
@@ -1059,7 +981,7 @@ def dnn_conv(img, kerns, border_mode='valid', subsample=(1, 1), dilation=(1, 1),
def
dnn_conv3d
(
img
,
kerns
,
border_mode
=
'valid'
,
subsample
=
(
1
,
1
,
1
),
dilation
=
(
1
,
1
,
1
),
conv_mode
=
'conv'
,
direction_hint
=
None
,
algo
=
'none'
,
precision
=
None
):
algo
=
None
,
precision
=
None
):
"""
GPU convolution using cuDNN from NVIDIA.
...
...
@@ -1349,7 +1271,33 @@ class GpuDnnPoolDesc(Op):
return
(
4
,
version
())
class
GpuDnnPool
(
DnnBase
):
class
GpuDnnPoolBase
(
DnnBase
):
"""
Abstract base class for GpuDnnPool and GpuDnnPoolGrad.
"""
# c_file and c_function must be defined in sub-classes.
c_file
=
None
c_function
=
None
_f16_ok
=
True
__props__
=
(
'mode'
,)
check_input
=
False
params_type
=
ParamsType
(
mode
=
cudnn
.
cudnnPoolingMode_t
,
handle
=
handle_type
)
def
__init__
(
self
,
mode
=
'max'
):
DnnBase
.
__init__
(
self
,
[
self
.
c_file
],
self
.
c_function
)
if
mode
==
'average'
:
mode
=
'average_inc_pad'
# Supported modes depend on runtime cuDNN version.
assert
cudnn
.
cudnnPoolingMode_t
.
has_alias
(
mode
)
self
.
mode
=
mode
class
GpuDnnPool
(
GpuDnnPoolBase
):
"""
Parameters
...
...
@@ -1366,25 +1314,8 @@ class GpuDnnPool(DnnBase):
(padX, padY) or (padX, padY, padZ)
"""
_f16_ok
=
True
__props__
=
(
'mode'
,)
def
__init__
(
self
,
mode
=
'max'
):
DnnBase
.
__init__
(
self
,
[
"dnn_pool.c"
],
"APPLY_SPECIFIC(dnn_pool)"
)
if
mode
==
'average'
:
mode
=
'average_inc_pad'
assert
mode
in
(
'max'
,
'average_inc_pad'
,
'average_exc_pad'
)
self
.
mode
=
mode
def
get_op_params
(
self
):
if
self
.
mode
==
'max'
:
mode_flag
=
'CUDNN_POOLING_MAX'
elif
self
.
mode
==
"average_inc_pad"
:
mode_flag
=
'CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING'
elif
self
.
mode
==
"average_exc_pad"
:
mode_flag
=
'CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING'
return
[(
'MODE_FLAG'
,
mode_flag
)]
c_file
=
"dnn_pool.c"
c_function
=
"APPLY_SPECIFIC(dnn_pool)"
def
make_node
(
self
,
img
,
ws
,
stride
,
pad
):
ctx_name
=
infer_context_name
(
img
)
...
...
@@ -1428,7 +1359,7 @@ class GpuDnnPool(DnnBase):
return
[[
1
],
[
0
],
[
0
],
[
0
]]
class
GpuDnnPoolGrad
(
Dnn
Base
):
class
GpuDnnPoolGrad
(
GpuDnnPool
Base
):
"""
The pooling gradient.
...
...
@@ -1451,26 +1382,8 @@ class GpuDnnPoolGrad(DnnBase):
(padX, padY) or (padX, padY, padZ)
"""
_f16_ok
=
True
__props__
=
(
'mode'
,)
def
__init__
(
self
,
mode
=
'max'
):
DnnBase
.
__init__
(
self
,
[
"dnn_pool_grad.c"
],
"APPLY_SPECIFIC(dnn_pool_grad)"
)
if
mode
==
'average'
:
mode
=
'average_inc_pad'
assert
mode
in
(
'max'
,
'average_inc_pad'
,
'average_exc_pad'
)
self
.
mode
=
mode
def
get_op_params
(
self
):
if
self
.
mode
==
'max'
:
mode_flag
=
'CUDNN_POOLING_MAX'
elif
self
.
mode
==
"average_inc_pad"
:
mode_flag
=
'CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING'
elif
self
.
mode
==
"average_exc_pad"
:
mode_flag
=
'CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING'
return
[(
'MODE_FLAG'
,
mode_flag
)]
c_file
=
"dnn_pool_grad.c"
c_function
=
"APPLY_SPECIFIC(dnn_pool_grad)"
def
make_node
(
self
,
inp
,
out
,
out_grad
,
ws
,
stride
,
pad
):
ctx_name
=
infer_context_name
(
inp
,
out
,
out_grad
)
...
...
@@ -1513,7 +1426,8 @@ def dnn_pool(img, ws, stride=None, mode='max', pad=None):
Subsampling window size. Should have 2 or 3 elements.
stride : tuple
Subsampling stride (default: (1, 1) or (1, 1, 1)).
mode : {'max', 'average_inc_pad', 'average_exc_pad', 'sum'}
mode : {'max', 'average_inc_pad', 'average_exc_pad', 'sum', 'max_deterministic'}
**NB**: 'max_deterministic' is supported since cuDNN v6.
pad : tuple
(padX, padY) or (padX, padY, padZ)
default: (0, 0) or (0, 0, 0)
...
...
@@ -1562,22 +1476,17 @@ class GpuDnnSoftmaxBase(DnnBase):
# neither in dnn_base.c nor in dnn_softmax*.c,
# so we can disable input checking.
check_input
=
False
params_type
=
ParamsType
(
algo
=
CEnumType
((
'CUDNN_SOFTMAX_FAST'
,
'fast'
),
(
'CUDNN_SOFTMAX_LOG'
,
'log'
),
(
'CUDNN_SOFTMAX_ACCURATE'
,
'accurate'
),
ctype
=
'cudnnSoftmaxAlgorithm_t'
),
mode
=
CEnumType
((
'CUDNN_SOFTMAX_MODE_INSTANCE'
,
'instance'
),
(
'CUDNN_SOFTMAX_MODE_CHANNEL'
,
'channel'
),
ctype
=
'cudnnSoftmaxMode_t'
),
params_type
=
ParamsType
(
algo
=
cudnn
.
cudnnSoftmaxAlgorithm_t
,
mode
=
cudnn
.
cudnnSoftmaxMode_t
,
handle
=
handle_type
)
def
__init__
(
self
,
algo
,
mode
):
DnnBase
.
__init__
(
self
,
[
self
.
file
],
self
.
c_func
)
assert
(
algo
in
(
'fast'
,
'accurate'
,
'log'
)
)
assert
cudnn
.
cudnnSoftmaxAlgorithm_t
.
has_alias
(
algo
)
self
.
algo
=
algo
assert
(
mode
in
(
'instance'
,
'channel'
)
)
assert
cudnn
.
cudnnSoftmaxMode_t
.
has_alias
(
mode
)
self
.
mode
=
mode
def
infer_shape
(
self
,
node
,
shape
):
...
...
@@ -1810,13 +1719,18 @@ class GpuDnnBatchNormInference(DnnBase):
__props__
=
(
'mode'
,
'inplace'
)
check_input
=
False
params_type
=
ParamsType
(
mode
=
cudnn
.
cudnnBatchNormMode_t
,
inplace
=
bool_t
,
handle
=
handle_type
)
def
__init__
(
self
,
mode
=
'per-activation'
,
inplace
=
False
):
DnnBase
.
__init__
(
self
,
[
'dnn_batchnorm_base.c'
,
'dnn_batchnorm_inf.c'
],
'dnn_batchnorm_op'
)
assert
(
mode
in
(
'per-activation'
,
'spatial'
)
)
assert
cudnn
.
cudnnBatchNormMode_t
.
has_alias
(
mode
)
self
.
mode
=
mode
self
.
inplace
=
inplace
self
.
inplace
=
bool
(
inplace
)
if
self
.
inplace
:
self
.
destroy_map
=
{
0
:
[
0
]}
...
...
@@ -1825,15 +1739,6 @@ class GpuDnnBatchNormInference(DnnBase):
if
not
hasattr
(
self
,
'inplace'
):
self
.
inplace
=
False
def
get_op_params
(
self
):
params
=
[]
if
self
.
inplace
:
params
.
append
((
'INPLACE_OUTPUT'
,
'1'
))
params
.
append
((
'MODE'
,
(
"CUDNN_BATCHNORM_SPATIAL"
if
self
.
mode
==
"spatial"
else
"CUDNN_BATCHNORM_PER_ACTIVATION"
)))
return
params
def
infer_shape
(
self
,
node
,
shape
):
return
[
shape
[
0
]]
...
...
@@ -1882,20 +1787,17 @@ class GpuDnnBatchNormInference(DnnBase):
class
GpuDnnBatchNormGrad
(
DnnBase
):
__props__
=
(
'mode'
,)
check_input
=
False
params_type
=
ParamsType
(
mode
=
cudnn
.
cudnnBatchNormMode_t
,
handle
=
handle_type
)
def
__init__
(
self
,
mode
=
'per-activation'
):
DnnBase
.
__init__
(
self
,
[
'dnn_batchnorm_base.c'
,
'dnn_batchnorm_grad.c'
],
'dnn_batchnorm_grad'
)
assert
(
mode
in
(
'per-activation'
,
'spatial'
)
)
assert
cudnn
.
cudnnBatchNormMode_t
.
has_alias
(
mode
)
self
.
mode
=
mode
def
get_op_params
(
self
):
params
=
[]
params
.
append
((
'MODE'
,
(
"CUDNN_BATCHNORM_SPATIAL"
if
self
.
mode
==
"spatial"
else
"CUDNN_BATCHNORM_PER_ACTIVATION"
)))
return
params
def
make_node
(
self
,
x
,
dy
,
scale
,
x_mean
,
x_invstd
,
epsilon
=
1e-4
):
ctx_name
=
infer_context_name
(
x
,
dy
,
scale
,
x_mean
,
x_invstd
)
x
=
as_gpuarray_variable
(
x
,
ctx_name
)
...
...
theano/gpuarray/dnn_batchnorm_grad.c
浏览文件 @
46783773
...
...
@@ -24,7 +24,7 @@ int dnn_batchnorm_grad(PyGpuArrayObject *inp, PyGpuArrayObject *doutp,
PyGpuArrayObject
*
scale
,
PyGpuArrayObject
*
x_mean
,
PyGpuArrayObject
*
x_invstd
,
npy_float64
epsilon
,
PyGpuArrayObject
**
dinp
,
PyGpuArrayObject
**
dscale
,
PyGpuArrayObject
**
dbias
,
cudnnHandle_t
_handle
)
{
PyGpuArrayObject
**
dbias
,
PARAMS_TYPE
*
params
)
{
PyGpuContextObject
*
c
=
inp
->
context
;
if
(
c_set_tensorNd
(
inp
,
bn_input
)
!=
0
)
...
...
@@ -70,8 +70,8 @@ int dnn_batchnorm_grad(PyGpuArrayObject *inp, PyGpuArrayObject *doutp,
betaParam
=
(
void
*
)
&
fbeta
;
}
cudnnStatus_t
err
=
cudnnBatchNormalizationBackward
(
_
handle
,
MODE
,
params
->
handle
,
params
->
mode
,
alphaData
,
betaData
,
alphaParam
,
...
...
theano/gpuarray/dnn_batchnorm_inf.c
浏览文件 @
46783773
...
...
@@ -3,7 +3,7 @@
int
dnn_batchnorm_op
(
PyGpuArrayObject
*
inp
,
PyGpuArrayObject
*
scale
,
PyGpuArrayObject
*
bias
,
PyGpuArrayObject
*
est_mean
,
PyGpuArrayObject
*
est_var
,
npy_float64
epsilon
,
PyGpuArrayObject
**
outp
,
cudnnHandle_t
_handle
)
{
PyGpuArrayObject
**
outp
,
PARAMS_TYPE
*
params
)
{
PyGpuContextObject
*
c
=
inp
->
context
;
if
(
c_set_tensorNd
(
inp
,
bn_input
)
!=
0
)
...
...
@@ -16,14 +16,14 @@ int dnn_batchnorm_op(PyGpuArrayObject *inp, PyGpuArrayObject *scale,
return
1
;
}
#ifdef INPLACE_OUTPUT
Py_XDECREF
(
*
outp
);
*
outp
=
inp
;
Py_INCREF
(
*
outp
);
#else
if
(
theano_prep_output
(
outp
,
inp
->
ga
.
nd
,
inp
->
ga
.
dimensions
,
inp
->
ga
.
typecode
,
GA_C_ORDER
,
c
)
!=
0
)
return
1
;
#endif
if
(
params
->
inplace
)
{
Py_XDECREF
(
*
outp
);
*
outp
=
inp
;
Py_INCREF
(
*
outp
);
}
else
{
if
(
theano_prep_output
(
outp
,
inp
->
ga
.
nd
,
inp
->
ga
.
dimensions
,
inp
->
ga
.
typecode
,
GA_C_ORDER
,
c
)
!=
0
)
return
1
;
}
if
(
c_set_tensorNd
(
*
outp
,
bn_output
)
!=
0
)
return
1
;
...
...
@@ -43,8 +43,8 @@ int dnn_batchnorm_op(PyGpuArrayObject *inp, PyGpuArrayObject *scale,
beta
=
(
void
*
)
&
fbeta
;
}
cudnnStatus_t
err
=
cudnnBatchNormalizationForwardInference
(
_
handle
,
MODE
,
params
->
handle
,
params
->
mode
,
alpha
,
beta
,
bn_input
,
...
...
theano/gpuarray/dnn_fwd.c
浏览文件 @
46783773
#section init_code_struct
#ifdef CHOOSE_ALGO
reuse_algo
=
0
;
prev_algo
=
CONV_ALGO
;
#ifndef CHOOSE_ONCE
memset
(
prev_img_dims
,
0
,
sizeof
(
prev_img_dims
));
memset
(
prev_kern_dims
,
0
,
sizeof
(
prev_kern_dims
));
#endif
#endif
if
(
PARAMS
->
choose_algo
)
{
reuse_algo
=
0
;
prev_algo
=
PARAMS
->
conv_algo
;
if
(
!
PARAMS
->
choose_once
)
{
memset
(
prev_img_dims
,
0
,
sizeof
(
prev_img_dims
));
memset
(
prev_kern_dims
,
0
,
sizeof
(
prev_kern_dims
));
}
}
#section support_code_struct
#ifdef CHOOSE_ALGO
int
reuse_algo
;
cudnnConvolutionFwdAlgo_t
prev_algo
;
#ifndef CHOOSE_ONCE
size_t
prev_img_dims
[
5
];
size_t
prev_kern_dims
[
5
];
#endif
#endif
int
APPLY_SPECIFIC
(
conv_fwd
)(
PyGpuArrayObject
*
input
,
PyGpuArrayObject
*
kerns
,
...
...
@@ -26,7 +22,7 @@ APPLY_SPECIFIC(conv_fwd)(PyGpuArrayObject *input, PyGpuArrayObject *kerns,
cudnnConvolutionDescriptor_t
desc
,
double
alpha
,
double
beta
,
PyGpuArrayObject
**
output
,
cudnnHandle_t
_handle
)
{
PARAMS_TYPE
*
params
)
{
PyGpuContextObject
*
c
=
input
->
context
;
void
*
alpha_p
;
void
*
beta_p
;
...
...
@@ -54,17 +50,17 @@ APPLY_SPECIFIC(conv_fwd)(PyGpuArrayObject *input, PyGpuArrayObject *kerns,
return
1
;
}
#ifdef CONV_INPLACE
Py_XDECREF
(
*
output
);
*
output
=
om
;
Py_INCREF
(
*
output
);
#else
if
(
theano_prep_output
(
output
,
PyGpuArray_NDIM
(
om
),
PyGpuArray_DIMS
(
om
),
om
->
ga
.
typecode
,
GA_C_ORDER
,
c
)
!=
0
)
return
1
;
if
(
beta
!=
0
.
0
&&
pygpu_move
(
*
output
,
om
))
return
1
;
#endif
if
(
params
->
inplace
)
{
Py_XDECREF
(
*
output
);
*
output
=
om
;
Py_INCREF
(
*
output
);
}
else
{
if
(
theano_prep_output
(
output
,
PyGpuArray_NDIM
(
om
),
PyGpuArray_DIMS
(
om
),
om
->
ga
.
typecode
,
GA_C_ORDER
,
c
)
!=
0
)
return
1
;
if
(
beta
!=
0
.
0
&&
pygpu_move
(
*
output
,
om
))
return
1
;
}
if
(
PyGpuArray_DIMS
(
input
)[
0
]
==
0
||
PyGpuArray_DIMS
(
kerns
)[
0
]
==
0
||
PyGpuArray_DIMS
(
kerns
)[
1
]
==
0
)
{
int
err2
=
GpuArray_memset
(
&
(
*
output
)
->
ga
,
0
);
...
...
@@ -83,90 +79,90 @@ APPLY_SPECIFIC(conv_fwd)(PyGpuArrayObject *input, PyGpuArrayObject *kerns,
if
(
c_set_tensorNd
(
*
output
,
APPLY_SPECIFIC
(
output
))
==
-
1
)
return
1
;
cudnnConvolutionFwdAlgo_t
algo
=
CONV_ALGO
;
cudnnConvolutionFwdAlgo_t
algo
=
params
->
conv_algo
;
cuda_enter
(
c
->
ctx
);
#ifdef CHOOSE_ALGO
#ifndef CHOOSE_ONCE
reuse_algo
=
1
;
for
(
unsigned
int
i
=
0
;
i
<
PyGpuArray_NDIM
(
input
);
i
++
)
{
reuse_algo
=
(
reuse_algo
&&
PyGpuArray_DIM
(
input
,
i
)
==
prev_img_dims
[
i
]);
reuse_algo
=
(
reuse_algo
&&
PyGpuArray_DIM
(
kerns
,
i
)
==
prev_kern_dims
[
i
]);
}
#endif
if
(
!
reuse_algo
)
{
size_t
free
;
int
err2
=
gpucontext_property
(
c
->
ctx
,
GA_CTX_PROP_LARGEST_MEMBLOCK
,
&
free
);
if
(
err2
!=
GA_NO_ERROR
)
{
PyErr_Format
(
PyExc_RuntimeError
,
"Error when trying to find the "
"memory information on the GPU"
);
cuda_exit
(
c
->
ctx
);
return
1
;
if
(
params
->
choose_algo
)
{
if
(
params
->
choose_once
)
{
reuse_algo
=
1
;
for
(
unsigned
int
i
=
0
;
i
<
PyGpuArray_NDIM
(
input
);
i
++
)
{
reuse_algo
=
(
reuse_algo
&&
PyGpuArray_DIM
(
input
,
i
)
==
prev_img_dims
[
i
]);
reuse_algo
=
(
reuse_algo
&&
PyGpuArray_DIM
(
kerns
,
i
)
==
prev_kern_dims
[
i
]);
}
}
// Guess 4Mb if the info is not available
if
(
free
==
0
)
free
=
4
*
1024
*
1024
;
if
(
!
reuse_algo
)
{
size_t
free
;
#ifdef CHOOSE_TIME
int
count
;
cudnnConvolutionFwdAlgoPerf_t
choice
;
gpudata
*
tmpmem
;
int
err2
=
gpucontext_property
(
c
->
ctx
,
GA_CTX_PROP_LARGEST_MEMBLOCK
,
&
free
);
if
(
err2
!=
GA_NO_ERROR
)
{
PyErr_Format
(
PyExc_RuntimeError
,
"Error when trying to find the "
"memory information on the GPU"
);
cuda_exit
(
c
->
ctx
);
return
1
;
}
tmpmem
=
gpudata_alloc
(
c
->
ctx
,
free
,
NULL
,
0
,
NULL
);
if
(
tmpmem
==
NULL
)
{
PyErr_SetString
(
PyExc_MemoryError
,
"Could not allocate working GPU memory"
);
return
-
1
;
// Guess 4Mb if the info is not available
if
(
free
==
0
)
free
=
4
*
1024
*
1024
;
if
(
params
->
choose_time
)
{
int
count
;
cudnnConvolutionFwdAlgoPerf_t
choice
;
gpudata
*
tmpmem
;
tmpmem
=
gpudata_alloc
(
c
->
ctx
,
free
,
NULL
,
0
,
NULL
);
if
(
tmpmem
==
NULL
)
{
PyErr_SetString
(
PyExc_MemoryError
,
"Could not allocate working GPU memory"
);
return
-
1
;
}
// We don't sync the buffer as we don't care about the values.
err
=
cudnnFindConvolutionForwardAlgorithmEx
(
params
->
handle
,
APPLY_SPECIFIC
(
input
),
PyGpuArray_DEV_DATA
(
input
),
APPLY_SPECIFIC
(
kerns
),
PyGpuArray_DEV_DATA
(
kerns
),
desc
,
APPLY_SPECIFIC
(
output
),
PyGpuArray_DEV_DATA
(
*
output
),
1
,
&
count
,
&
choice
,
*
(
void
**
)
tmpmem
,
free
);
gpudata_release
(
tmpmem
);
if
(
err
!=
CUDNN_STATUS_SUCCESS
)
{
PyErr_Format
(
PyExc_RuntimeError
,
"error selecting convolution algo: %s"
,
cudnnGetErrorString
(
err
));
cuda_exit
(
c
->
ctx
);
return
1
;
}
algo
=
choice
.
algo
;
}
else
{
err
=
cudnnGetConvolutionForwardAlgorithm
(
params
->
handle
,
APPLY_SPECIFIC
(
input
),
APPLY_SPECIFIC
(
kerns
),
desc
,
APPLY_SPECIFIC
(
output
),
CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT
,
free
,
&
algo
);
if
(
err
!=
CUDNN_STATUS_SUCCESS
)
{
PyErr_Format
(
PyExc_RuntimeError
,
"error selecting convolution algo: %s"
,
cudnnGetErrorString
(
err
));
cuda_exit
(
c
->
ctx
);
return
1
;
}
}
prev_algo
=
algo
;
}
else
{
algo
=
prev_algo
;
}
// We don't sync the buffer as we don't care about the values.
err
=
cudnnFindConvolutionForwardAlgorithmEx
(
_handle
,
APPLY_SPECIFIC
(
input
),
PyGpuArray_DEV_DATA
(
input
),
APPLY_SPECIFIC
(
kerns
),
PyGpuArray_DEV_DATA
(
kerns
),
desc
,
APPLY_SPECIFIC
(
output
),
PyGpuArray_DEV_DATA
(
*
output
),
1
,
&
count
,
&
choice
,
*
(
void
**
)
tmpmem
,
free
);
gpudata_release
(
tmpmem
);
if
(
err
!=
CUDNN_STATUS_SUCCESS
)
{
PyErr_Format
(
PyExc_RuntimeError
,
"error selecting convolution algo: %s"
,
cudnnGetErrorString
(
err
));
cuda_exit
(
c
->
ctx
);
return
1
;
}
algo
=
choice
.
algo
;
#else
err
=
cudnnGetConvolutionForwardAlgorithm
(
_handle
,
APPLY_SPECIFIC
(
input
),
APPLY_SPECIFIC
(
kerns
),
desc
,
APPLY_SPECIFIC
(
output
),
CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT
,
free
,
&
algo
);
if
(
err
!=
CUDNN_STATUS_SUCCESS
)
{
PyErr_Format
(
PyExc_RuntimeError
,
"error selecting convolution algo: %s"
,
cudnnGetErrorString
(
err
));
cuda_exit
(
c
->
ctx
);
return
1
;
if
(
params
->
choose_once
)
{
reuse_algo
=
1
;
}
else
{
for
(
unsigned
int
i
=
0
;
i
<
PyGpuArray_NDIM
(
input
);
i
++
)
{
prev_img_dims
[
i
]
=
PyGpuArray_DIM
(
input
,
i
);
prev_kern_dims
[
i
]
=
PyGpuArray_DIM
(
kerns
,
i
);
}
}
#endif
prev_algo
=
algo
;
}
else
{
algo
=
prev_algo
;
}
#ifdef CHOOSE_ONCE
reuse_algo
=
1
;
#else
for
(
unsigned
int
i
=
0
;
i
<
PyGpuArray_NDIM
(
input
);
i
++
)
{
prev_img_dims
[
i
]
=
PyGpuArray_DIM
(
input
,
i
);
prev_kern_dims
[
i
]
=
PyGpuArray_DIM
(
kerns
,
i
);
}
#endif
#endif
/* These two algos are not supported for 3d conv */
if
(
PyGpuArray_NDIM
(
input
)
==
5
&&
(
algo
==
CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM
||
...
...
@@ -201,20 +197,16 @@ APPLY_SPECIFIC(conv_fwd)(PyGpuArrayObject *input, PyGpuArrayObject *kerns,
return
1
;
}
if
(
algo
==
CUDNN_CONVOLUTION_FWD_ALGO_FFT
)
{
if
(
algo
==
CUDNN_CONVOLUTION_FWD_ALGO_FFT
)
{
if
(
stride
[
0
]
!=
1
||
stride
[
1
]
!=
1
||
PyGpuArray_DIM
(
input
,
2
)
>
1024
||
PyGpuArray_DIM
(
input
,
3
)
>
1024
||
(
PyGpuArray_DIM
(
kerns
,
2
)
==
1
&&
PyGpuArray_DIM
(
kerns
,
3
)
==
1
))
{
algo
=
CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM
;
}
}
else
{
}
else
{
// algo == CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING
if
(
stride
[
0
]
!=
1
||
stride
[
1
]
!=
1
)
{
if
(
stride
[
0
]
!=
1
||
stride
[
1
]
!=
1
)
{
algo
=
CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM
;
}
}
...
...
@@ -223,7 +215,7 @@ APPLY_SPECIFIC(conv_fwd)(PyGpuArrayObject *input, PyGpuArrayObject *kerns,
{
size_t
worksize
;
gpudata
*
workspace
;
err
=
cudnnGetConvolutionForwardWorkspaceSize
(
_
handle
,
err
=
cudnnGetConvolutionForwardWorkspaceSize
(
params
->
handle
,
APPLY_SPECIFIC
(
input
),
APPLY_SPECIFIC
(
kerns
),
desc
,
...
...
@@ -236,7 +228,7 @@ APPLY_SPECIFIC(conv_fwd)(PyGpuArrayObject *input, PyGpuArrayObject *kerns,
// TODO: Print a warning
algo
=
CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM
;
err
=
cudnnGetConvolutionForwardWorkspaceSize
(
_
handle
,
err
=
cudnnGetConvolutionForwardWorkspaceSize
(
params
->
handle
,
APPLY_SPECIFIC
(
input
),
APPLY_SPECIFIC
(
kerns
),
desc
,
...
...
@@ -273,7 +265,7 @@ APPLY_SPECIFIC(conv_fwd)(PyGpuArrayObject *input, PyGpuArrayObject *kerns,
cuda_wait
((
*
output
)
->
ga
.
data
,
GPUARRAY_CUDA_WAIT_WRITE
);
err
=
cudnnConvolutionForward
(
_
handle
,
params
->
handle
,
alpha_p
,
APPLY_SPECIFIC
(
input
),
PyGpuArray_DEV_DATA
(
input
),
APPLY_SPECIFIC
(
kerns
),
PyGpuArray_DEV_DATA
(
kerns
),
...
...
theano/gpuarray/dnn_gi.c
浏览文件 @
46783773
#section init_code_struct
#ifdef CHOOSE_ALGO
reuse_algo
=
0
;
prev_algo
=
CONV_ALGO
;
#ifndef CHOOSE_ONCE
memset
(
prev_kern_dims
,
0
,
sizeof
(
prev_kern_dims
));
memset
(
prev_top_dims
,
0
,
sizeof
(
prev_top_dims
));
#endif
#endif
// #ifdef CHOOSE_ALGO
if
(
PARAMS
->
choose_algo
)
{
reuse_algo
=
0
;
prev_algo
=
PARAMS
->
conv_algo
;
// #ifndef CHOOSE_ONCE
if
(
!
PARAMS
->
choose_once
)
{
memset
(
prev_kern_dims
,
0
,
sizeof
(
prev_kern_dims
));
memset
(
prev_top_dims
,
0
,
sizeof
(
prev_top_dims
));
}
// #endif
}
// #endif
#section support_code_struct
#ifdef CHOOSE_ALGO
int
reuse_algo
=
0
;
cudnnConvolutionBwdDataAlgo_t
prev_algo
=
CONV_ALGO
;
#ifndef CHOOSE_ONCE
int
reuse_algo
;
cudnnConvolutionBwdDataAlgo_t
prev_algo
;
size_t
prev_kern_dims
[
5
]
=
{
0
};
size_t
prev_top_dims
[
5
]
=
{
0
};
#endif
#endif
int
APPLY_SPECIFIC
(
conv_gi
)(
PyGpuArrayObject
*
kerns
,
PyGpuArrayObject
*
output
,
PyGpuArrayObject
*
im
,
cudnnConvolutionDescriptor_t
desc
,
double
alpha
,
double
beta
,
PyGpuArrayObject
**
input
,
cudnnHandle_t
_handle
)
{
PARAMS_TYPE
*
params
)
{
PyGpuContextObject
*
c
=
kerns
->
context
;
void
*
alpha_p
;
void
*
beta_p
;
...
...
@@ -53,17 +53,20 @@ APPLY_SPECIFIC(conv_gi)(PyGpuArrayObject *kerns, PyGpuArrayObject *output,
return
1
;
}
#ifdef CONV_INPLACE
Py_XDECREF
(
*
input
);
*
input
=
im
;
Py_INCREF
(
*
input
);
#else
if
(
theano_prep_output
(
input
,
PyGpuArray_NDIM
(
im
),
PyGpuArray_DIMS
(
im
),
im
->
ga
.
typecode
,
GA_C_ORDER
,
c
)
!=
0
)
return
1
;
if
(
beta
!=
0
.
0
&&
pygpu_move
(
*
input
,
im
))
return
1
;
#endif
// #ifdef CONV_INPLACE
if
(
params
->
inplace
)
{
Py_XDECREF
(
*
input
);
*
input
=
im
;
Py_INCREF
(
*
input
);
// #else
}
else
{
if
(
theano_prep_output
(
input
,
PyGpuArray_NDIM
(
im
),
PyGpuArray_DIMS
(
im
),
im
->
ga
.
typecode
,
GA_C_ORDER
,
c
)
!=
0
)
return
1
;
if
(
beta
!=
0
.
0
&&
pygpu_move
(
*
input
,
im
))
return
1
;
}
// #endif
if
(
PyGpuArray_DIMS
(
im
)[
0
]
==
0
||
PyGpuArray_DIMS
(
kerns
)[
0
]
==
0
||
PyGpuArray_DIMS
(
kerns
)[
1
]
==
0
)
{
int
err2
=
GpuArray_memset
(
&
(
*
input
)
->
ga
,
0
);
...
...
@@ -82,7 +85,7 @@ APPLY_SPECIFIC(conv_gi)(PyGpuArrayObject *kerns, PyGpuArrayObject *output,
if
(
c_set_tensorNd
(
*
input
,
APPLY_SPECIFIC
(
input
))
==
-
1
)
return
1
;
cudnnConvolutionBwdDataAlgo_t
algo
=
CONV_ALGO
;
cudnnConvolutionBwdDataAlgo_t
algo
=
params
->
conv_algo
;
cuda_enter
(
c
->
ctx
);
...
...
@@ -128,84 +131,93 @@ APPLY_SPECIFIC(conv_gi)(PyGpuArrayObject *kerns, PyGpuArrayObject *output,
}
}
#ifdef CHOOSE_ALGO
#ifndef CHOOSE_ONCE
reuse_algo
=
1
;
for
(
unsigned
int
i
=
0
;
i
<
PyGpuArray_NDIM
(
kerns
);
i
++
)
{
reuse_algo
=
(
reuse_algo
&&
PyGpuArray_DIM
(
kerns
,
i
)
==
prev_kern_dims
[
i
]);
reuse_algo
=
(
reuse_algo
&&
PyGpuArray_DIM
(
output
,
i
)
==
prev_top_dims
[
i
]);
}
#endif
// #ifdef CHOOSE_ALGO
if
(
params
->
choose_algo
)
{
// #ifndef CHOOSE_ONCE
if
(
!
params
->
choose_once
)
{
reuse_algo
=
1
;
for
(
unsigned
int
i
=
0
;
i
<
PyGpuArray_NDIM
(
kerns
);
i
++
)
{
reuse_algo
=
(
reuse_algo
&&
PyGpuArray_DIM
(
kerns
,
i
)
==
prev_kern_dims
[
i
]);
reuse_algo
=
(
reuse_algo
&&
PyGpuArray_DIM
(
output
,
i
)
==
prev_top_dims
[
i
]);
}
}
// #endif
if
(
!
reuse_algo
)
{
size_t
free
;
int
err2
=
gpucontext_property
(
c
->
ctx
,
GA_CTX_PROP_LARGEST_MEMBLOCK
,
&
free
);
if
(
!
reuse_algo
)
{
size_t
free
;
int
err2
=
gpucontext_property
(
c
->
ctx
,
GA_CTX_PROP_LARGEST_MEMBLOCK
,
&
free
);
if
(
err2
!=
GA_NO_ERROR
)
{
PyErr_Format
(
PyExc_RuntimeError
,
"Error when trying to find the "
"memory information on the GPU"
);
cuda_exit
(
c
->
ctx
);
return
1
;
}
if
(
err2
!=
GA_NO_ERROR
)
{
PyErr_Format
(
PyExc_RuntimeError
,
"Error when trying to find the "
"memory information on the GPU"
);
cuda_exit
(
c
->
ctx
);
return
1
;
}
// Guess 4Mb if the info is not available
if
(
free
==
0
)
free
=
4
*
1024
*
1024
;
// Guess 4Mb if the info is not available
if
(
free
==
0
)
free
=
4
*
1024
*
1024
;
#ifdef CHOOSE_TIME
int
count
;
cudnnConvolutionBwdDataAlgoPerf_t
choice
;
gpudata
*
tmpmem
;
// #ifdef CHOOSE_TIME
if
(
params
->
choose_time
)
{
int
count
;
cudnnConvolutionBwdDataAlgoPerf_t
choice
;
gpudata
*
tmpmem
;
tmpmem
=
gpudata_alloc
(
c
->
ctx
,
free
,
NULL
,
0
,
NULL
);
if
(
tmpmem
==
NULL
)
{
PyErr_SetString
(
PyExc_MemoryError
,
"Could not allocate working GPU memory"
);
return
-
1
;
}
tmpmem
=
gpudata_alloc
(
c
->
ctx
,
free
,
NULL
,
0
,
NULL
);
if
(
tmpmem
==
NULL
)
{
PyErr_SetString
(
PyExc_MemoryError
,
"Could not allocate working GPU memory"
);
return
-
1
;
}
err
=
cudnnFindConvolutionBackwardDataAlgorithmEx
(
_handle
,
APPLY_SPECIFIC
(
kerns
),
PyGpuArray_DEV_DATA
(
kerns
),
APPLY_SPECIFIC
(
output
),
PyGpuArray_DEV_DATA
(
output
),
desc
,
APPLY_SPECIFIC
(
input
),
PyGpuArray_DEV_DATA
(
*
input
),
1
,
&
count
,
&
choice
,
*
(
void
**
)
tmpmem
,
free
);
gpudata_release
(
tmpmem
);
err
=
cudnnFindConvolutionBackwardDataAlgorithmEx
(
params
->
handle
,
APPLY_SPECIFIC
(
kerns
),
PyGpuArray_DEV_DATA
(
kerns
),
APPLY_SPECIFIC
(
output
),
PyGpuArray_DEV_DATA
(
output
),
desc
,
APPLY_SPECIFIC
(
input
),
PyGpuArray_DEV_DATA
(
*
input
),
1
,
&
count
,
&
choice
,
*
(
void
**
)
tmpmem
,
free
);
gpudata_release
(
tmpmem
);
if
(
err
!=
CUDNN_STATUS_SUCCESS
)
{
PyErr_Format
(
PyExc_RuntimeError
,
"error selecting convolution algo: %s"
,
cudnnGetErrorString
(
err
));
cuda_exit
(
c
->
ctx
);
return
1
;
}
if
(
err
!=
CUDNN_STATUS_SUCCESS
)
{
PyErr_Format
(
PyExc_RuntimeError
,
"error selecting convolution algo: %s"
,
cudnnGetErrorString
(
err
));
cuda_exit
(
c
->
ctx
);
return
1
;
algo
=
choice
.
algo
;
// #else
}
else
{
err
=
cudnnGetConvolutionBackwardDataAlgorithm
(
params
->
handle
,
APPLY_SPECIFIC
(
kerns
),
APPLY_SPECIFIC
(
output
),
desc
,
APPLY_SPECIFIC
(
input
),
CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT
,
free
,
&
algo
);
if
(
err
!=
CUDNN_STATUS_SUCCESS
)
{
PyErr_Format
(
PyExc_RuntimeError
,
"error selecting convolution algo: %s"
,
cudnnGetErrorString
(
err
));
cuda_exit
(
c
->
ctx
);
return
1
;
}
}
algo
=
choice
.
algo
;
#else
err
=
cudnnGetConvolutionBackwardDataAlgorithm
(
_handle
,
APPLY_SPECIFIC
(
kerns
),
APPLY_SPECIFIC
(
output
),
desc
,
APPLY_SPECIFIC
(
input
),
CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT
,
free
,
&
algo
);
if
(
err
!=
CUDNN_STATUS_SUCCESS
)
{
PyErr_Format
(
PyExc_RuntimeError
,
"error selecting convolution algo: %s"
,
cudnnGetErrorString
(
err
));
cuda_exit
(
c
->
ctx
);
return
1
;
// #endif
prev_algo
=
algo
;
}
else
{
algo
=
prev_algo
;
}
#endif
prev_algo
=
algo
;
}
else
{
algo
=
prev_algo
;
}
#ifdef CHOOSE_ONCE
reuse_algo
=
1
;
#else
for
(
unsigned
int
i
=
0
;
i
<
PyGpuArray_NDIM
(
kerns
);
i
++
)
{
prev_kern_dims
[
i
]
=
PyGpuArray_DIM
(
kerns
,
i
);
prev_top_dims
[
i
]
=
PyGpuArray_DIM
(
output
,
i
);
// #ifdef CHOOSE_ONCE
if
(
params
->
choose_once
)
{
reuse_algo
=
1
;
// #else
}
else
{
for
(
unsigned
int
i
=
0
;
i
<
PyGpuArray_NDIM
(
kerns
);
i
++
)
{
prev_kern_dims
[
i
]
=
PyGpuArray_DIM
(
kerns
,
i
);
prev_top_dims
[
i
]
=
PyGpuArray_DIM
(
output
,
i
);
}
}
// #endif
}
#endif
#endif
// #endif
// The FFT implementation does not support strides, 1x1 filters or inputs
// with a spatial dimension larger than 1024. The tiled-FFT implementation
...
...
@@ -258,7 +270,7 @@ APPLY_SPECIFIC(conv_gi)(PyGpuArrayObject *kerns, PyGpuArrayObject *output,
gpudata
*
workspace
;
err
=
cudnnGetConvolutionBackwardDataWorkspaceSize
(
_
handle
,
APPLY_SPECIFIC
(
kerns
),
APPLY_SPECIFIC
(
output
),
desc
,
params
->
handle
,
APPLY_SPECIFIC
(
kerns
),
APPLY_SPECIFIC
(
output
),
desc
,
APPLY_SPECIFIC
(
input
),
algo
,
&
worksize
);
if
(
err
!=
CUDNN_STATUS_SUCCESS
)
{
...
...
@@ -283,7 +295,7 @@ APPLY_SPECIFIC(conv_gi)(PyGpuArrayObject *kerns, PyGpuArrayObject *output,
cuda_wait
((
*
input
)
->
ga
.
data
,
GPUARRAY_CUDA_WAIT_WRITE
);
err
=
cudnnConvolutionBackwardData
(
_
handle
,
params
->
handle
,
alpha_p
,
APPLY_SPECIFIC
(
kerns
),
PyGpuArray_DEV_DATA
(
kerns
),
APPLY_SPECIFIC
(
output
),
PyGpuArray_DEV_DATA
(
output
),
...
...
theano/gpuarray/dnn_gw.c
浏览文件 @
46783773
#section init_code_struct
#ifdef CHOOSE_ALGO
reuse_algo
=
0
;
prev_algo
=
CONV_ALGO
;
#ifndef CHOOSE_ONCE
memset
(
prev_img_dims
,
0
,
sizeof
(
prev_img_dims
));
memset
(
prev_top_dims
,
0
,
sizeof
(
prev_top_dims
));
#endif
#endif
if
(
PARAMS
->
choose_algo
)
{
reuse_algo
=
0
;
prev_algo
=
PARAMS
->
conv_algo
;
if
(
!
PARAMS
->
choose_once
)
{
memset
(
prev_img_dims
,
0
,
sizeof
(
prev_img_dims
));
memset
(
prev_top_dims
,
0
,
sizeof
(
prev_top_dims
));
}
}
#section support_code_struct
#ifdef CHOOSE_ALGO
int
reuse_algo
;
cudnnConvolutionBwdFilterAlgo_t
prev_algo
;
#ifndef CHOOSE_ONCE
size_t
prev_img_dims
[
5
];
size_t
prev_top_dims
[
5
];
#endif
#endif
int
APPLY_SPECIFIC
(
conv_gw
)(
PyGpuArrayObject
*
input
,
PyGpuArrayObject
*
output
,
PyGpuArrayObject
*
km
,
cudnnConvolutionDescriptor_t
desc
,
double
alpha
,
double
beta
,
PyGpuArrayObject
**
kerns
,
cudnnHandle_t
_handle
)
{
PARAMS_TYPE
*
params
)
{
PyGpuContextObject
*
c
=
input
->
context
;
void
*
alpha_p
;
void
*
beta_p
;
...
...
@@ -53,17 +49,17 @@ APPLY_SPECIFIC(conv_gw)(PyGpuArrayObject *input, PyGpuArrayObject *output,
return
1
;
}
#ifdef CONV_INPLACE
Py_XDECREF
(
*
kerns
);
*
kerns
=
km
;
Py_INCREF
(
*
kerns
);
#else
if
(
theano_prep_output
(
kerns
,
PyGpuArray_NDIM
(
km
),
PyGpuArray_DIMS
(
km
),
km
->
ga
.
typecode
,
GA_C_ORDER
,
c
)
!=
0
)
return
1
;
if
(
beta
!=
0
.
0
&&
pygpu_move
(
*
kerns
,
km
))
return
1
;
#endif
if
(
params
->
inplace
)
{
Py_XDECREF
(
*
kerns
);
*
kerns
=
km
;
Py_INCREF
(
*
kerns
);
}
else
{
if
(
theano_prep_output
(
kerns
,
PyGpuArray_NDIM
(
km
),
PyGpuArray_DIMS
(
km
),
km
->
ga
.
typecode
,
GA_C_ORDER
,
c
)
!=
0
)
return
1
;
if
(
beta
!=
0
.
0
&&
pygpu_move
(
*
kerns
,
km
))
return
1
;
}
if
(
PyGpuArray_DIMS
(
input
)[
0
]
==
0
||
PyGpuArray_DIMS
(
km
)[
0
]
==
0
||
PyGpuArray_DIMS
(
km
)[
1
]
==
0
)
{
int
err2
=
GpuArray_memset
(
&
(
*
kerns
)
->
ga
,
0
);
...
...
@@ -82,7 +78,7 @@ APPLY_SPECIFIC(conv_gw)(PyGpuArrayObject *input, PyGpuArrayObject *output,
if
(
c_set_filter
(
*
kerns
,
APPLY_SPECIFIC
(
kerns
))
==
-
1
)
return
1
;
cudnnConvolutionBwdFilterAlgo_t
algo
=
CONV_ALGO
;
cudnnConvolutionBwdFilterAlgo_t
algo
=
params
->
conv_algo
;
cuda_enter
(
c
->
ctx
);
...
...
@@ -128,86 +124,85 @@ APPLY_SPECIFIC(conv_gw)(PyGpuArrayObject *input, PyGpuArrayObject *output,
}
}
#ifdef CHOOSE_ALGO
#ifndef CHOOSE_ONCE
reuse_algo
=
1
;
for
(
unsigned
int
i
=
0
;
i
<
PyGpuArray_NDIM
(
input
);
i
++
)
{
reuse_algo
=
(
reuse_algo
&&
PyGpuArray_DIM
(
input
,
i
)
==
prev_img_dims
[
i
]);
reuse_algo
=
(
reuse_algo
&&
PyGpuArray_DIM
(
output
,
i
)
==
prev_top_dims
[
i
]);
}
#endif
if
(
!
reuse_algo
)
{
size_t
free
;
int
err2
=
gpucontext_property
(
c
->
ctx
,
GA_CTX_PROP_LARGEST_MEMBLOCK
,
&
free
);
if
(
err2
!=
GA_NO_ERROR
)
{
PyErr_Format
(
PyExc_RuntimeError
,
"Error when trying to find the "
"memory information on the GPU"
);
cuda_exit
(
c
->
ctx
);
return
1
;
if
(
params
->
choose_algo
)
{
if
(
!
params
->
choose_once
)
{
reuse_algo
=
1
;
for
(
unsigned
int
i
=
0
;
i
<
PyGpuArray_NDIM
(
input
);
i
++
)
{
reuse_algo
=
(
reuse_algo
&&
PyGpuArray_DIM
(
input
,
i
)
==
prev_img_dims
[
i
]);
reuse_algo
=
(
reuse_algo
&&
PyGpuArray_DIM
(
output
,
i
)
==
prev_top_dims
[
i
]);
}
}
// Guess 4Mb if the info is not available
if
(
free
==
0
)
free
=
4
*
1024
*
1024
;
#ifdef CHOOSE_TIME
int
count
;
cudnnConvolutionBwdFilterAlgoPerf_t
choice
;
gpudata
*
tmpmem
;
tmpmem
=
gpudata_alloc
(
c
->
ctx
,
free
,
NULL
,
0
,
NULL
);
if
(
tmpmem
==
NULL
)
{
PyErr_SetString
(
PyExc_MemoryError
,
"Could not allocate working GPU memory"
);
return
-
1
;
}
if
(
!
reuse_algo
)
{
size_t
free
;
err
=
cudnnFindConvolutionBackwardFilterAlgorithmEx
(
_handle
,
APPLY_SPECIFIC
(
input
),
PyGpuArray_DEV_DATA
(
input
),
APPLY_SPECIFIC
(
output
),
PyGpuArray_DEV_DATA
(
output
),
desc
,
APPLY_SPECIFIC
(
kerns
),
PyGpuArray_DEV_DATA
(
*
kerns
),
1
,
&
count
,
&
choice
,
*
(
void
**
)
tmpmem
,
free
);
gpudata_release
(
tmpmem
);
if
(
err
!=
CUDNN_STATUS_SUCCESS
)
{
PyErr_Format
(
PyExc_RuntimeError
,
"error selecting convolution algo: %s"
,
cudnnGetErrorString
(
err
));
cuda_exit
(
c
->
ctx
);
return
1
;
int
err2
=
gpucontext_property
(
c
->
ctx
,
GA_CTX_PROP_LARGEST_MEMBLOCK
,
&
free
);
if
(
err2
!=
GA_NO_ERROR
)
{
PyErr_Format
(
PyExc_RuntimeError
,
"Error when trying to find the "
"memory information on the GPU"
);
cuda_exit
(
c
->
ctx
);
return
1
;
}
// Guess 4Mb if the info is not available
if
(
free
==
0
)
free
=
4
*
1024
*
1024
;
if
(
params
->
choose_time
)
{
int
count
;
cudnnConvolutionBwdFilterAlgoPerf_t
choice
;
gpudata
*
tmpmem
;
tmpmem
=
gpudata_alloc
(
c
->
ctx
,
free
,
NULL
,
0
,
NULL
);
if
(
tmpmem
==
NULL
)
{
PyErr_SetString
(
PyExc_MemoryError
,
"Could not allocate working GPU memory"
);
return
-
1
;
}
err
=
cudnnFindConvolutionBackwardFilterAlgorithmEx
(
params
->
handle
,
APPLY_SPECIFIC
(
input
),
PyGpuArray_DEV_DATA
(
input
),
APPLY_SPECIFIC
(
output
),
PyGpuArray_DEV_DATA
(
output
),
desc
,
APPLY_SPECIFIC
(
kerns
),
PyGpuArray_DEV_DATA
(
*
kerns
),
1
,
&
count
,
&
choice
,
*
(
void
**
)
tmpmem
,
free
);
gpudata_release
(
tmpmem
);
if
(
err
!=
CUDNN_STATUS_SUCCESS
)
{
PyErr_Format
(
PyExc_RuntimeError
,
"error selecting convolution algo: %s"
,
cudnnGetErrorString
(
err
));
cuda_exit
(
c
->
ctx
);
return
1
;
}
algo
=
choice
.
algo
;
}
else
{
err
=
cudnnGetConvolutionBackwardFilterAlgorithm
(
params
->
handle
,
APPLY_SPECIFIC
(
input
),
APPLY_SPECIFIC
(
output
),
desc
,
APPLY_SPECIFIC
(
kerns
),
CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT
,
free
,
&
algo
);
if
(
err
!=
CUDNN_STATUS_SUCCESS
)
{
PyErr_Format
(
PyExc_RuntimeError
,
"error selecting convolution algo: %s"
,
cudnnGetErrorString
(
err
));
cuda_exit
(
c
->
ctx
);
return
1
;
}
}
prev_algo
=
algo
;
}
else
{
algo
=
prev_algo
;
}
algo
=
choice
.
algo
;
#else
err
=
cudnnGetConvolutionBackwardFilterAlgorithm
(
_handle
,
APPLY_SPECIFIC
(
input
),
APPLY_SPECIFIC
(
output
),
desc
,
APPLY_SPECIFIC
(
kerns
),
CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT
,
free
,
&
algo
);
if
(
err
!=
CUDNN_STATUS_SUCCESS
)
{
PyErr_Format
(
PyExc_RuntimeError
,
"error selecting convolution algo: %s"
,
cudnnGetErrorString
(
err
));
cuda_exit
(
c
->
ctx
);
return
1
;
if
(
params
->
choose_once
)
{
reuse_algo
=
1
;
}
else
{
for
(
unsigned
int
i
=
0
;
i
<
PyGpuArray_NDIM
(
input
);
i
++
)
{
prev_img_dims
[
i
]
=
PyGpuArray_DIM
(
input
,
i
);
prev_top_dims
[
i
]
=
PyGpuArray_DIM
(
output
,
i
);
}
}
#endif
prev_algo
=
algo
;
}
else
{
algo
=
prev_algo
;
}
#ifdef CHOOSE_ONCE
reuse_algo
=
1
;
#else
for
(
unsigned
int
i
=
0
;
i
<
PyGpuArray_NDIM
(
input
);
i
++
)
{
prev_img_dims
[
i
]
=
PyGpuArray_DIM
(
input
,
i
);
prev_top_dims
[
i
]
=
PyGpuArray_DIM
(
output
,
i
);
}
#endif
#endif
// The FFT implementation does not support strides, 1x1 filters or inputs
// with a spatial dimension larger than 1024.
...
...
@@ -246,7 +241,7 @@ APPLY_SPECIFIC(conv_gw)(PyGpuArrayObject *input, PyGpuArrayObject *output,
gpudata
*
workspace
;
err
=
cudnnGetConvolutionBackwardFilterWorkspaceSize
(
_
handle
,
APPLY_SPECIFIC
(
input
),
APPLY_SPECIFIC
(
output
),
desc
,
params
->
handle
,
APPLY_SPECIFIC
(
input
),
APPLY_SPECIFIC
(
output
),
desc
,
APPLY_SPECIFIC
(
kerns
),
algo
,
&
worksize
);
if
(
err
!=
CUDNN_STATUS_SUCCESS
)
{
...
...
@@ -270,7 +265,7 @@ APPLY_SPECIFIC(conv_gw)(PyGpuArrayObject *input, PyGpuArrayObject *output,
cuda_wait
((
*
kerns
)
->
ga
.
data
,
GPUARRAY_CUDA_WAIT_WRITE
);
err
=
cudnnConvolutionBackwardFilter
(
_
handle
,
params
->
handle
,
alpha_p
,
APPLY_SPECIFIC
(
input
),
PyGpuArray_DEV_DATA
(
input
),
APPLY_SPECIFIC
(
output
),
PyGpuArray_DEV_DATA
(
output
),
...
...
theano/gpuarray/dnn_pool.c
浏览文件 @
46783773
...
...
@@ -42,7 +42,7 @@ int APPLY_SPECIFIC(dnn_pool)(PyGpuArrayObject *img,
PyArrayObject
*
stride
,
PyArrayObject
*
pad
,
PyGpuArrayObject
**
out
,
cudnnHandle_t
_handle
)
{
PARAMS_TYPE
*
params
)
{
PyGpuContextObject
*
c
=
img
->
context
;
size_t
dims
[
5
];
cudnnStatus_t
err
;
...
...
@@ -90,7 +90,7 @@ int APPLY_SPECIFIC(dnn_pool)(PyGpuArrayObject *img,
if
(
c_set_tensorNd
(
*
out
,
APPLY_SPECIFIC
(
output
))
!=
0
)
return
1
;
err
=
cudnnSetPoolingNdDescriptor
(
APPLY_SPECIFIC
(
pool
),
MODE_FLAG
,
CUDNN_PROPAGATE_NAN
,
ndims
,
w
,
p
,
s
);
err
=
cudnnSetPoolingNdDescriptor
(
APPLY_SPECIFIC
(
pool
),
params
->
mode
,
CUDNN_PROPAGATE_NAN
,
ndims
,
w
,
p
,
s
);
if
(
err
!=
CUDNN_STATUS_SUCCESS
)
{
PyErr_Format
(
PyExc_RuntimeError
,
"could not set op descriptor %s"
,
cudnnGetErrorString
(
err
));
...
...
@@ -124,7 +124,7 @@ int APPLY_SPECIFIC(dnn_pool)(PyGpuArrayObject *img,
cuda_wait
((
*
out
)
->
ga
.
data
,
GPUARRAY_CUDA_WAIT_WRITE
);
err
=
cudnnPoolingForward
(
_
handle
,
APPLY_SPECIFIC
(
pool
),
params
->
handle
,
APPLY_SPECIFIC
(
pool
),
alpha
,
APPLY_SPECIFIC
(
input
),
PyGpuArray_DEV_DATA
(
img
),
beta
,
...
...
theano/gpuarray/dnn_pool_grad.c
浏览文件 @
46783773
...
...
@@ -64,7 +64,7 @@ int APPLY_SPECIFIC(dnn_pool_grad)(PyGpuArrayObject *inp,
PyArrayObject
*
stride
,
PyArrayObject
*
pad
,
PyGpuArrayObject
**
inp_grad
,
cudnnHandle_t
_handle
)
{
PARAMS_TYPE
*
params
)
{
PyGpuContextObject
*
c
=
inp
->
context
;
cudnnStatus_t
err
;
...
...
@@ -116,7 +116,7 @@ int APPLY_SPECIFIC(dnn_pool_grad)(PyGpuArrayObject *inp,
s
[
i
]
=
*
((
npy_intp
*
)
PyArray_GETPTR1
(
stride
,
i
));
}
err
=
cudnnSetPoolingNdDescriptor
(
APPLY_SPECIFIC
(
pool
),
MODE_FLAG
,
CUDNN_PROPAGATE_NAN
,
ndims
,
w
,
p
,
s
);
err
=
cudnnSetPoolingNdDescriptor
(
APPLY_SPECIFIC
(
pool
),
params
->
mode
,
CUDNN_PROPAGATE_NAN
,
ndims
,
w
,
p
,
s
);
if
(
err
!=
CUDNN_STATUS_SUCCESS
)
{
PyErr_Format
(
PyExc_RuntimeError
,
"could not set op descriptor %s"
,
cudnnGetErrorString
(
err
));
...
...
@@ -155,7 +155,7 @@ int APPLY_SPECIFIC(dnn_pool_grad)(PyGpuArrayObject *inp,
cuda_wait
((
*
inp_grad
)
->
ga
.
data
,
GPUARRAY_CUDA_WAIT_WRITE
);
err
=
cudnnPoolingBackward
(
_
handle
,
APPLY_SPECIFIC
(
pool
),
params
->
handle
,
APPLY_SPECIFIC
(
pool
),
alpha
,
APPLY_SPECIFIC
(
output
),
PyGpuArray_DEV_DATA
(
out
),
APPLY_SPECIFIC
(
output_grad
),
PyGpuArray_DEV_DATA
(
out_grad
),
...
...
theano/gpuarray/tests/test_dnn.py
浏览文件 @
46783773
...
...
@@ -31,6 +31,20 @@ mode_with_gpu = mode_with_gpu.including()
mode_with_gpu
.
check_py_code
=
False
# This variable will store the list of pooling modes available with the current runtime cuDNN version.
# Don't use this variable directly, always call `get_dnn_pool_modes()` instead.
dnn_pool_modes
=
None
def
get_dnn_pool_modes
():
# This function is called only by pooling tests to initialize and/or get dnn_pool_modes.
global
dnn_pool_modes
if
dnn_pool_modes
is
None
:
from
..
import
cudnn_defs
dnn_pool_modes
=
cudnn_defs
.
get_definitions
(
dnn
.
version
(
raises
=
False
))
.
cudnnPoolingMode_t
.
get_aliases
()
return
dnn_pool_modes
# If using float16, set CUDNN precision to float32
def
set_precision
(
floatX
):
if
floatX
==
"float16"
:
...
...
@@ -155,11 +169,7 @@ def test_pooling():
raise
SkipTest
(
dnn
.
dnn_available
.
msg
)
utt
.
seed_rng
()
# 'average_exc_pad' is disabled for versions < 4004
if
dnn
.
version
(
raises
=
False
)
<
4004
:
modes
=
(
'max'
,
'average_inc_pad'
)
else
:
modes
=
(
'max'
,
'average_inc_pad'
,
'average_exc_pad'
)
modes
=
get_dnn_pool_modes
()
x
=
T
.
tensor4
()
for
mode
,
pad
in
product
(
modes
,
...
...
@@ -242,7 +252,9 @@ def test_pooling():
for
node
in
fg
.
maker
.
fgraph
.
toposort
()])
def
test_pooling_with_tensor_vars
():
# This test will be run with different values of 'mode'
# (see next test below).
def
run_pooling_with_tensor_vars
(
mode
):
if
not
dnn
.
dnn_available
(
test_ctx_name
):
raise
SkipTest
(
dnn
.
dnn_available
.
msg
)
utt
.
seed_rng
()
...
...
@@ -251,7 +263,6 @@ def test_pooling_with_tensor_vars():
ws
=
theano
.
shared
(
np
.
array
([
2
,
2
],
dtype
=
'int32'
))
stride
=
theano
.
shared
(
np
.
array
([
1
,
1
],
dtype
=
'int32'
))
pad
=
theano
.
shared
(
np
.
array
([
0
,
0
],
dtype
=
'int32'
))
mode
=
'max'
def
fn
(
x
):
dnn_op
=
dnn
.
dnn_pool
(
...
...
@@ -297,6 +308,12 @@ def test_pooling_with_tensor_vars():
i
+=
1
def
test_pooling_with_tensor_vars
():
# Let's test for mode 'max' and also for 'max_deterministic' if available.
for
mode
in
[
m
for
m
in
get_dnn_pool_modes
()
if
m
in
(
'max'
,
'max_deterministic'
)]:
yield
(
run_pooling_with_tensor_vars
,
mode
)
def
test_pooling3d
():
# 3d pooling requires version 3 or newer.
if
not
dnn
.
dnn_available
(
test_ctx_name
)
or
dnn
.
version
(
raises
=
False
)
<
3000
:
...
...
@@ -307,11 +324,7 @@ def test_pooling3d():
mode_without_gpu_ref
=
theano
.
compile
.
mode
.
get_mode
(
'FAST_RUN'
)
.
excluding
(
'gpuarray'
)
# 'average_exc_pad' is disabled for versions < 4004
if
dnn
.
version
(
raises
=
False
)
<
4004
:
modes
=
(
'max'
,
'average_inc_pad'
)
else
:
modes
=
(
'max'
,
'average_inc_pad'
,
'average_exc_pad'
)
modes
=
get_dnn_pool_modes
()
x
=
T
.
tensor5
()
for
mode
,
pad
in
product
(
modes
,
...
...
@@ -467,11 +480,7 @@ def test_pooling_opt_arbitrary_dimensions():
raise
SkipTest
(
dnn
.
dnn_available
.
msg
)
utt
.
seed_rng
()
# 'average_exc_pad' is disabled for versions < 4004
if
dnn
.
version
(
raises
=
False
)
<
4004
:
modes
=
(
'max'
,
'average_inc_pad'
)
else
:
modes
=
(
'max'
,
'average_inc_pad'
,
'average_exc_pad'
)
modes
=
get_dnn_pool_modes
()
for
n_non_pool_dims
in
(
0
,
1
,
2
,
3
):
for
ws
in
((
2
,
2
),
(
3
,
3
,
3
)):
...
...
@@ -498,7 +507,7 @@ def test_pooling_opt_arbitrary_dimensions():
fc
=
theano
.
function
([],
out
,
mode
=
mode_without_gpu
)
assert
any
([
isinstance
(
node
.
op
,
Pool
)
for
node
in
fc
.
maker
.
fgraph
.
toposort
()])
if
mode
==
'max'
:
if
mode
in
(
'max'
,
'max_deterministic'
)
:
assert
any
([
isinstance
(
node
.
op
,
MaxPoolGrad
)
for
node
in
fc
.
maker
.
fgraph
.
toposort
()])
else
:
...
...
@@ -780,11 +789,7 @@ class TestDnnInferShapes(utt.InferShapeTester):
dtype
=
theano
.
config
.
floatX
)
# 'average_exc_pad' is disabled for versions < 4004
if
dnn
.
version
(
raises
=
False
)
<
4004
:
modes
=
[
'max'
,
'average_inc_pad'
]
else
:
modes
=
[
'max'
,
'average_inc_pad'
,
'average_exc_pad'
]
modes
=
get_dnn_pool_modes
()
for
params
in
product
(
[(
1
,
1
),
(
2
,
2
),
(
3
,
3
)],
...
...
@@ -807,11 +812,7 @@ class TestDnnInferShapes(utt.InferShapeTester):
dtype
=
theano
.
config
.
floatX
)
# 'average_exc_pad' is disabled for versions < 4004
if
dnn
.
version
(
raises
=
False
)
<
4004
:
modes
=
[
'max'
,
'average_inc_pad'
]
else
:
modes
=
[
'max'
,
'average_inc_pad'
,
'average_exc_pad'
]
modes
=
get_dnn_pool_modes
()
for
params
in
product
(
[(
1
,
1
,
1
),
(
2
,
2
,
2
),
(
3
,
3
,
3
)],
...
...
@@ -847,7 +848,8 @@ class TestDnnInferShapes(utt.InferShapeTester):
for
params
in
product
(
[(
1
,
1
),
(
2
,
2
),
(
3
,
3
)],
[(
1
,
1
),
(
2
,
2
),
(
3
,
3
)],
[
'max'
,
'average_inc_pad'
]
# modes without `average_exc_pad`
[
m
for
m
in
get_dnn_pool_modes
()
if
m
!=
'average_exc_pad'
]
):
pool_grad
=
dnn
.
GpuDnnPoolGrad
(
mode
=
params
[
2
])(
img
,
...
...
@@ -886,7 +888,8 @@ class TestDnnInferShapes(utt.InferShapeTester):
for
params
in
product
(
[(
1
,
1
,
1
),
(
2
,
2
,
2
),
(
3
,
3
,
3
)],
[(
1
,
1
,
1
),
(
2
,
2
,
2
),
(
3
,
3
,
3
)],
[
'max'
,
'average_inc_pad'
]
# modes without `average_exc_pad`
[
m
for
m
in
get_dnn_pool_modes
()
if
m
!=
'average_exc_pad'
]
):
pool_grad
=
dnn
.
GpuDnnPoolGrad
(
mode
=
params
[
2
])(
img
,
...
...
theano/tensor/signal/pool.py
浏览文件 @
46783773
...
...
@@ -434,6 +434,9 @@ class Pool(OpenMPOp):
super
(
Pool
,
self
)
.
__init__
(
openmp
=
openmp
)
self
.
ndim
=
ndim
self
.
ignore_border
=
ignore_border
if
mode
==
'max_deterministic'
:
# It seems max pool algo is already deterministic in CPU.
mode
=
'max'
if
mode
not
in
[
'max'
,
'average_inc_pad'
,
'average_exc_pad'
,
'sum'
]:
raise
ValueError
(
"Pool mode parameter only support 'max', 'sum',"
...
...
@@ -1041,6 +1044,9 @@ class PoolGrad(OpenMPOp):
def
__init__
(
self
,
ignore_border
,
mode
=
'max'
,
ndim
=
2
,
openmp
=
None
):
self
.
ndim
=
ndim
self
.
ignore_border
=
ignore_border
if
mode
==
'max_deterministic'
:
# It seems max pool grad algo is already deterministic in CPU.
mode
=
'max'
if
mode
not
in
[
'max'
,
'sum'
,
'average_inc_pad'
,
'average_exc_pad'
]:
raise
ValueError
(
"Pool mode parameter only support 'max', 'sum',"
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论