提交 adc2387b authored 作者: Mathieu Germain's avatar Mathieu Germain

added winograd

上级 6523bbdf
...@@ -290,7 +290,7 @@ def safe_no_dnn_algo_bwd(algo): ...@@ -290,7 +290,7 @@ def safe_no_dnn_algo_bwd(algo):
# Those are the supported algorithm by Theano, # Those are the supported algorithm by Theano,
# The tests will reference those lists. # The tests will reference those lists.
SUPPORTED_DNN_CONV_ALGO_FWD = ('small', 'none', 'large', 'fft', 'fft_tiling', SUPPORTED_DNN_CONV_ALGO_FWD = ('small', 'none', 'large', 'fft', 'fft_tiling',
'guess_once', 'guess_on_shape_change', 'winograd', 'guess_once', 'guess_on_shape_change',
'time_once', 'time_on_shape_change') 'time_once', 'time_on_shape_change')
SUPPORTED_DNN_CONV_ALGO_BWD_DATA = ('none', 'deterministic', 'fft', 'fft_tiling', SUPPORTED_DNN_CONV_ALGO_BWD_DATA = ('none', 'deterministic', 'fft', 'fft_tiling',
......
...@@ -402,7 +402,7 @@ class GpuDnnConv(DnnBase): ...@@ -402,7 +402,7 @@ class GpuDnnConv(DnnBase):
kernel kernel
descr descr
The convolution descriptor. The convolution descriptor.
algo : {'small', 'none', 'large', 'fft', 'fft_tiling', 'guess_once', algo : {'small', 'none', 'large', 'fft', 'fft_tiling', 'winograd', 'guess_once',
'guess_on_shape_change', 'time_once', 'time_on_shape_change'} 'guess_on_shape_change', 'time_once', 'time_on_shape_change'}
Default is the value of :attr:`config.dnn.conv.algo_fwd`. Default is the value of :attr:`config.dnn.conv.algo_fwd`.
...@@ -438,8 +438,12 @@ class GpuDnnConv(DnnBase): ...@@ -438,8 +438,12 @@ class GpuDnnConv(DnnBase):
raise RuntimeError("CuDNN tiled-FFT convolution requires " raise RuntimeError("CuDNN tiled-FFT convolution requires "
"CuDNN v4 or more recent") "CuDNN v4 or more recent")
if version() < 5000 and self.algo == 'winograd':
raise RuntimeError("CuDNN winograd convolution requires "
"CuDNN v5 or more recent")
assert self.algo in ['none', 'small', 'large', 'fft', 'fft_tiling', assert self.algo in ['none', 'small', 'large', 'fft', 'fft_tiling',
'guess_once', 'guess_on_shape_change', 'winograd', 'guess_once', 'guess_on_shape_change',
'time_once', 'time_on_shape_change'] 'time_once', 'time_on_shape_change']
def __setstate__(self, d): def __setstate__(self, d):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论