提交 6098369f authored 作者: Frederic Bastien's avatar Frederic Bastien

Update 3d conv algo avail in cudnn v5

上级 58e93f9b
......@@ -461,9 +461,9 @@ class GpuDnnConv(DnnBase):
defs.append(('CONV_INPLACE', '1'))
alg = 'CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM'
if self.algo == 'none':
if self.algo == 'none': # 3d (at least in v4)
alg = 'CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM'
elif self.algo == 'small':
elif self.algo == 'small': # 3d (at least in v4)
alg = 'CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM'
elif self.algo == 'large':
alg = 'CUDNN_CONVOLUTION_FWD_ALGO_GEMM'
......@@ -471,7 +471,7 @@ class GpuDnnConv(DnnBase):
alg = 'CUDNN_CONVOLUTION_FWD_ALGO_DIRECT'
elif self.algo == 'fft':
alg = 'CUDNN_CONVOLUTION_FWD_ALGO_FFT'
elif self.algo == 'fft_tiling':
elif self.algo == 'fft_tiling': # 3d (not in v4, in v5)
alg = 'CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING'
elif self.algo == 'winograd':
# need v5
......@@ -505,10 +505,13 @@ class GpuDnnConv(DnnBase):
raise TypeError("The number of dimensions of "
"img, kern and output must match")
if (img.type.ndim == 5 and
self.algo in ['small', 'large', 'fft', 'fft_tiling']):
if img.type.ndim == 5 and self.algo in ['large', 'fft']:
raise ValueError("convolution algo %s can't be used for "
"3d convolutions", (self.algo,))
if (img.type.ndim == 5 and
self.algo in ['fft_tiling'] and
version() < 5000):
raise ValueError("3d convolution algo fft_tiling need cudnn v5")
if (not isinstance(desc.type, CDataType) or
desc.type.ctype != 'cudnnConvolutionDescriptor_t'):
......@@ -634,13 +637,13 @@ class GpuDnnConvGradW(DnnBase):
defs.append(('CONV_INPLACE', '1'))
alg = 'CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0'
if self.algo == 'none':
if self.algo == 'none': # 3d in at least v4
alg = 'CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0'
if self.algo == 'deterministic':
alg = 'CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1'
if self.algo == 'fft':
alg = 'CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT'
if self.algo == 'small':
if self.algo == 'small': # 3d in at least v4
# non-deterministic, small workspace
alg = 'CUDNN_CONVOLUTION_BWD_FILTER_ALGO_3'
if self.algo in ['guess_once', 'guess_on_shape_change',
......@@ -673,7 +676,7 @@ class GpuDnnConvGradW(DnnBase):
"img, topgrad and output must match")
if (img.type.ndim == 5 and
self.algo in ['fft', 'deterministic', 'small']):
self.algo in ['fft', 'deterministic']):
raise ValueError("convolution algo %s can't be used for "
"3d convolutions", (self.algo,))
......@@ -766,13 +769,13 @@ class GpuDnnConvGradI(DnnBase):
defs.append(('CONV_INPLACE', '1'))
alg = 'CUDNN_CONVOLUTION_BWD_DATA_ALGO_0'
if self.algo == 'none':
if self.algo == 'none': # 3d at least v4
alg = 'CUDNN_CONVOLUTION_BWD_DATA_ALGO_0'
elif self.algo == 'deterministic':
elif self.algo == 'deterministic': # 3d at least v4
alg = 'CUDNN_CONVOLUTION_BWD_DATA_ALGO_1'
elif self.algo == 'fft':
alg = 'CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT'
elif self.algo == 'fft_tiling':
elif self.algo == 'fft_tiling': # 3d not v4, since v5
# big workspace but less than fft
alg = 'CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING'
elif self.algo == 'winograd':
......@@ -808,10 +811,13 @@ class GpuDnnConvGradI(DnnBase):
raise TypeError("The number of dimensions of "
"kern, topgrad and output must match")
if (kern.type.ndim == 5 and
self.algo in ['fft', 'deterministic', 'fft_tiling']):
if kern.type.ndim == 5 and self.algo in ['fft']:
raise ValueError("convolution algo %s can't be used for "
"3d convolutions", (self.algo,))
if (kern.type.ndim == 5 and
self.algo == 'fft_tiling' and
version() < 5000):
raise ValueError("3d convolution algo fft_tiling need cudnn v5")
if (not isinstance(desc.type, CDataType) or
desc.type.ctype != 'cudnnConvolutionDescriptor_t'):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论