提交 32f11551 authored 作者: carriepl's avatar carriepl 提交者: Frederic

Add 'as_input' option to dnn.conv.precision flag

上级 7ff6621e
...@@ -261,8 +261,9 @@ AddConfigVar('dnn.conv.algo_bwd_filter', ...@@ -261,8 +261,9 @@ AddConfigVar('dnn.conv.algo_bwd_filter',
AddConfigVar('dnn.conv.precision', AddConfigVar('dnn.conv.precision',
"Default data precision to use for the computation in CuDNN " "Default data precision to use for the computation in CuDNN "
"convolutions (defaults to the floatX).", "convolutions (defaults to the same dtype as the inputs of the "
EnumStr('floatX', 'float16', 'float32', 'float64'), "convolutions).",
EnumStr('as_input', 'float16', 'float32', 'float64'),
in_c_key=False) in_c_key=False)
......
...@@ -266,7 +266,7 @@ class GpuDnnConvDesc(GpuOp): ...@@ -266,7 +266,7 @@ class GpuDnnConvDesc(GpuOp):
return False return False
def __init__(self, border_mode, subsample=(1, 1), conv_mode='conv', def __init__(self, border_mode, subsample=(1, 1), conv_mode='conv',
precision=None): precision="float32"):
if isinstance(border_mode, int): if isinstance(border_mode, int):
border_mode = (border_mode,) * len(subsample) border_mode = (border_mode,) * len(subsample)
if isinstance(border_mode, tuple): if isinstance(border_mode, tuple):
...@@ -284,10 +284,6 @@ class GpuDnnConvDesc(GpuOp): ...@@ -284,10 +284,6 @@ class GpuDnnConvDesc(GpuOp):
assert conv_mode in ('conv', 'cross') assert conv_mode in ('conv', 'cross')
self.conv_mode = conv_mode self.conv_mode = conv_mode
if precision is None:
precision = theano.config.dnn.conv.precision
if precision == 'floatX':
precision = theano.config.floatX
assert precision in ['float16', 'float32', 'float64'] assert precision in ['float16', 'float32', 'float64']
self.precision = precision self.precision = precision
...@@ -1140,13 +1136,19 @@ def dnn_conv(img, kerns, border_mode='valid', subsample=(1, 1), ...@@ -1140,13 +1136,19 @@ def dnn_conv(img, kerns, border_mode='valid', subsample=(1, 1),
Convolution implementation to use. Some of its values may require certain Convolution implementation to use. Some of its values may require certain
versions of CuDNN to be installed. Default is the value of versions of CuDNN to be installed. Default is the value of
:attr:`config.dnn.conv.algo_fwd`. :attr:`config.dnn.conv.algo_fwd`.
precision : {'float16', 'float32', 'float64', 'floatX'} precision : {'as_input', 'float16', 'float32', 'float64'}
Description of the dtype in which the computation of the convolution Description of the dtype in which the computation of the convolution
should be done. Default is the value of should be done. Default is the value of
:attr:`config.dnn.conv.precision`. :attr:`config.dnn.conv.precision`.
""" """
# Establish dtype in which to perform the computation of the convolution
if precision is None:
precision = theano.config.dnn.conv.precision
if precision == 'as_input':
precision = theano.scalar.upcast(img.dtype, kerns.dtype)
# Check if deprecated param 'workmem' is used # Check if deprecated param 'workmem' is used
if workmem is not None: if workmem is not None:
warnings.warn(("dnn_conv: parameter 'workmem' is deprecated. Use " warnings.warn(("dnn_conv: parameter 'workmem' is deprecated. Use "
...@@ -1240,8 +1242,9 @@ def dnn_conv3d(img, kerns, border_mode='valid', subsample=(1, 1, 1), ...@@ -1240,8 +1242,9 @@ def dnn_conv3d(img, kerns, border_mode='valid', subsample=(1, 1, 1),
for the conv3d. Default is the value of for the conv3d. Default is the value of
:attr:`config.dnn.conv.algo_fwd`. :attr:`config.dnn.conv.algo_fwd`.
:param precision : dtype in which the computation of the convolution :param precision : dtype in which the computation of the convolution
should be done. Possible values are 'float16', 'float32', 'float64' and should be done. Possible values are 'as_input', 'float16', 'float32'
'floatX'. Default is the value of :attr:`config.dnn.conv.precision`. and 'float64'. Default is the value of
:attr:`config.dnn.conv.precision`.
:warning: The cuDNN library only works with GPU that have a compute :warning: The cuDNN library only works with GPU that have a compute
capability of 3.0 or higer. This means that older GPU will not capability of 3.0 or higer. This means that older GPU will not
...@@ -1250,6 +1253,12 @@ def dnn_conv3d(img, kerns, border_mode='valid', subsample=(1, 1, 1), ...@@ -1250,6 +1253,12 @@ def dnn_conv3d(img, kerns, border_mode='valid', subsample=(1, 1, 1),
""" """
# Establish dtype in which to perform the computation of the convolution
if precision is None:
precision = theano.config.dnn.conv.precision
if precision == 'as_input':
precision = theano.scalar.upcast(img.dtype, kerns.dtype)
# Check if deprecated param 'workmem' is used # Check if deprecated param 'workmem' is used
if workmem is not None: if workmem is not None:
warnings.warn(("dnn_conv3d: parameter 'workmem' is deprecated. Use " warnings.warn(("dnn_conv3d: parameter 'workmem' is deprecated. Use "
......
...@@ -259,7 +259,7 @@ class GpuDnnConvDesc(COp): ...@@ -259,7 +259,7 @@ class GpuDnnConvDesc(COp):
return False return False
def __init__(self, border_mode, subsample=(1, 1), conv_mode='conv', def __init__(self, border_mode, subsample=(1, 1), conv_mode='conv',
precision=None): precision="float32"):
COp.__init__(self, ["conv_desc.c"], "APPLY_SPECIFIC(conv_desc)") COp.__init__(self, ["conv_desc.c"], "APPLY_SPECIFIC(conv_desc)")
if isinstance(border_mode, int): if isinstance(border_mode, int):
...@@ -279,10 +279,6 @@ class GpuDnnConvDesc(COp): ...@@ -279,10 +279,6 @@ class GpuDnnConvDesc(COp):
assert conv_mode in ('conv', 'cross') assert conv_mode in ('conv', 'cross')
self.conv_mode = conv_mode self.conv_mode = conv_mode
if precision is None:
precision = theano.config.dnn.conv.precision
if precision == 'floatX':
precision = theano.config.floatX
assert precision in ['float16', 'float32', 'float64'] assert precision in ['float16', 'float32', 'float64']
self.precision = precision self.precision = precision
...@@ -802,7 +798,7 @@ def dnn_conv(img, kerns, border_mode='valid', subsample=(1, 1), ...@@ -802,7 +798,7 @@ def dnn_conv(img, kerns, border_mode='valid', subsample=(1, 1),
Convolution implementation to use. Some of its values may Convolution implementation to use. Some of its values may
require certain versions of CuDNN to be installed. Default is require certain versions of CuDNN to be installed. Default is
the value of :attr:`config.dnn.conv.algo_fwd`. the value of :attr:`config.dnn.conv.algo_fwd`.
precision : {'float16', 'float32', 'float64', 'floatX'} precision : {'as_input', 'float16', 'float32', 'float64'}
Description of the dtype in which the computation of the convolution Description of the dtype in which the computation of the convolution
should be done. Default is the value of should be done. Default is the value of
:attr:`config.dnn.conv.precision`. :attr:`config.dnn.conv.precision`.
...@@ -812,6 +808,13 @@ def dnn_conv(img, kerns, border_mode='valid', subsample=(1, 1), ...@@ -812,6 +808,13 @@ def dnn_conv(img, kerns, border_mode='valid', subsample=(1, 1),
work with this Op. work with this Op.
""" """
# Establish dtype in which to perform the computation of the convolution
if precision is None:
precision = theano.config.dnn.conv.precision
if precision == 'as_input':
precision = theano.scalar.upcast(img.dtype, kerns.dtype)
if workmem is not None: if workmem is not None:
if algo is not None: if algo is not None:
raise ValueError("You can't use both algo and workmem") raise ValueError("You can't use both algo and workmem")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论