提交 2b6ef54e authored 作者: --global's avatar --global

Update dnn_conv3d to use new config flags

上级 6f80c3be
...@@ -1150,7 +1150,8 @@ def dnn_conv(img, kerns, border_mode='valid', subsample=(1, 1), ...@@ -1150,7 +1150,8 @@ def dnn_conv(img, kerns, border_mode='valid', subsample=(1, 1),
def dnn_conv3d(img, kerns, border_mode='valid', subsample=(1, 1, 1), def dnn_conv3d(img, kerns, border_mode='valid', subsample=(1, 1, 1),
conv_mode='conv', direction_hint=None, workmem='none'): conv_mode='conv', direction_hint=None, workmem=None,
algo='none'):
""" """
GPU convolution using cuDNN from NVIDIA. GPU convolution using cuDNN from NVIDIA.
...@@ -1170,8 +1171,10 @@ def dnn_conv3d(img, kerns, border_mode='valid', subsample=(1, 1, 1), ...@@ -1170,8 +1171,10 @@ def dnn_conv3d(img, kerns, border_mode='valid', subsample=(1, 1, 1),
'bprop weights', it will use GpuDnnConvGradW. 'bprop weights', it will use GpuDnnConvGradW.
This parameter is used internally by graph optimizers and may be This parameter is used internally by graph optimizers and may be
removed at any time without a deprecation period. You have been warned. removed at any time without a deprecation period. You have been warned.
:param workmem: Specify the amount of working memory allowed. :param workmem: *deprecated*, use param algo instead
Only 'none' is implemented for the conv3d :param algo: convolution implementation to use. Only 'none' is implemented
for the conv3d. Default is the value of
:attr:`config.dnn.conv.algo_fwd.
:warning: The cuDNN library only works with GPU that have a compute :warning: The cuDNN library only works with GPU that have a compute
capability of 3.0 or higer. This means that older GPU will not capability of 3.0 or higer. This means that older GPU will not
...@@ -1180,6 +1183,13 @@ def dnn_conv3d(img, kerns, border_mode='valid', subsample=(1, 1, 1), ...@@ -1180,6 +1183,13 @@ def dnn_conv3d(img, kerns, border_mode='valid', subsample=(1, 1, 1),
""" """
# Check if deprecated param 'workmem' is used
if workmem is not None:
warnings.warn(("dnn_conv3d: parameter 'workmem' is deprecated. Use "
"'algo' instead."), stacklevel=3)
assert algo == None
algo = workmem
# Ensure the value of direction_hint is supported # Ensure the value of direction_hint is supported
assert direction_hint in [None, 'bprop weights', 'forward'] assert direction_hint in [None, 'bprop weights', 'forward']
...@@ -1216,7 +1226,7 @@ def dnn_conv3d(img, kerns, border_mode='valid', subsample=(1, 1, 1), ...@@ -1216,7 +1226,7 @@ def dnn_conv3d(img, kerns, border_mode='valid', subsample=(1, 1, 1),
desc_op.border_mode, desc_op.border_mode,
desc_op.subsample) desc_op.subsample)
out = gpu_alloc_empty(*out_shp) out = gpu_alloc_empty(*out_shp)
return GpuDnnConv3d(workmem=workmem)(img, kerns, out, desc) return GpuDnnConv3d(algo=algo)(img, kerns, out, desc)
class GpuDnnPoolDesc(GpuOp): class GpuDnnPoolDesc(GpuOp):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论