提交 96fe4301 authored 作者: Frederic Bastien's avatar Frederic Bastien 提交者: Mathieu Germain

GpuDnnConv3dGradW for v5 support small algo

上级 9a9c5fb8
......@@ -710,7 +710,9 @@ class GpuDnnConv3dGradW(GpuDnnConvGradW):
:param descr: the convolution descriptor
:param workmem:
*deprecated*, use parameter algo instead.
:param algo: ['none', 'guess_once', 'guess_on_shape_change', 'time_once', 'time_on_shape_change']
:param algo: ['none', 'small',
'guess_once', 'guess_on_shape_change',
'time_once', 'time_on_shape_change']
Default is the value of :attr:`config.dnn.conv.algo_bwd_filter`.
"""
......@@ -723,11 +725,18 @@ class GpuDnnConv3dGradW(GpuDnnConvGradW):
"deprecated. Use 'algo' instead."), stacklevel=3)
assert algo is None
algo = workmem
super(GpuDnnConv3dGradW, self).__init__(inplace=inplace,
algo='none')
assert self.algo in ['none', 'guess_once', 'guess_on_shape_change',
good_algo = ['none', 'small',
'guess_once', 'guess_on_shape_change',
'time_once', 'time_on_shape_change']
if version() < (5000, 5000) and algo == 'small':
algo = 'guess_once'
elif algo is None and config.dnn.conv.algo_bwd_filter not in good_algo:
algo = 'guess_once'
elif algo is not None and algo not in good_algo:
algo = 'guess_once'
super(GpuDnnConv3dGradW, self).__init__(inplace=inplace,
algo=algo)
assert self.algo in good_algo
def grad(self, inp, grads):
img, top, output, desc, alpha, beta = inp
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论