提交 1e48b734 authored 作者: carriepl's avatar carriepl 提交者: Frederic

Update GpuDnnConvGradW for CuDNN v4 (gpua backend)

上级 64439f41
...@@ -541,18 +541,25 @@ class GpuDnnConvGradW(DnnBase): ...@@ -541,18 +541,25 @@ class GpuDnnConvGradW(DnnBase):
if self.inplace: if self.inplace:
self.destroy_map = {0: [2]} self.destroy_map = {0: [2]}
if algo is None: if algo is None:
algo = config.dnn.conv.algo_bwd algo = config.dnn.conv.algo_bwd_filter
self.algo = algo self.algo = algo
assert self.algo in ['none', 'deterministic', 'fft', 'guess_once',
'guess_on_shape_change', 'time_once', # The small-workspace implementation is only available from CuDNN V4
'time_on_shape_change'] # onward.
if version() < 4000 and self.algo == 'small':
raise RuntimeError("CuDNN's small workspace GradW convolution "
"requires CuDNN v4 or more recent.")
assert self.algo in ['none', 'deterministic', 'fft', 'small',
'guess_once', 'guess_on_shape_change',
'time_once', 'time_on_shape_change']
def __setstate__(self, d): def __setstate__(self, d):
self.__dict__.update(d) self.__dict__.update(d)
if not hasattr(self, 'inplace'): if not hasattr(self, 'inplace'):
self.inplace = False self.inplace = False
if not hasattr(self, 'algo'): if not hasattr(self, 'algo'):
self.algo = config.dnn.conv.algo_bwd self.algo = config.dnn.conv.algo_bwd_filter
def grad(self, inp, grads): def grad(self, inp, grads):
img, top, output, desc, alpha, beta = inp img, top, output, desc, alpha, beta = inp
...@@ -587,7 +594,9 @@ class GpuDnnConvGradW(DnnBase): ...@@ -587,7 +594,9 @@ class GpuDnnConvGradW(DnnBase):
alg = 'CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1' alg = 'CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1'
if self.algo == 'fft': if self.algo == 'fft':
alg = 'CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT' alg = 'CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT'
if self.algo == 'small':
# non-deterministic, small workspace
alg = 'CUDNN_CONVOLUTION_BWD_FILTER_ALGO_3'
if self.algo in ['guess_once', 'guess_on_shape_change', if self.algo in ['guess_once', 'guess_on_shape_change',
'time_once', 'time_on_shape_change']: 'time_once', 'time_on_shape_change']:
defs.append(('CHOOSE_ALGO', '')) defs.append(('CHOOSE_ALGO', ''))
...@@ -617,7 +626,8 @@ class GpuDnnConvGradW(DnnBase): ...@@ -617,7 +626,8 @@ class GpuDnnConvGradW(DnnBase):
raise TypeError("The number of dimensions of " raise TypeError("The number of dimensions of "
"img, topgrad and output must match") "img, topgrad and output must match")
if img.type.ndim == 5 and self.algo in ['fft', 'deterministic']: if (img.type.ndim == 5 and
self.algo in ['fft', 'deterministic', 'small']):
raise ValueError("convolution algo %s can't be used for " raise ValueError("convolution algo %s can't be used for "
"3d convolutions", (self.algo,)) "3d convolutions", (self.algo,))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论