提交 3e9bf5bd authored 作者: Mathieu Germain's avatar Mathieu Germain

Made public list of available conv algo

上级 36c90b22
...@@ -287,6 +287,20 @@ def safe_no_dnn_algo_bwd(algo): ...@@ -287,6 +287,20 @@ def safe_no_dnn_algo_bwd(algo):
'`dnn.conv.algo_bwd_filter` and `dnn.conv.algo_bwd_data` instead.') '`dnn.conv.algo_bwd_filter` and `dnn.conv.algo_bwd_data` instead.')
return True return True
# Those are the supported algorithm by Theano,
# The tests will reference those lists.
SUPPORTED_DNN_CONV_ALGO_FWD = ('small', 'none', 'large', 'fft', 'fft_tiling',
'guess_once', 'guess_on_shape_change',
'time_once', 'time_on_shape_change')
SUPPORTED_DNN_CONV_ALGO_BWD_DATA = ('none', 'deterministic', 'fft', 'fft_tiling',
'guess_once', 'guess_on_shape_change',
'time_once', 'time_on_shape_change')
SUPPORTED_DNN_CONV_ALGO_BWD_FILTER = ('none', 'deterministic', 'fft', 'small',
'guess_once', 'guess_on_shape_change',
'time_once', 'time_on_shape_change')
AddConfigVar('dnn.conv.algo_bwd', AddConfigVar('dnn.conv.algo_bwd',
"This flag is deprecated; use dnn.conv.algo_bwd_data and " "This flag is deprecated; use dnn.conv.algo_bwd_data and "
"dnn.conv.algo_bwd_filter.", "dnn.conv.algo_bwd_filter.",
...@@ -296,26 +310,20 @@ AddConfigVar('dnn.conv.algo_bwd', ...@@ -296,26 +310,20 @@ AddConfigVar('dnn.conv.algo_bwd',
AddConfigVar('dnn.conv.algo_fwd', AddConfigVar('dnn.conv.algo_fwd',
"Default implementation to use for CuDNN forward convolution.", "Default implementation to use for CuDNN forward convolution.",
EnumStr('small', 'none', 'large', 'fft', 'fft_tiling', EnumStr(*SUPPORTED_DNN_CONV_ALGO_FWD),
'guess_once', 'guess_on_shape_change',
'time_once', 'time_on_shape_change'),
in_c_key=False) in_c_key=False)
AddConfigVar('dnn.conv.algo_bwd_data', AddConfigVar('dnn.conv.algo_bwd_data',
"Default implementation to use for CuDNN backward convolution to " "Default implementation to use for CuDNN backward convolution to "
"get the gradients of the convolution with regard to the inputs.", "get the gradients of the convolution with regard to the inputs.",
EnumStr('none', 'deterministic', 'fft', 'fft_tiling', EnumStr(*SUPPORTED_DNN_CONV_ALGO_BWD_DATA),
'guess_once', 'guess_on_shape_change', 'time_once',
'time_on_shape_change'),
in_c_key=False) in_c_key=False)
AddConfigVar('dnn.conv.algo_bwd_filter', AddConfigVar('dnn.conv.algo_bwd_filter',
"Default implementation to use for CuDNN backward convolution to " "Default implementation to use for CuDNN backward convolution to "
"get the gradients of the convolution with regard to the " "get the gradients of the convolution with regard to the "
"filters.", "filters.",
EnumStr('none', 'deterministic', 'fft', 'small', 'guess_once', EnumStr(*SUPPORTED_DNN_CONV_ALGO_BWD_FILTER),
'guess_on_shape_change', 'time_once',
'time_on_shape_change'),
in_c_key=False) in_c_key=False)
AddConfigVar('dnn.conv.precision', AddConfigVar('dnn.conv.precision',
......
...@@ -34,6 +34,8 @@ from .nnet import GpuSoftmax ...@@ -34,6 +34,8 @@ from .nnet import GpuSoftmax
from .opt import gpu_seqopt, register_opt, conv_groupopt, op_lifter from .opt import gpu_seqopt, register_opt, conv_groupopt, op_lifter
from .opt_util import alpha_merge, output_merge, inplace_allocempty from .opt_util import alpha_merge, output_merge, inplace_allocempty
from theano.configdefaults import SUPPORTED_DNN_CONV_ALGO_BWD_FILTER
def raise_no_cudnn(msg="CuDNN is required for convolution and pooling"): def raise_no_cudnn(msg="CuDNN is required for convolution and pooling"):
raise RuntimeError(msg) raise RuntimeError(msg)
...@@ -583,9 +585,7 @@ class GpuDnnConvGradW(DnnBase): ...@@ -583,9 +585,7 @@ class GpuDnnConvGradW(DnnBase):
algo = config.dnn.conv.algo_bwd_filter algo = config.dnn.conv.algo_bwd_filter
self.algo = algo self.algo = algo
assert self.algo in ['none', 'deterministic', 'fft', 'small', assert self.algo in SUPPORTED_DNN_CONV_ALGO_BWD_FILTER
'guess_once', 'guess_on_shape_change',
'time_once', 'time_on_shape_change']
def __setstate__(self, d): def __setstate__(self, d):
self.__dict__.update(d) self.__dict__.update(d)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论