提交 6f80c3be authored 作者: --global's avatar --global

Update dnn_conv to use new config flags

上级 5e6820a6
...@@ -1053,7 +1053,7 @@ class GpuDnnConv3dGradI(GpuDnnConvGradI): ...@@ -1053,7 +1053,7 @@ class GpuDnnConv3dGradI(GpuDnnConvGradI):
def dnn_conv(img, kerns, border_mode='valid', subsample=(1, 1), def dnn_conv(img, kerns, border_mode='valid', subsample=(1, 1),
conv_mode='conv', direction_hint=None, workmem=None): conv_mode='conv', direction_hint=None, workmem=None, algo=None):
""" """
GPU convolution using cuDNN from NVIDIA. GPU convolution using cuDNN from NVIDIA.
...@@ -1075,17 +1075,25 @@ def dnn_conv(img, kerns, border_mode='valid', subsample=(1, 1), ...@@ -1075,17 +1075,25 @@ def dnn_conv(img, kerns, border_mode='valid', subsample=(1, 1),
*not* 'forward!', it will use GpuDnnConvGradI. *not* 'forward!', it will use GpuDnnConvGradI.
This parameter is used internally by graph optimizers and may be This parameter is used internally by graph optimizers and may be
removed at any time without a deprecation period. You have been warned. removed at any time without a deprecation period. You have been warned.
:param workmem: Specify the amount of working memory allowed. :param workmem: *deprecated*, use param algo instead
More memory is usually faster. One of 'none', 'small' or :param algo: convolution implementation to use. One of 'none', 'small',
'large'. (default is None which takes its value from 'large', 'fft', 'guess_once', 'guess_on_shape_change', 'time_once' or
:attr:`config.dnn.conv.workmem`) 'time_on_shape_change'. Some of these values may require certain
versions of CuDNN to be installed. Default is the value of
:attr:`config.dnn.conv.algo_fwd.
:warning: The cuDNN library only works with GPU that have a compute :warning: The cuDNN library only works with GPU that have a compute
capability of 3.0 or higer. This means that older GPU will not capability of 3.0 or higer. This means that older GPU will not
work with this Op. work with this Op.
""" """
# Check if deprecated param 'workmem' is used
if workmem is not None:
warnings.warn(("dnn_conv: parameter 'workmem' is deprecated. Use "
"'algo' instead."), stacklevel=3)
assert algo == None
algo = workmem
# Ensure the value of direction_hint is supported # Ensure the value of direction_hint is supported
assert direction_hint in [None, 'bprop weights', 'forward'] assert direction_hint in [None, 'bprop weights', 'forward']
...@@ -1138,7 +1146,7 @@ def dnn_conv(img, kerns, border_mode='valid', subsample=(1, 1), ...@@ -1138,7 +1146,7 @@ def dnn_conv(img, kerns, border_mode='valid', subsample=(1, 1),
desc_op.border_mode, desc_op.border_mode,
desc_op.subsample) desc_op.subsample)
out = gpu_alloc_empty(*out_shp) out = gpu_alloc_empty(*out_shp)
return GpuDnnConv(workmem=workmem)(img, kerns, out, desc) return GpuDnnConv(algo=algo)(img, kerns, out, desc)
def dnn_conv3d(img, kerns, border_mode='valid', subsample=(1, 1, 1), def dnn_conv3d(img, kerns, border_mode='valid', subsample=(1, 1, 1),
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论