提交 6bbe6a92 authored 作者: Boris Fomitchev's avatar Boris Fomitchev 提交者: notoraptor

Tensor op implementation for Volta, first cut at algo cache

上级 3a6c0f70
...@@ -11,6 +11,7 @@ Currently supported cuDNN APIs: ...@@ -11,6 +11,7 @@ Currently supported cuDNN APIs:
- v5.1 - v5.1
- v6.0 - v6.0
- v7.0
""" """
...@@ -102,8 +103,7 @@ class CuDNNV6(CuDNNV51): ...@@ -102,8 +103,7 @@ class CuDNNV6(CuDNNV51):
# new in v6 # new in v6
('CUDNN_DATA_INT8', 'int8'), ('CUDNN_DATA_INT8', 'int8'),
('CUDNN_DATA_INT32', 'int32'), ('CUDNN_DATA_INT32', 'int32'),
# Also in v6, but restrictions make this fail # ('CUDNN_DATA_INT8X4', 'int8x4'),
# CUDNN_DATA_INT8x4
ctype='cudnnDataType_t') ctype='cudnnDataType_t')
cudnnPoolingMode_t = CEnumType(('CUDNN_POOLING_MAX', 'max'), cudnnPoolingMode_t = CEnumType(('CUDNN_POOLING_MAX', 'max'),
...@@ -117,10 +117,8 @@ class CuDNNV6(CuDNNV51): ...@@ -117,10 +117,8 @@ class CuDNNV6(CuDNNV51):
('CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1', 'deterministic'), ('CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1', 'deterministic'),
('CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT', 'fft'), ('CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT', 'fft'),
('CUDNN_CONVOLUTION_BWD_FILTER_ALGO_3', 'small'), ('CUDNN_CONVOLUTION_BWD_FILTER_ALGO_3', 'small'),
# not implemented:
('CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD'), ('CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD'),
('CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD_NONFUSED', 'winograd_non_fused'), ('CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD_NONFUSED', 'winograd_non_fused'),
# TODO: not yet tested/documented:
# new in v6: # new in v6:
('CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT_TILING', 'fft_tiling'), ('CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT_TILING', 'fft_tiling'),
ctype='cudnnConvolutionBwdFilterAlgo_t') ctype='cudnnConvolutionBwdFilterAlgo_t')
...@@ -135,7 +133,15 @@ class CuDNNV6(CuDNNV51): ...@@ -135,7 +133,15 @@ class CuDNNV6(CuDNNV51):
('CUDNN_REDUCE_TENSOR_NORM2', 'norm2'), ('CUDNN_REDUCE_TENSOR_NORM2', 'norm2'),
ctype='cudnnReduceTensorOp_t') ctype='cudnnReduceTensorOp_t')
class CuDNNV7(CuDNNV6):
version = 7
cudnnMathType_t = CEnumType(('CUDNN_DEFAULT_MATH', 'non_tensor_op'),
('CUDNN_TENSOR_OP_MATH', 'tensor_op'),
ctype = 'cudnnMathType_t')
cudnnDeterminism_t = CEnumType(('CUDNN_NON_DETERMINISTIC', 'non_deterministic'),
('CUDNN_DETERMINISTIC', 'deterministic'),
ctype = 'cudnnDeterminism_t')
def get_definitions(cudnn_version=None): def get_definitions(cudnn_version=None):
""" """
Return cuDNN definitions to be used by Theano for the given cuDNN version. Return cuDNN definitions to be used by Theano for the given cuDNN version.
...@@ -145,7 +151,7 @@ def get_definitions(cudnn_version=None): ...@@ -145,7 +151,7 @@ def get_definitions(cudnn_version=None):
if None, return definitions for the most recent supported cuDNN version. if None, return definitions for the most recent supported cuDNN version.
""" """
if cudnn_version is not None and cudnn_version // 1000 == 5: if cudnn_version is not None and cudnn_version // 1000 == 6:
return CuDNNV51() return CuDNNV6()
# By default, we use definitions for the last supported cuDNN version. # By default, we use definitions for the last supported cuDNN version.
return CuDNNV6() return CuDNNV7()
...@@ -58,7 +58,7 @@ except ImportError: ...@@ -58,7 +58,7 @@ except ImportError:
pass pass
# Update these names when new versions of cudnn are supported. # Update these names when new versions of cudnn are supported.
WIN32_CUDNN_NAMES = ['cudnn64_6.dll', 'cudnn64_5.dll'] WIN32_CUDNN_NAMES = ['cudnn64_7.dll', 'cudnn64_6.dll', 'cudnn64_5.dll']
def _load_lib(name): def _load_lib(name):
...@@ -166,11 +166,11 @@ def _dnn_check_version(): ...@@ -166,11 +166,11 @@ def _dnn_check_version():
v = version() v = version()
if v < 5000: if v < 5000:
return False, "cuDNN version is too old. Update to v5* or higher, was %d." % v return False, "cuDNN version is too old. Update to v5* or higher, was %d." % v
if v >= 6100: if v >= 7200:
warnings.warn("Your cuDNN version is more recent than " warnings.warn("Your cuDNN version is more recent than "
"Theano. If you encounter problems, try " "Theano. If you encounter problems, try "
"updating Theano or downgrading cuDNN to " "updating Theano or downgrading cuDNN to "
"a version >= v5 and <= v6.") "a version >= v5 and <= v7.")
return True, None return True, None
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论