提交 0bd65807 authored 作者: notoraptor's avatar notoraptor

Rewrite choice of cuDNN definitions.

Fix cuDNN V7 integration into dnn_rnn_desc. Fix typo. Fix flake8.
上级 00786143
...@@ -29,11 +29,17 @@ int dnn_rnn_desc(int hidden_size, int num_layers, ...@@ -29,11 +29,17 @@ int dnn_rnn_desc(int hidden_size, int num_layers,
PyErr_SetString(PyExc_RuntimeError, "Can't create RNN descriptor"); PyErr_SetString(PyExc_RuntimeError, "Can't create RNN descriptor");
return -1; return -1;
} }
#if CUDNN_MAJOR < 7
err = cudnnSetRNNDescriptor(desc, hidden_size, num_layers, ddesc,
(cudnnRNNInputMode_t)input_mode,
(cudnnDirectionMode_t)direction_mode,
(cudnnRNNMode_t)rnn_mode, data_type);
#else
err = cudnnSetRNNDescriptor(_handle, desc, hidden_size, num_layers, ddesc, err = cudnnSetRNNDescriptor(_handle, desc, hidden_size, num_layers, ddesc,
(cudnnRNNInputMode_t)input_mode, (cudnnRNNInputMode_t)input_mode,
(cudnnDirectionMode_t)direction_mode, (cudnnDirectionMode_t)direction_mode,
(cudnnRNNMode_t)rnn_mode, CUDNN_RNN_ALGO_STANDARD, data_type); (cudnnRNNMode_t)rnn_mode, CUDNN_RNN_ALGO_STANDARD, data_type);
#endif
if (err != CUDNN_STATUS_SUCCESS) { if (err != CUDNN_STATUS_SUCCESS) {
cudnnDestroyRNNDescriptor(desc); cudnnDestroyRNNDescriptor(desc);
PyErr_SetString(PyExc_RuntimeError, "Can't set RNN descriptor"); PyErr_SetString(PyExc_RuntimeError, "Can't set RNN descriptor");
......
...@@ -103,7 +103,7 @@ class CuDNNV6(CuDNNV51): ...@@ -103,7 +103,7 @@ class CuDNNV6(CuDNNV51):
# new in v6 # new in v6
('CUDNN_DATA_INT8', 'int8'), ('CUDNN_DATA_INT8', 'int8'),
('CUDNN_DATA_INT32', 'int32'), ('CUDNN_DATA_INT32', 'int32'),
# ('CUDNN_DATA_INT8X4', 'int8x4'), # ('CUDNN_DATA_INT8X4', 'int8x4'),
ctype='cudnnDataType_t') ctype='cudnnDataType_t')
cudnnPoolingMode_t = CEnumType(('CUDNN_POOLING_MAX', 'max'), cudnnPoolingMode_t = CEnumType(('CUDNN_POOLING_MAX', 'max'),
...@@ -133,14 +133,16 @@ class CuDNNV6(CuDNNV51): ...@@ -133,14 +133,16 @@ class CuDNNV6(CuDNNV51):
('CUDNN_REDUCE_TENSOR_NORM2', 'norm2'), ('CUDNN_REDUCE_TENSOR_NORM2', 'norm2'),
ctype='cudnnReduceTensorOp_t') ctype='cudnnReduceTensorOp_t')
class CuDNNV7(CuDNNV6): class CuDNNV7(CuDNNV6):
version = 7 version = 7
cudnnMathType_t = CEnumType(('CUDNN_DEFAULT_MATH', 'non_tensor_op'), cudnnMathType_t = CEnumType(('CUDNN_DEFAULT_MATH', 'non_tensor_op'),
('CUDNN_TENSOR_OP_MATH', 'tensor_op'), ('CUDNN_TENSOR_OP_MATH', 'tensor_op'),
ctype = 'cudnnMathType_t') ctype='cudnnMathType_t')
cudnnDeterminism_t = CEnumType(('CUDNN_NON_DETERMINISTIC', 'non_deterministic'), cudnnDeterminism_t = CEnumType(('CUDNN_NON_DETERMINISTIC', 'non_deterministic'),
('CUDNN_DETERMINISTIC', 'deterministic'), ('CUDNN_DETERMINISTIC', 'deterministic'),
ctype = 'cudnnDeterminism_t') ctype='cudnnDeterminism_t')
def get_definitions(cudnn_version=None): def get_definitions(cudnn_version=None):
""" """
...@@ -151,7 +153,10 @@ def get_definitions(cudnn_version=None): ...@@ -151,7 +153,10 @@ def get_definitions(cudnn_version=None):
if None, return definitions for the most recent supported cuDNN version. if None, return definitions for the most recent supported cuDNN version.
""" """
if cudnn_version is not None and cudnn_version // 1000 == 6: if cudnn_version is not None:
if cudnn_version // 1000 == 5:
return CuDNNV51()
if cudnn_version // 1000 == 6:
return CuDNNV6() return CuDNNV6()
# By default, we use definitions for the last supported cuDNN version. # By default, we use definitions for the last supported cuDNN version.
return CuDNNV7() return CuDNNV7()
...@@ -90,7 +90,7 @@ def _dnn_lib(): ...@@ -90,7 +90,7 @@ def _dnn_lib():
if lib_name: if lib_name:
break break
if lib_name is None: if lib_name is None:
raise RuntimeError('Could not find cudnn library (looked for v5* or v6*)') raise RuntimeError('Could not find cudnn library (looked for v5* to v7*)')
else: else:
dnn_handle = ctypes.cdll.LoadLibrary(lib_name) dnn_handle = ctypes.cdll.LoadLibrary(lib_name)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论