提交 c36f9a29 authored 作者: carriepl's avatar carriepl 提交者: Frederic

modify cuda/dnn.py to support new convolution implementations

上级 13bad511
......@@ -401,9 +401,8 @@ class GpuDnnConv(DnnBase, COp):
workmem
*deprecated*, use parameter algo instead.
algo
['none', 'small', 'large', 'fft', 'guess_once',
'guess_on_shape_change', 'time_once',
'time_on_shape_change']
['none', 'small', 'large', 'fft', 'fft_tiling', 'guess_once',
'guess_on_shape_change', 'time_once', 'time_on_shape_change']
Default is the value of :attr:`config.dnn.conv.algo_fwd`.
......@@ -445,9 +444,9 @@ class GpuDnnConv(DnnBase, COp):
raise RuntimeError("CuDNN convolution timing requires CuDNN "
"v3")
assert self.algo in ['none', 'small', 'large', 'fft', 'guess_once',
'guess_on_shape_change', 'time_once',
'time_on_shape_change']
assert self.algo in ['none', 'small', 'large', 'fft', 'fft_tiling',
'guess_once', 'guess_on_shape_change',
'time_once', 'time_on_shape_change']
def __setstate__(self, d):
self.__dict__.update(d)
......@@ -659,8 +658,8 @@ class GpuDnnConvGradW(DnnBase, COp):
The convolution descriptor.
workmem
*deprecated*, use parameter algo instead.
algo : {'none', 'deterministic', 'fft', 'guess_once', 'guess_on_shape_change', 'time_once', 'time_on_shape_change'}
Default is the value of :attr:`config.dnn.conv.algo_bwd`.
algo : {'none', 'deterministic', 'fft', 'small', 'guess_once', 'guess_on_shape_change', 'time_once', 'time_on_shape_change'}
Default is the value of :attr:`config.dnn.conv.algo_bwd_filter`.
"""
......@@ -678,15 +677,15 @@ class GpuDnnConvGradW(DnnBase, COp):
self.algo = workmem
else:
if algo is None:
algo = config.dnn.conv.algo_bwd
algo = config.dnn.conv.algo_bwd_filter
self.algo = algo
self.inplace = inplace
if self.inplace:
self.destroy_map = {0: [2]}
assert self.algo in ['none', 'deterministic', 'fft', 'guess_once',
'guess_on_shape_change', 'time_once',
'time_on_shape_change']
assert self.algo in ['none', 'deterministic', 'fft', 'small',
'guess_once', 'guess_on_shape_change',
'time_once', 'time_on_shape_change']
def __setstate__(self, d):
self.__dict__.update(d)
......@@ -694,7 +693,7 @@ class GpuDnnConvGradW(DnnBase, COp):
if hasattr(self, 'workmem'):
self.algo = self.workmem
else:
self.algo = config.dnn.conv.algo_bwd
self.algo = config.dnn.conv.algo_bwd_filter
if not hasattr(self, 'inplace'):
self.inplace = False
......@@ -737,7 +736,7 @@ class GpuDnnConvGradW(DnnBase, COp):
alg = 'CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1'
elif self.algo == 'fft':
alg = 'CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT'
elif self.algo == 'none2':
elif self.algo == 'small':
# need v3, non-deterministic, small workspace
alg = 'CUDNN_CONVOLUTION_BWD_FILTER_ALGO_3'
elif self.algo in ['guess_once', 'guess_on_shape_change']:
......@@ -799,7 +798,7 @@ class GpuDnnConv3dGradW(GpuDnnConvGradW):
:param workmem:
*deprecated*, use parameter algo instead.
:param algo: ['none', 'guess_once', 'guess_on_shape_change', 'time_once', 'time_on_shape_change']
Default is the value of :attr:`config.dnn.conv.algo_bwd`.
Default is the value of :attr:`config.dnn.conv.algo_bwd_filter`.
"""
__props__ = ('algo', 'inplace',)
......@@ -867,11 +866,11 @@ class GpuDnnConvGradI(DnnBase, COp):
workmem
*deprecated*, use parameter algo instead.
algo
['none', 'deterministic', 'fft', 'guess_once',
['none', 'deterministic', 'fft', 'fft_tiling', 'guess_once',
'guess_on_shape_change', 'time_once',
'time_on_shape_change']
Default is the value of :attr:`config.dnn.conv.algo_bwd`.
Default is the value of :attr:`config.dnn.conv.algo_bwd_data`.
"""
......@@ -890,15 +889,15 @@ class GpuDnnConvGradI(DnnBase, COp):
self.algo = workmem
else:
if algo is None:
algo = config.dnn.conv.algo_bwd
algo = config.dnn.conv.algo_bwd_data
self.algo = algo
self.inplace = inplace
if self.inplace:
self.destroy_map = {0: [2]}
assert self.algo in ['none', 'deterministic', 'fft', 'guess_once',
'guess_on_shape_change', 'time_once',
'time_on_shape_change']
assert self.algo in ['none', 'deterministic', 'fft', 'fft_tiling',
'guess_once', 'guess_on_shape_change',
'time_once', 'time_on_shape_change']
def __setstate__(self, d):
self.__dict__.update(d)
......@@ -906,7 +905,7 @@ class GpuDnnConvGradI(DnnBase, COp):
if hasattr(self, 'workmem'):
self.algo = self.workmem
else:
self.algo = config.dnn.conv.algo_bwd
self.algo = config.dnn.conv.algo_bwd_data
if not hasattr(self, 'inplace'):
self.inplace = False
......@@ -1013,7 +1012,7 @@ class GpuDnnConv3dGradI(GpuDnnConvGradI):
:param algo: ['none', 'guess_once', 'guess_on_shape_change',
'time_once', 'time_on_shape_change']
Default is the value of :attr:`config.dnn.conv.algo_bwd`.
Default is the value of :attr:`config.dnn.conv.algo_bwd_data`.
"""
__props__ = ('algo', 'inplace',)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论