提交 9da84e89 authored 作者: Frederic Bastien's avatar Frederic Bastien 提交者: Mathieu Germain

GpuDnnConv3dGradI algo deterministic for v4 and v5. fft_tiling for v5

上级 96fe4301
...@@ -880,6 +880,7 @@ class GpuDnnConvGradI(DnnBase, COp): ...@@ -880,6 +880,7 @@ class GpuDnnConvGradI(DnnBase, COp):
alg = 'CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT' alg = 'CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT'
elif self.algo == 'fft_tiling': elif self.algo == 'fft_tiling':
# need v4, big workspace, but less then fft # need v4, big workspace, but less then fft
# need v5, for conv3d.
alg = 'CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING' alg = 'CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING'
elif self.algo in ['guess_once', 'guess_on_shape_change']: elif self.algo in ['guess_once', 'guess_on_shape_change']:
# The convolution implementation should be chosen according # The convolution implementation should be chosen according
...@@ -939,7 +940,8 @@ class GpuDnnConv3dGradI(GpuDnnConvGradI): ...@@ -939,7 +940,8 @@ class GpuDnnConv3dGradI(GpuDnnConvGradI):
:param descr: the convolution descriptor :param descr: the convolution descriptor
:param workmem: :param workmem:
*deprecated*, use parameter algo instead. *deprecated*, use parameter algo instead.
:param algo: ['none', 'guess_once', 'guess_on_shape_change', :param algo: ['none', 'deterministic, 'fft_tiling',
'guess_once', 'guess_on_shape_change',
'time_once', 'time_on_shape_change'] 'time_once', 'time_on_shape_change']
Default is the value of :attr:`config.dnn.conv.algo_bwd_data`. Default is the value of :attr:`config.dnn.conv.algo_bwd_data`.
...@@ -955,11 +957,20 @@ class GpuDnnConv3dGradI(GpuDnnConvGradI): ...@@ -955,11 +957,20 @@ class GpuDnnConv3dGradI(GpuDnnConvGradI):
"deprecated. Use 'algo' instead."), stacklevel=3) "deprecated. Use 'algo' instead."), stacklevel=3)
assert algo is None assert algo is None
algo = workmem algo = workmem
good_algo = ['none', 'deterministic', 'fft_tiling',
super(GpuDnnConv3dGradI, self).__init__(inplace=inplace, 'guess_once', 'guess_on_shape_change',
algo="none")
assert self.algo in ['none', 'guess_once', 'guess_on_shape_change',
'time_once', 'time_on_shape_change'] 'time_once', 'time_on_shape_change']
if algo is None and config.dnn.conv.algo_bwd_data not in good_algo:
algo = 'guess_once'
elif algo is not None and algo not in good_algo:
algo = 'guess_once'
super(GpuDnnConv3dGradI, self).__init__(inplace=inplace,
algo=algo)
assert self.algo in good_algo
if version() < (5000, 5000):
if self.algo == 'fft_tiling':
raise RuntimeError("CuDNN 3d tiled-FFT convolution requires "
"CuDNN v5 or more recent")
def grad(self, inp, grads): def grad(self, inp, grads):
kerns, top, output, desc, alpha, beta = inp kerns, top, output, desc, alpha, beta = inp
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论