提交 319cf621 authored 作者: --global's avatar --global

Update GpuDnnConvGradW to use new config flags

上级 e0eaa397
...@@ -685,25 +685,42 @@ class GpuDnnConvGradW(DnnBase, COp): ...@@ -685,25 +685,42 @@ class GpuDnnConvGradW(DnnBase, COp):
:param descr: the convolution descriptor :param descr: the convolution descriptor
""" """
__props__ = ('workmem', 'inplace',) __props__ = ('algo', 'inplace',)
__input_name__ = ('image', 'grad', 'output', 'descriptor', 'alpha', 'beta') __input_name__ = ('image', 'grad', 'output', 'descriptor', 'alpha', 'beta')
def __init__(self, inplace=False, workmem=None): def __init__(self, inplace=False, workmem=None, algo=None):
"""
:param workmem: *deprecated*, use param algo instead
:param algo: either 'none', 'deterministic', 'fft', 'guess_once' or
'guess_on_shape_change'.
Default is the value of :attr:`config.dnn.conv.algo_bwd`.
"""
COp.__init__(self, ["dnn_base.c", "dnn_conv_base.c", "dnn_gw.c"], COp.__init__(self, ["dnn_base.c", "dnn_conv_base.c", "dnn_gw.c"],
"APPLY_SPECIFIC(conv_gw)") "APPLY_SPECIFIC(conv_gw)")
if workmem is None:
workmem = config.dnn.conv.workmem_bwd if workmem is not None:
self.workmem = workmem warnings.warn(("GpuDnnConvGradW: parameter 'workmem' is "
"deprecated. Use 'algo' instead."), stacklevel=3)
assert algo == None
self.algo = workmem
else:
if algo is None:
algo = config.dnn.conv.algo_bwd
self.algo = algo
self.inplace = inplace self.inplace = inplace
if self.inplace: if self.inplace:
self.destroy_map = {0: [2]} self.destroy_map = {0: [2]}
assert self.workmem in ['none', 'deterministic', 'fft', 'guess', assert self.algo in ['none', 'deterministic', 'fft', 'guess_once',
'guess_once'] 'guess_on_shape_change']
def __setstate__(self, d): def __setstate__(self, d):
self.__dict__.update(d) self.__dict__.update(d)
if not hasattr(self, 'workmem'): if not hasattr(self, 'algo'):
self.workmem = 'none' if hasattr(self, 'workmem'):
self.algo = self.workmem
else:
self.algo = 'none'
if not hasattr(self, 'inplace'): if not hasattr(self, 'inplace'):
self.inplace = False self.inplace = False
...@@ -736,21 +753,21 @@ class GpuDnnConvGradW(DnnBase, COp): ...@@ -736,21 +753,21 @@ class GpuDnnConvGradW(DnnBase, COp):
alg_def = ('CONV_ALGO', '0') alg_def = ('CONV_ALGO', '0')
alg_choose_def = ('CHOOSE_ALGO', '0') alg_choose_def = ('CHOOSE_ALGO', '0')
else: else:
if self.workmem == 'none': if self.algo == 'none':
alg_def = ('CONV_ALGO', 'CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0') alg_def = ('CONV_ALGO', 'CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0')
alg_choose_def = ('CHOOSE_ALGO', '0') alg_choose_def = ('CHOOSE_ALGO', '0')
elif self.workmem == 'deterministic': elif self.algo == 'deterministic':
alg_def = ('CONV_ALGO', 'CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1') alg_def = ('CONV_ALGO', 'CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1')
alg_choose_def = ('CHOOSE_ALGO', '0') alg_choose_def = ('CHOOSE_ALGO', '0')
elif self.workmem == 'fft': elif self.algo == 'fft':
alg_def = ('CONV_ALGO', 'CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT') alg_def = ('CONV_ALGO', 'CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT')
alg_choose_def = ('CHOOSE_ALGO', '0') alg_choose_def = ('CHOOSE_ALGO', '0')
elif self.workmem in ['guess', 'guess_once']: elif self.algo in ['guess_once', 'guess_on_shape_change']:
# The convolution implementation should be choosen according # The convolution implementation should be choosen according
# to a heuristic # to a heuristic
alg_def = ('CONV_ALGO', 'CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0') alg_def = ('CONV_ALGO', 'CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0')
alg_choose_def = ('CHOOSE_ALGO', '1') alg_choose_def = ('CHOOSE_ALGO', '1')
if self.workmem == 'guess_once': if self.algo == 'guess_once':
alg_choose_once_def = ('CHOOSE_ALGO_ONCE', '1') alg_choose_once_def = ('CHOOSE_ALGO_ONCE', '1')
return inplace_def + [alg_def, alg_choose_def, alg_choose_once_def] return inplace_def + [alg_def, alg_choose_def, alg_choose_once_def]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论