提交 28f72bb5 authored 作者: --global's avatar --global

Update GpuDnnConv to use new config flags

上级 d740f88e
import os import os
import numpy import numpy
import warnings
import theano import theano
from theano import Apply, gof, tensor, config, Variable from theano import Apply, gof, tensor, config, Variable
...@@ -405,21 +406,30 @@ class GpuDnnConv(DnnBase, COp): ...@@ -405,21 +406,30 @@ class GpuDnnConv(DnnBase, COp):
:param kernel: :param kernel:
:param descr: the convolution descriptor :param descr: the convolution descriptor
""" """
__props__ = ('workmem', 'inplace') __props__ = ('algo', 'inplace')
__input_name__ = ('image', 'kernel', 'output', __input_name__ = ('image', 'kernel', 'output',
'descriptor', 'alpha', 'beta') 'descriptor', 'alpha', 'beta')
def __init__(self, workmem=None, inplace=False): def __init__(self, workmem=None, inplace=False, algo=None):
""" """
:param workmem: either 'none', 'small', 'large', 'fft', 'time', :param workmem: *deprecated*, use param algo instead
'time_once', 'guess' or 'guess_once'. Default is the value of :param algo: either 'small', 'none', 'large', 'fft', 'guess_once',
:attr:`config.dnn.conv.workmem`. 'guess_on_shape_change', 'time_once' or 'time_on_shape_change'.
Default is the value of :attr:`config.dnn.conv.algo_fwd`.
""" """
COp.__init__(self, ["dnn_base.c", "dnn_conv_base.c", "dnn_fwd.c"], COp.__init__(self, ["dnn_base.c", "dnn_conv_base.c", "dnn_fwd.c"],
"APPLY_SPECIFIC(conv_fwd)") "APPLY_SPECIFIC(conv_fwd)")
if workmem is None:
workmem = config.dnn.conv.workmem if workmem is not None:
self.workmem = workmem warnings.warn(("GpuDnnConv: parameter 'workmem' is deprecated. "
"Use 'algo' instead."), stacklevel=3)
assert algo == None
self.algo = workmem
else:
if algo is None:
algo = config.dnn.conv.algo_fwd
self.algo = algo
self.inplace = inplace self.inplace = inplace
if self.inplace: if self.inplace:
self.destroy_map = {0: [2]} self.destroy_map = {0: [2]}
...@@ -428,18 +438,21 @@ class GpuDnnConv(DnnBase, COp): ...@@ -428,18 +438,21 @@ class GpuDnnConv(DnnBase, COp):
# option to time the different implementations to get the fastest # option to time the different implementations to get the fastest
# are both unavailable. # are both unavailable.
if version() < (3000, 3000): if version() < (3000, 3000):
if self.workmem == 'fft': if self.algo == 'fft':
raise RuntimeError("CuDNN FFT convolution requires CuDNN v3") raise RuntimeError("CuDNN FFT convolution requires CuDNN v3")
elif self.workmem in ['time', 'time_once']: elif self.algo in ['time', 'time_once']:
raise RuntimeError("CuDNN convolution timing requires CuDNN v3") raise RuntimeError("CuDNN convolution timing requires CuDNN v3")
assert self.workmem in ['none', 'small', 'large', 'fft', 'time', assert self.algo in ['none', 'small', 'large', 'fft', 'time',
'time_once', 'guess', 'guess_once'] 'time_once', 'guess', 'guess_once']
def __setstate__(self, d): def __setstate__(self, d):
self.__dict__.update(d) self.__dict__.update(d)
if not hasattr(self, 'workmem'): if not hasattr(self, 'algo'):
self.workmem = 'none' if hasattr(self, 'workmem'):
self.algo = self.workmem
else:
self.algo = 'none'
if not hasattr(self, 'inplace'): if not hasattr(self, 'inplace'):
self.inplace = False self.inplace = False
...@@ -455,28 +468,28 @@ class GpuDnnConv(DnnBase, COp): ...@@ -455,28 +468,28 @@ class GpuDnnConv(DnnBase, COp):
if version() == -1: if version() == -1:
alg = "0" alg = "0"
else: else:
if self.workmem == 'none': if self.algo == 'none':
alg = 'CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM' alg = 'CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM'
elif self.workmem == 'small': elif self.algo == 'small':
alg = 'CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM' alg = 'CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM'
elif self.workmem == 'large': elif self.algo == 'large':
alg = 'CUDNN_CONVOLUTION_FWD_ALGO_GEMM' alg = 'CUDNN_CONVOLUTION_FWD_ALGO_GEMM'
elif self.workmem == 'fft': elif self.algo == 'fft':
alg = 'CUDNN_CONVOLUTION_FWD_ALGO_FFT' alg = 'CUDNN_CONVOLUTION_FWD_ALGO_FFT'
elif self.workmem in ['guess', 'guess_once']: elif self.algo in ['guess', 'guess_once']:
# The convolution implementation should be choosen according # The convolution implementation should be choosen according
# to a heuristic # to a heuristic
alg = 'CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM' alg = 'CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM'
choose_alg = '1' choose_alg = '1'
if self.workmem == 'guess_once': if self.algo == 'guess_once':
choose_alg_once = '1' choose_alg_once = '1'
elif self.workmem in ['time', 'time_once']: elif self.algo in ['time', 'time_once']:
# The convolution implementation should be choosen by timing # The convolution implementation should be choosen by timing
# every available implementation # every available implementation
alg = 'CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM' alg = 'CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM'
choose_alg = '1' choose_alg = '1'
choose_alg_time = '1' choose_alg_time = '1'
if self.workmem == 'time_once': if self.algo == 'time_once':
choose_alg_once = '1' choose_alg_once = '1'
alg_def = ('CONV_ALGO', alg) alg_def = ('CONV_ALGO', alg)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论