提交 3e9bf5bd authored 作者: Mathieu Germain's avatar Mathieu Germain

Made public list of available conv algo

上级 36c90b22
......@@ -287,6 +287,20 @@ def safe_no_dnn_algo_bwd(algo):
'`dnn.conv.algo_bwd_filter` and `dnn.conv.algo_bwd_data` instead.')
return True
# Those are the supported algorithm by Theano,
# The tests will reference those lists.
SUPPORTED_DNN_CONV_ALGO_FWD = ('small', 'none', 'large', 'fft', 'fft_tiling',
'guess_once', 'guess_on_shape_change',
'time_once', 'time_on_shape_change')
SUPPORTED_DNN_CONV_ALGO_BWD_DATA = ('none', 'deterministic', 'fft', 'fft_tiling',
'guess_once', 'guess_on_shape_change',
'time_once', 'time_on_shape_change')
SUPPORTED_DNN_CONV_ALGO_BWD_FILTER = ('none', 'deterministic', 'fft', 'small',
'guess_once', 'guess_on_shape_change',
'time_once', 'time_on_shape_change')
AddConfigVar('dnn.conv.algo_bwd',
"This flag is deprecated; use dnn.conv.algo_bwd_data and "
"dnn.conv.algo_bwd_filter.",
......@@ -296,26 +310,20 @@ AddConfigVar('dnn.conv.algo_bwd',
AddConfigVar('dnn.conv.algo_fwd',
"Default implementation to use for CuDNN forward convolution.",
EnumStr('small', 'none', 'large', 'fft', 'fft_tiling',
'guess_once', 'guess_on_shape_change',
'time_once', 'time_on_shape_change'),
EnumStr(*SUPPORTED_DNN_CONV_ALGO_FWD),
in_c_key=False)
AddConfigVar('dnn.conv.algo_bwd_data',
"Default implementation to use for CuDNN backward convolution to "
"get the gradients of the convolution with regard to the inputs.",
EnumStr('none', 'deterministic', 'fft', 'fft_tiling',
'guess_once', 'guess_on_shape_change', 'time_once',
'time_on_shape_change'),
EnumStr(*SUPPORTED_DNN_CONV_ALGO_BWD_DATA),
in_c_key=False)
AddConfigVar('dnn.conv.algo_bwd_filter',
"Default implementation to use for CuDNN backward convolution to "
"get the gradients of the convolution with regard to the "
"filters.",
EnumStr('none', 'deterministic', 'fft', 'small', 'guess_once',
'guess_on_shape_change', 'time_once',
'time_on_shape_change'),
EnumStr(*SUPPORTED_DNN_CONV_ALGO_BWD_FILTER),
in_c_key=False)
AddConfigVar('dnn.conv.precision',
......
......@@ -34,6 +34,8 @@ from .nnet import GpuSoftmax
from .opt import gpu_seqopt, register_opt, conv_groupopt, op_lifter
from .opt_util import alpha_merge, output_merge, inplace_allocempty
from theano.configdefaults import SUPPORTED_DNN_CONV_ALGO_BWD_FILTER
def raise_no_cudnn(msg="CuDNN is required for convolution and pooling"):
raise RuntimeError(msg)
......@@ -583,9 +585,7 @@ class GpuDnnConvGradW(DnnBase):
algo = config.dnn.conv.algo_bwd_filter
self.algo = algo
assert self.algo in ['none', 'deterministic', 'fft', 'small',
'guess_once', 'guess_on_shape_change',
'time_once', 'time_on_shape_change']
assert self.algo in SUPPORTED_DNN_CONV_ALGO_BWD_FILTER
def __setstate__(self, d):
self.__dict__.update(d)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论