提交 c36f9a29 authored 作者: carriepl's avatar carriepl 提交者: Frederic

modify cuda/dnn.py to support new convolution implementations

上级 13bad511
...@@ -401,9 +401,8 @@ class GpuDnnConv(DnnBase, COp): ...@@ -401,9 +401,8 @@ class GpuDnnConv(DnnBase, COp):
workmem workmem
*deprecated*, use parameter algo instead. *deprecated*, use parameter algo instead.
algo algo
['none', 'small', 'large', 'fft', 'guess_once', ['none', 'small', 'large', 'fft', 'fft_tiling', 'guess_once',
'guess_on_shape_change', 'time_once', 'guess_on_shape_change', 'time_once', 'time_on_shape_change']
'time_on_shape_change']
Default is the value of :attr:`config.dnn.conv.algo_fwd`. Default is the value of :attr:`config.dnn.conv.algo_fwd`.
...@@ -445,9 +444,9 @@ class GpuDnnConv(DnnBase, COp): ...@@ -445,9 +444,9 @@ class GpuDnnConv(DnnBase, COp):
raise RuntimeError("CuDNN convolution timing requires CuDNN " raise RuntimeError("CuDNN convolution timing requires CuDNN "
"v3") "v3")
assert self.algo in ['none', 'small', 'large', 'fft', 'guess_once', assert self.algo in ['none', 'small', 'large', 'fft', 'fft_tiling',
'guess_on_shape_change', 'time_once', 'guess_once', 'guess_on_shape_change',
'time_on_shape_change'] 'time_once', 'time_on_shape_change']
def __setstate__(self, d): def __setstate__(self, d):
self.__dict__.update(d) self.__dict__.update(d)
...@@ -659,8 +658,8 @@ class GpuDnnConvGradW(DnnBase, COp): ...@@ -659,8 +658,8 @@ class GpuDnnConvGradW(DnnBase, COp):
The convolution descriptor. The convolution descriptor.
workmem workmem
*deprecated*, use parameter algo instead. *deprecated*, use parameter algo instead.
algo : {'none', 'deterministic', 'fft', 'guess_once', 'guess_on_shape_change', 'time_once', 'time_on_shape_change'} algo : {'none', 'deterministic', 'fft', 'small', 'guess_once', 'guess_on_shape_change', 'time_once', 'time_on_shape_change'}
Default is the value of :attr:`config.dnn.conv.algo_bwd`. Default is the value of :attr:`config.dnn.conv.algo_bwd_filter`.
""" """
...@@ -678,15 +677,15 @@ class GpuDnnConvGradW(DnnBase, COp): ...@@ -678,15 +677,15 @@ class GpuDnnConvGradW(DnnBase, COp):
self.algo = workmem self.algo = workmem
else: else:
if algo is None: if algo is None:
algo = config.dnn.conv.algo_bwd algo = config.dnn.conv.algo_bwd_filter
self.algo = algo self.algo = algo
self.inplace = inplace self.inplace = inplace
if self.inplace: if self.inplace:
self.destroy_map = {0: [2]} self.destroy_map = {0: [2]}
assert self.algo in ['none', 'deterministic', 'fft', 'guess_once', assert self.algo in ['none', 'deterministic', 'fft', 'small',
'guess_on_shape_change', 'time_once', 'guess_once', 'guess_on_shape_change',
'time_on_shape_change'] 'time_once', 'time_on_shape_change']
def __setstate__(self, d): def __setstate__(self, d):
self.__dict__.update(d) self.__dict__.update(d)
...@@ -694,7 +693,7 @@ class GpuDnnConvGradW(DnnBase, COp): ...@@ -694,7 +693,7 @@ class GpuDnnConvGradW(DnnBase, COp):
if hasattr(self, 'workmem'): if hasattr(self, 'workmem'):
self.algo = self.workmem self.algo = self.workmem
else: else:
self.algo = config.dnn.conv.algo_bwd self.algo = config.dnn.conv.algo_bwd_filter
if not hasattr(self, 'inplace'): if not hasattr(self, 'inplace'):
self.inplace = False self.inplace = False
...@@ -737,7 +736,7 @@ class GpuDnnConvGradW(DnnBase, COp): ...@@ -737,7 +736,7 @@ class GpuDnnConvGradW(DnnBase, COp):
alg = 'CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1' alg = 'CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1'
elif self.algo == 'fft': elif self.algo == 'fft':
alg = 'CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT' alg = 'CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT'
elif self.algo == 'none2': elif self.algo == 'small':
# need v3, non-deterministic, small workspace # need v3, non-deterministic, small workspace
alg = 'CUDNN_CONVOLUTION_BWD_FILTER_ALGO_3' alg = 'CUDNN_CONVOLUTION_BWD_FILTER_ALGO_3'
elif self.algo in ['guess_once', 'guess_on_shape_change']: elif self.algo in ['guess_once', 'guess_on_shape_change']:
...@@ -799,7 +798,7 @@ class GpuDnnConv3dGradW(GpuDnnConvGradW): ...@@ -799,7 +798,7 @@ class GpuDnnConv3dGradW(GpuDnnConvGradW):
:param workmem: :param workmem:
*deprecated*, use parameter algo instead. *deprecated*, use parameter algo instead.
:param algo: ['none', 'guess_once', 'guess_on_shape_change', 'time_once', 'time_on_shape_change'] :param algo: ['none', 'guess_once', 'guess_on_shape_change', 'time_once', 'time_on_shape_change']
Default is the value of :attr:`config.dnn.conv.algo_bwd`. Default is the value of :attr:`config.dnn.conv.algo_bwd_filter`.
""" """
__props__ = ('algo', 'inplace',) __props__ = ('algo', 'inplace',)
...@@ -867,11 +866,11 @@ class GpuDnnConvGradI(DnnBase, COp): ...@@ -867,11 +866,11 @@ class GpuDnnConvGradI(DnnBase, COp):
workmem workmem
*deprecated*, use parameter algo instead. *deprecated*, use parameter algo instead.
algo algo
['none', 'deterministic', 'fft', 'guess_once', ['none', 'deterministic', 'fft', 'fft_tiling', 'guess_once',
'guess_on_shape_change', 'time_once', 'guess_on_shape_change', 'time_once',
'time_on_shape_change'] 'time_on_shape_change']
Default is the value of :attr:`config.dnn.conv.algo_bwd`. Default is the value of :attr:`config.dnn.conv.algo_bwd_data`.
""" """
...@@ -890,15 +889,15 @@ class GpuDnnConvGradI(DnnBase, COp): ...@@ -890,15 +889,15 @@ class GpuDnnConvGradI(DnnBase, COp):
self.algo = workmem self.algo = workmem
else: else:
if algo is None: if algo is None:
algo = config.dnn.conv.algo_bwd algo = config.dnn.conv.algo_bwd_data
self.algo = algo self.algo = algo
self.inplace = inplace self.inplace = inplace
if self.inplace: if self.inplace:
self.destroy_map = {0: [2]} self.destroy_map = {0: [2]}
assert self.algo in ['none', 'deterministic', 'fft', 'guess_once', assert self.algo in ['none', 'deterministic', 'fft', 'fft_tiling',
'guess_on_shape_change', 'time_once', 'guess_once', 'guess_on_shape_change',
'time_on_shape_change'] 'time_once', 'time_on_shape_change']
def __setstate__(self, d): def __setstate__(self, d):
self.__dict__.update(d) self.__dict__.update(d)
...@@ -906,7 +905,7 @@ class GpuDnnConvGradI(DnnBase, COp): ...@@ -906,7 +905,7 @@ class GpuDnnConvGradI(DnnBase, COp):
if hasattr(self, 'workmem'): if hasattr(self, 'workmem'):
self.algo = self.workmem self.algo = self.workmem
else: else:
self.algo = config.dnn.conv.algo_bwd self.algo = config.dnn.conv.algo_bwd_data
if not hasattr(self, 'inplace'): if not hasattr(self, 'inplace'):
self.inplace = False self.inplace = False
...@@ -1013,7 +1012,7 @@ class GpuDnnConv3dGradI(GpuDnnConvGradI): ...@@ -1013,7 +1012,7 @@ class GpuDnnConv3dGradI(GpuDnnConvGradI):
:param algo: ['none', 'guess_once', 'guess_on_shape_change', :param algo: ['none', 'guess_once', 'guess_on_shape_change',
'time_once', 'time_on_shape_change'] 'time_once', 'time_on_shape_change']
Default is the value of :attr:`config.dnn.conv.algo_bwd`. Default is the value of :attr:`config.dnn.conv.algo_bwd_data`.
""" """
__props__ = ('algo', 'inplace',) __props__ = ('algo', 'inplace',)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论