提交 90c50340 authored 作者: Frédéric Bastien's avatar Frédéric Bastien 提交者: GitHub

Merge pull request #4771 from nouiz/cudnn_conv_precision_fix

Make the old backend support as_input_f32 theano.config.dnn.precision.
...@@ -375,7 +375,7 @@ AddConfigVar('dnn.enabled', ...@@ -375,7 +375,7 @@ AddConfigVar('dnn.enabled',
" to not using it if not present." " to not using it if not present."
" If True and cuDNN can not be used, raise an error." " If True and cuDNN can not be used, raise an error."
" If False, disable cudnn", " If False, disable cudnn",
StrParam("auto", "True", "False"), EnumStr("auto", "True", "False"),
in_c_key=False) in_c_key=False)
# This flag determines whether or not to raise error/warning message if # This flag determines whether or not to raise error/warning message if
......
...@@ -873,7 +873,7 @@ def dnn_conv(img, kerns, border_mode='valid', subsample=(1, 1), ...@@ -873,7 +873,7 @@ def dnn_conv(img, kerns, border_mode='valid', subsample=(1, 1),
Convolution implementation to use. Some of its values may Convolution implementation to use. Some of its values may
require certain versions of cuDNN to be installed. Default is require certain versions of cuDNN to be installed. Default is
the value of :attr:`config.dnn.conv.algo_fwd`. the value of :attr:`config.dnn.conv.algo_fwd`.
precision : {'as_input', 'float16', 'float32', 'float64'} precision : {'as_input_f32', 'as_input', 'float16', 'float32', 'float64'}
Description of the dtype in which the computation of the convolution Description of the dtype in which the computation of the convolution
should be done. Possible values are 'as_input', 'float16', 'float32' should be done. Possible values are 'as_input', 'float16', 'float32'
and 'float64'. Default is the value of and 'float64'. Default is the value of
......
...@@ -1108,7 +1108,7 @@ def dnn_conv(img, kerns, border_mode='valid', subsample=(1, 1), ...@@ -1108,7 +1108,7 @@ def dnn_conv(img, kerns, border_mode='valid', subsample=(1, 1),
Convolution implementation to use. Some of its values may require certain Convolution implementation to use. Some of its values may require certain
versions of cuDNN to be installed. Default is the value of versions of cuDNN to be installed. Default is the value of
:attr:`config.dnn.conv.algo_fwd`. :attr:`config.dnn.conv.algo_fwd`.
precision : {'as_input', 'float16', 'float32', 'float64'} precision : {'as_input_f32', 'as_input', 'float16', 'float32', 'float64'}
Description of the dtype in which the computation of the convolution Description of the dtype in which the computation of the convolution
should be done. Possible values are 'as_input', 'float16', 'float32' should be done. Possible values are 'as_input', 'float16', 'float32'
and 'float64'. Default is the value of and 'float64'. Default is the value of
...@@ -1122,8 +1122,12 @@ def dnn_conv(img, kerns, border_mode='valid', subsample=(1, 1), ...@@ -1122,8 +1122,12 @@ def dnn_conv(img, kerns, border_mode='valid', subsample=(1, 1),
# Establish dtype in which to perform the computation of the convolution # Establish dtype in which to perform the computation of the convolution
if precision is None: if precision is None:
precision = theano.config.dnn.conv.precision precision = theano.config.dnn.conv.precision
if precision == 'as_input': if precision == 'as_input' or precision == 'as_input_f32':
precision = theano.scalar.upcast(img.dtype, kerns.dtype) nprec = theano.scalar.upcast(img.dtype, kerns.dtype)
if nprec == 'float16' and precision == 'as_input_f32':
precision = 'float32'
else:
precision = nprec
# Check if deprecated param 'workmem' is used # Check if deprecated param 'workmem' is used
if workmem is not None: if workmem is not None:
...@@ -1218,8 +1222,8 @@ def dnn_conv3d(img, kerns, border_mode='valid', subsample=(1, 1, 1), ...@@ -1218,8 +1222,8 @@ def dnn_conv3d(img, kerns, border_mode='valid', subsample=(1, 1, 1),
for the conv3d. Default is the value of for the conv3d. Default is the value of
:attr:`config.dnn.conv.algo_fwd`. :attr:`config.dnn.conv.algo_fwd`.
:param precision : dtype in which the computation of the convolution :param precision : dtype in which the computation of the convolution
should be done. Possible values are 'as_input', 'float16', 'float32' should be done. Possible values are 'as_input_f32', 'as_input',
and 'float64'. Default is the value of 'float16', 'float32' and 'float64'. Default is the value of
:attr:`config.dnn.conv.precision`. :attr:`config.dnn.conv.precision`.
:warning: The cuDNN library only works with GPU that have a compute :warning: The cuDNN library only works with GPU that have a compute
...@@ -1234,8 +1238,12 @@ def dnn_conv3d(img, kerns, border_mode='valid', subsample=(1, 1, 1), ...@@ -1234,8 +1238,12 @@ def dnn_conv3d(img, kerns, border_mode='valid', subsample=(1, 1, 1),
# Establish dtype in which to perform the computation of the convolution # Establish dtype in which to perform the computation of the convolution
if precision is None: if precision is None:
precision = theano.config.dnn.conv.precision precision = theano.config.dnn.conv.precision
if precision == 'as_input': if precision == 'as_input' or precision == 'as_input_f32':
precision = theano.scalar.upcast(img.dtype, kerns.dtype) nprec = theano.scalar.upcast(img.dtype, kerns.dtype)
if nprec == 'float16' and precision == 'as_input_f32':
precision = 'float32'
else:
precision = nprec
# Check if deprecated param 'workmem' is used # Check if deprecated param 'workmem' is used
if workmem is not None: if workmem is not None:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论