提交 28f72bb5 authored 作者: --global's avatar --global

Update GpuDnnConv to use new config flags

上级 d740f88e
import os
import numpy
import warnings
import theano
from theano import Apply, gof, tensor, config, Variable
......@@ -405,21 +406,30 @@ class GpuDnnConv(DnnBase, COp):
:param kernel:
:param descr: the convolution descriptor
"""
__props__ = ('workmem', 'inplace')
__props__ = ('algo', 'inplace')
__input_name__ = ('image', 'kernel', 'output',
'descriptor', 'alpha', 'beta')
def __init__(self, workmem=None, inplace=False):
def __init__(self, workmem=None, inplace=False, algo=None):
"""
:param workmem: either 'none', 'small', 'large', 'fft', 'time',
'time_once', 'guess' or 'guess_once'. Default is the value of
:attr:`config.dnn.conv.workmem`.
:param workmem: *deprecated*, use param algo instead
:param algo: either 'small', 'none', 'large', 'fft', 'guess_once',
'guess_on_shape_change', 'time_once' or 'time_on_shape_change'.
Default is the value of :attr:`config.dnn.conv.algo_fwd`.
"""
COp.__init__(self, ["dnn_base.c", "dnn_conv_base.c", "dnn_fwd.c"],
"APPLY_SPECIFIC(conv_fwd)")
if workmem is None:
workmem = config.dnn.conv.workmem
self.workmem = workmem
if workmem is not None:
warnings.warn(("GpuDnnConv: parameter 'workmem' is deprecated. "
"Use 'algo' instead."), stacklevel=3)
assert algo == None
self.algo = workmem
else:
if algo is None:
algo = config.dnn.conv.algo_fwd
self.algo = algo
self.inplace = inplace
if self.inplace:
self.destroy_map = {0: [2]}
......@@ -428,18 +438,21 @@ class GpuDnnConv(DnnBase, COp):
# option to time the different implementations to get the fastest
# are both unavailable.
if version() < (3000, 3000):
if self.workmem == 'fft':
if self.algo == 'fft':
raise RuntimeError("CuDNN FFT convolution requires CuDNN v3")
elif self.workmem in ['time', 'time_once']:
elif self.algo in ['time', 'time_once']:
raise RuntimeError("CuDNN convolution timing requires CuDNN v3")
assert self.workmem in ['none', 'small', 'large', 'fft', 'time',
assert self.algo in ['none', 'small', 'large', 'fft', 'time',
'time_once', 'guess', 'guess_once']
def __setstate__(self, d):
self.__dict__.update(d)
if not hasattr(self, 'workmem'):
self.workmem = 'none'
if not hasattr(self, 'algo'):
if hasattr(self, 'workmem'):
self.algo = self.workmem
else:
self.algo = 'none'
if not hasattr(self, 'inplace'):
self.inplace = False
......@@ -455,28 +468,28 @@ class GpuDnnConv(DnnBase, COp):
if version() == -1:
alg = "0"
else:
if self.workmem == 'none':
if self.algo == 'none':
alg = 'CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM'
elif self.workmem == 'small':
elif self.algo == 'small':
alg = 'CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM'
elif self.workmem == 'large':
elif self.algo == 'large':
alg = 'CUDNN_CONVOLUTION_FWD_ALGO_GEMM'
elif self.workmem == 'fft':
elif self.algo == 'fft':
alg = 'CUDNN_CONVOLUTION_FWD_ALGO_FFT'
elif self.workmem in ['guess', 'guess_once']:
elif self.algo in ['guess', 'guess_once']:
# The convolution implementation should be choosen according
# to a heuristic
alg = 'CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM'
choose_alg = '1'
if self.workmem == 'guess_once':
if self.algo == 'guess_once':
choose_alg_once = '1'
elif self.workmem in ['time', 'time_once']:
elif self.algo in ['time', 'time_once']:
# The convolution implementation should be choosen by timing
# every available implementation
alg = 'CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM'
choose_alg = '1'
choose_alg_time = '1'
if self.workmem == 'time_once':
if self.algo == 'time_once':
choose_alg_once = '1'
alg_def = ('CONV_ALGO', alg)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论