提交 49c79247 authored 作者: Frederic Bastien's avatar Frederic Bastien 提交者: Mathieu Germain

Add small and fft_tiling to GpuDnnConv3d

上级 420a832b
......@@ -384,7 +384,7 @@ class GpuDnnConv(DnnBase, COp):
# need v3
alg = 'CUDNN_CONVOLUTION_FWD_ALGO_FFT'
elif self.algo == 'fft_tiling':
# need v4
# need v4 for conv2d, need v5 for conv3d
alg = 'CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING'
elif self.algo in ['guess_once', 'guess_on_shape_change']:
# The convolution implementation should be choosen according
......@@ -478,7 +478,9 @@ class GpuDnnConv3d(GpuDnnConv):
:param descr: the convolution descriptor
:param workmem:
*deprecated*, use parameter algo instead.
:param algo: ['none', 'guess_once', 'guess_on_shape_change', 'time_once', 'time_on_shape_change']
:param algo: ['none', 'small', 'fft_tiling',
'guess_once', 'guess_on_shape_change',
'time_once', 'time_on_shape_change']
Default is the value of :attr:`config.dnn.conv.algo_fwd`.
"""
......@@ -492,10 +494,22 @@ class GpuDnnConv3d(GpuDnnConv):
"Use 'algo' instead."), stacklevel=3)
assert algo is None
algo = workmem
super(GpuDnnConv3d, self).__init__(inplace=inplace, algo='none')
assert self.algo in ['none', 'guess_once', 'guess_on_shape_change',
'time_once', 'time_on_shape_change']
good_algo = ['none', 'small', 'fft_tiling',
'guess_once', 'guess_on_shape_change',
'time_once', 'time_on_shape_change']
if version() < (5000, 5000):
# Need to confirm when small was added for conv3d.
algo = 'none'
elif algo is None and config.dnn.conv.algo_fwd not in good_algo:
algo = 'guess_once'
elif algo is not None and algo not in good_algo:
algo = 'guess_once'
super(GpuDnnConv3d, self).__init__(inplace=inplace, algo=algo)
assert self.algo in good_algo
if version() < (5000, 5000):
if self.algo == 'fft_tiling':
raise RuntimeError("CuDNN 3d tiled-FFT convolution requires "
"CuDNN v5 or more recent")
def make_node(self, img, kern, output, desc, alpha=None, beta=None):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论