提交 ca407db4 authored 作者: notoraptor's avatar notoraptor

Extend cudnn_defs

with list of deterministic algorithms and other utility definitions.
上级 8ef62afb
...@@ -19,6 +19,10 @@ from __future__ import absolute_import, print_function, division ...@@ -19,6 +19,10 @@ from __future__ import absolute_import, print_function, division
from theano.gof import CEnumType from theano.gof import CEnumType
HALF, FLOAT, DOUBLE = ('float16', 'float32', 'float64')
# NB: Some cuDNN algorithms are listed in cuDNN enums but not implemented. # NB: Some cuDNN algorithms are listed in cuDNN enums but not implemented.
# We still register them here because we try to exactly copy cuDNN enums # We still register them here because we try to exactly copy cuDNN enums
# in Python side, but they will have no aliases associated, to help # in Python side, but they will have no aliases associated, to help
...@@ -51,6 +55,8 @@ class CuDNNV51(object): ...@@ -51,6 +55,8 @@ class CuDNNV51(object):
conv3d_fwd_algorithms = ('none', 'small', 'fft_tiling') conv3d_fwd_algorithms = ('none', 'small', 'fft_tiling')
deterministic_fwd_algorithms = cudnnConvolutionFwdAlgo_t.get_aliases()
cudnnConvolutionBwdFilterAlgo_t = CEnumType(('CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0', 'none'), cudnnConvolutionBwdFilterAlgo_t = CEnumType(('CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0', 'none'),
('CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1', 'deterministic'), ('CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1', 'deterministic'),
('CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT', 'fft'), ('CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT', 'fft'),
...@@ -61,6 +67,8 @@ class CuDNNV51(object): ...@@ -61,6 +67,8 @@ class CuDNNV51(object):
conv3d_bwd_filter_algorithms = ('none', 'small') conv3d_bwd_filter_algorithms = ('none', 'small')
deterministic_bwd_filter_algorithms = ('deterministic', 'fft', 'winograd_non_fused')
cudnnConvolutionBwdDataAlgo_t = CEnumType(('CUDNN_CONVOLUTION_BWD_DATA_ALGO_0', 'none'), cudnnConvolutionBwdDataAlgo_t = CEnumType(('CUDNN_CONVOLUTION_BWD_DATA_ALGO_0', 'none'),
('CUDNN_CONVOLUTION_BWD_DATA_ALGO_1', 'deterministic'), ('CUDNN_CONVOLUTION_BWD_DATA_ALGO_1', 'deterministic'),
('CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT', 'fft'), ('CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT', 'fft'),
...@@ -72,6 +80,8 @@ class CuDNNV51(object): ...@@ -72,6 +80,8 @@ class CuDNNV51(object):
conv3d_bwd_data_algorithms = ('none', 'deterministic', 'fft_tiling') conv3d_bwd_data_algorithms = ('none', 'deterministic', 'fft_tiling')
deterministic_bwd_data_algorithms = ('deterministic', 'fft', 'fft_tiling', 'winograd', 'winograd_non_fused')
cudnnPoolingMode_t = CEnumType(('CUDNN_POOLING_MAX', 'max'), cudnnPoolingMode_t = CEnumType(('CUDNN_POOLING_MAX', 'max'),
('CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING', 'average_inc_pad'), ('CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING', 'average_inc_pad'),
('CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING', 'average_exc_pad'), ('CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING', 'average_exc_pad'),
...@@ -93,6 +103,23 @@ class CuDNNV51(object): ...@@ -93,6 +103,23 @@ class CuDNNV51(object):
# empty list of enum to don't crash with cudnn 5 # empty list of enum to don't crash with cudnn 5
cudnnReduceTensorOp_t = CEnumType() cudnnReduceTensorOp_t = CEnumType()
def supported_precisions(self, dtype):
"""
Return the tuple of precisions supported by cuDNN for given input data type.
This is currently convenient for both cuDNN V5.1 and V6, as Theano does not
yet support new data types (like INT8, INT8x4, etc.).
"""
assert dtype in (HALF, FLOAT, DOUBLE)
if dtype == HALF:
# TRUE_HALF_CONFIG, PSEUDO_HALF_CONFIG
return (HALF, FLOAT)
if dtype == FLOAT:
# FLOAT_CONFIG
return (FLOAT,)
if dtype == DOUBLE:
# DOUBLE_CONFIG
return (DOUBLE,)
class CuDNNV6(CuDNNV51): class CuDNNV6(CuDNNV51):
version = 6 version = 6
...@@ -123,6 +150,8 @@ class CuDNNV6(CuDNNV51): ...@@ -123,6 +150,8 @@ class CuDNNV6(CuDNNV51):
('CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT_TILING', 'fft_tiling'), ('CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT_TILING', 'fft_tiling'),
ctype='cudnnConvolutionBwdFilterAlgo_t') ctype='cudnnConvolutionBwdFilterAlgo_t')
deterministic_bwd_filter_algorithms = CuDNNV51.deterministic_bwd_filter_algorithms + ('fft_tiling',)
cudnnReduceTensorOp_t = CEnumType(('CUDNN_REDUCE_TENSOR_ADD', 'add'), cudnnReduceTensorOp_t = CEnumType(('CUDNN_REDUCE_TENSOR_ADD', 'add'),
('CUDNN_REDUCE_TENSOR_MUL', 'mul'), ('CUDNN_REDUCE_TENSOR_MUL', 'mul'),
('CUDNN_REDUCE_TENSOR_MIN', 'minimum'), ('CUDNN_REDUCE_TENSOR_MIN', 'minimum'),
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论