提交 2b6ef54e authored 作者: --global's avatar --global

Update dnn_conv3d to use new config flags

上级 6f80c3be
......@@ -1150,7 +1150,8 @@ def dnn_conv(img, kerns, border_mode='valid', subsample=(1, 1),
def dnn_conv3d(img, kerns, border_mode='valid', subsample=(1, 1, 1),
conv_mode='conv', direction_hint=None, workmem='none'):
conv_mode='conv', direction_hint=None, workmem=None,
algo='none'):
"""
GPU convolution using cuDNN from NVIDIA.
......@@ -1170,8 +1171,10 @@ def dnn_conv3d(img, kerns, border_mode='valid', subsample=(1, 1, 1),
'bprop weights', it will use GpuDnnConvGradW.
This parameter is used internally by graph optimizers and may be
removed at any time without a deprecation period. You have been warned.
:param workmem: Specify the amount of working memory allowed.
Only 'none' is implemented for the conv3d
:param workmem: *deprecated*, use param algo instead
:param algo: convolution implementation to use. Only 'none' is implemented
for the conv3d. Default is the value of
:attr:`config.dnn.conv.algo_fwd.
:warning: The cuDNN library only works with GPU that have a compute
capability of 3.0 or higer. This means that older GPU will not
......@@ -1180,6 +1183,13 @@ def dnn_conv3d(img, kerns, border_mode='valid', subsample=(1, 1, 1),
"""
# Check if deprecated param 'workmem' is used
if workmem is not None:
warnings.warn(("dnn_conv3d: parameter 'workmem' is deprecated. Use "
"'algo' instead."), stacklevel=3)
assert algo == None
algo = workmem
# Ensure the value of direction_hint is supported
assert direction_hint in [None, 'bprop weights', 'forward']
......@@ -1216,7 +1226,7 @@ def dnn_conv3d(img, kerns, border_mode='valid', subsample=(1, 1, 1),
desc_op.border_mode,
desc_op.subsample)
out = gpu_alloc_empty(*out_shp)
return GpuDnnConv3d(workmem=workmem)(img, kerns, out, desc)
return GpuDnnConv3d(algo=algo)(img, kerns, out, desc)
class GpuDnnPoolDesc(GpuOp):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论