提交 fb6da2d1 authored 作者: notoraptor's avatar notoraptor

Fix typos and forbide float16 precision for gradinput deterministic.

上级 3041071d
......@@ -9,9 +9,9 @@ for a given cuDNN version.
Currently supported cuDNN APIs:
- v5.1
- v6.0
- v7.0
- v5.1*
- v6.0*
- v7.0*
"""
......@@ -125,7 +125,7 @@ class CuDNNV51(object):
def get_supported_dtype_configs(self, check_runtime=None):
"""
Return the tuple of data type configurations supported by this version of cuDNN.
This is currently convenient for both cuDNN V5.1 and V6, as Theano does not
This is currently convenient for all supported cuDNN versions, as Theano does not
yet support new data types (like INT8, INT8x4, etc.).
``check_runtime`` may be a function that tests if a data type configuration is supported.::
......@@ -199,7 +199,9 @@ class CuDNNV51(object):
return not is_true_half_config(dtype, precision)
if algo == algorithms.CUDNN_CONVOLUTION_BWD_DATA_ALGO_1:
# CUDNN_CONVOLUTION_BWD_DATA_ALGO_1: all data type configs supported.
return True
# NB: Let's avoid float16 precision, as some strange errors may be encountered
# with that precision ( see https://github.com/Theano/Theano/pull/5932/ )
return not is_true_half_config(dtype, precision)
if algo == algorithms.CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT:
return ndim == 2 and (is_pseudo_half_config(dtype, precision) or is_float_config(dtype, precision))
if algo == algorithms.CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论