提交 32f11551 authored 作者: carriepl's avatar carriepl 提交者: Frederic

Add 'as_input' option to dnn.conv.precision flag

上级 7ff6621e
......@@ -261,8 +261,9 @@ AddConfigVar('dnn.conv.algo_bwd_filter',
AddConfigVar('dnn.conv.precision',
"Default data precision to use for the computation in CuDNN "
"convolutions (defaults to the floatX).",
EnumStr('floatX', 'float16', 'float32', 'float64'),
"convolutions (defaults to the same dtype as the inputs of the "
"convolutions).",
EnumStr('as_input', 'float16', 'float32', 'float64'),
in_c_key=False)
......
......@@ -266,7 +266,7 @@ class GpuDnnConvDesc(GpuOp):
return False
def __init__(self, border_mode, subsample=(1, 1), conv_mode='conv',
precision=None):
precision="float32"):
if isinstance(border_mode, int):
border_mode = (border_mode,) * len(subsample)
if isinstance(border_mode, tuple):
......@@ -284,10 +284,6 @@ class GpuDnnConvDesc(GpuOp):
assert conv_mode in ('conv', 'cross')
self.conv_mode = conv_mode
if precision is None:
precision = theano.config.dnn.conv.precision
if precision == 'floatX':
precision = theano.config.floatX
assert precision in ['float16', 'float32', 'float64']
self.precision = precision
......@@ -1140,13 +1136,19 @@ def dnn_conv(img, kerns, border_mode='valid', subsample=(1, 1),
Convolution implementation to use. Some of its values may require certain
versions of CuDNN to be installed. Default is the value of
:attr:`config.dnn.conv.algo_fwd`.
precision : {'float16', 'float32', 'float64', 'floatX'}
precision : {'as_input', 'float16', 'float32', 'float64'}
Description of the dtype in which the computation of the convolution
should be done. Default is the value of
:attr:`config.dnn.conv.precision`.
"""
# Establish dtype in which to perform the computation of the convolution
if precision is None:
precision = theano.config.dnn.conv.precision
if precision == 'as_input':
precision = theano.scalar.upcast(img.dtype, kerns.dtype)
# Check if deprecated param 'workmem' is used
if workmem is not None:
warnings.warn(("dnn_conv: parameter 'workmem' is deprecated. Use "
......@@ -1240,8 +1242,9 @@ def dnn_conv3d(img, kerns, border_mode='valid', subsample=(1, 1, 1),
for the conv3d. Default is the value of
:attr:`config.dnn.conv.algo_fwd`.
:param precision : dtype in which the computation of the convolution
should be done. Possible values are 'float16', 'float32', 'float64' and
'floatX'. Default is the value of :attr:`config.dnn.conv.precision`.
should be done. Possible values are 'as_input', 'float16', 'float32'
and 'float64'. Default is the value of
:attr:`config.dnn.conv.precision`.
:warning: The cuDNN library only works with GPU that have a compute
capability of 3.0 or higer. This means that older GPU will not
......@@ -1250,6 +1253,12 @@ def dnn_conv3d(img, kerns, border_mode='valid', subsample=(1, 1, 1),
"""
# Establish dtype in which to perform the computation of the convolution
if precision is None:
precision = theano.config.dnn.conv.precision
if precision == 'as_input':
precision = theano.scalar.upcast(img.dtype, kerns.dtype)
# Check if deprecated param 'workmem' is used
if workmem is not None:
warnings.warn(("dnn_conv3d: parameter 'workmem' is deprecated. Use "
......
......@@ -259,7 +259,7 @@ class GpuDnnConvDesc(COp):
return False
def __init__(self, border_mode, subsample=(1, 1), conv_mode='conv',
precision=None):
precision="float32"):
COp.__init__(self, ["conv_desc.c"], "APPLY_SPECIFIC(conv_desc)")
if isinstance(border_mode, int):
......@@ -279,10 +279,6 @@ class GpuDnnConvDesc(COp):
assert conv_mode in ('conv', 'cross')
self.conv_mode = conv_mode
if precision is None:
precision = theano.config.dnn.conv.precision
if precision == 'floatX':
precision = theano.config.floatX
assert precision in ['float16', 'float32', 'float64']
self.precision = precision
......@@ -802,7 +798,7 @@ def dnn_conv(img, kerns, border_mode='valid', subsample=(1, 1),
Convolution implementation to use. Some of its values may
require certain versions of CuDNN to be installed. Default is
the value of :attr:`config.dnn.conv.algo_fwd`.
precision : {'float16', 'float32', 'float64', 'floatX'}
precision : {'as_input', 'float16', 'float32', 'float64'}
Description of the dtype in which the computation of the convolution
should be done. Default is the value of
:attr:`config.dnn.conv.precision`.
......@@ -812,6 +808,13 @@ def dnn_conv(img, kerns, border_mode='valid', subsample=(1, 1),
work with this Op.
"""
# Establish dtype in which to perform the computation of the convolution
if precision is None:
precision = theano.config.dnn.conv.precision
if precision == 'as_input':
precision = theano.scalar.upcast(img.dtype, kerns.dtype)
if workmem is not None:
if algo is not None:
raise ValueError("You can't use both algo and workmem")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论