提交 90c50340 authored 作者: Frédéric Bastien's avatar Frédéric Bastien 提交者: GitHub

Merge pull request #4771 from nouiz/cudnn_conv_precision_fix

Make the old backend support as_input_f32 theano.config.dnn.precision.
......@@ -375,7 +375,7 @@ AddConfigVar('dnn.enabled',
" to not using it if not present."
" If True and cuDNN can not be used, raise an error."
" If False, disable cudnn",
StrParam("auto", "True", "False"),
EnumStr("auto", "True", "False"),
in_c_key=False)
# This flag determines whether or not to raise error/warning message if
......
......@@ -873,7 +873,7 @@ def dnn_conv(img, kerns, border_mode='valid', subsample=(1, 1),
Convolution implementation to use. Some of its values may
require certain versions of cuDNN to be installed. Default is
the value of :attr:`config.dnn.conv.algo_fwd`.
precision : {'as_input', 'float16', 'float32', 'float64'}
precision : {'as_input_f32', 'as_input', 'float16', 'float32', 'float64'}
Description of the dtype in which the computation of the convolution
should be done. Possible values are 'as_input', 'float16', 'float32'
and 'float64'. Default is the value of
......
......@@ -1108,7 +1108,7 @@ def dnn_conv(img, kerns, border_mode='valid', subsample=(1, 1),
Convolution implementation to use. Some of its values may require certain
versions of cuDNN to be installed. Default is the value of
:attr:`config.dnn.conv.algo_fwd`.
precision : {'as_input', 'float16', 'float32', 'float64'}
precision : {'as_input_f32', 'as_input', 'float16', 'float32', 'float64'}
Description of the dtype in which the computation of the convolution
should be done. Possible values are 'as_input', 'float16', 'float32'
and 'float64'. Default is the value of
......@@ -1122,8 +1122,12 @@ def dnn_conv(img, kerns, border_mode='valid', subsample=(1, 1),
# Establish dtype in which to perform the computation of the convolution
if precision is None:
precision = theano.config.dnn.conv.precision
if precision == 'as_input':
precision = theano.scalar.upcast(img.dtype, kerns.dtype)
if precision == 'as_input' or precision == 'as_input_f32':
nprec = theano.scalar.upcast(img.dtype, kerns.dtype)
if nprec == 'float16' and precision == 'as_input_f32':
precision = 'float32'
else:
precision = nprec
# Check if deprecated param 'workmem' is used
if workmem is not None:
......@@ -1218,8 +1222,8 @@ def dnn_conv3d(img, kerns, border_mode='valid', subsample=(1, 1, 1),
for the conv3d. Default is the value of
:attr:`config.dnn.conv.algo_fwd`.
:param precision : dtype in which the computation of the convolution
should be done. Possible values are 'as_input', 'float16', 'float32'
and 'float64'. Default is the value of
should be done. Possible values are 'as_input_f32', 'as_input',
'float16', 'float32' and 'float64'. Default is the value of
:attr:`config.dnn.conv.precision`.
:warning: The cuDNN library only works with GPU that have a compute
......@@ -1234,8 +1238,12 @@ def dnn_conv3d(img, kerns, border_mode='valid', subsample=(1, 1, 1),
# Establish dtype in which to perform the computation of the convolution
if precision is None:
precision = theano.config.dnn.conv.precision
if precision == 'as_input':
precision = theano.scalar.upcast(img.dtype, kerns.dtype)
if precision == 'as_input' or precision == 'as_input_f32':
nprec = theano.scalar.upcast(img.dtype, kerns.dtype)
if nprec == 'float16' and precision == 'as_input_f32':
precision = 'float32'
else:
precision = nprec
# Check if deprecated param 'workmem' is used
if workmem is not None:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论