提交 dc2277cc authored 作者: Mathieu Germain's avatar Mathieu Germain

added winograd to old backend

上级 bec034eb
......@@ -296,7 +296,7 @@ class GpuDnnConv(DnnBase, COp):
workmem
*deprecated*, use parameter algo instead.
algo
['none', 'small', 'large', 'fft', 'fft_tiling', 'guess_once',
['none', 'small', 'large', 'fft', 'fft_tiling', 'guess_once', 'winograd',
'guess_on_shape_change', 'time_once', 'time_on_shape_change']
Default is the value of :attr:`config.dnn.conv.algo_fwd`.
......@@ -345,8 +345,13 @@ class GpuDnnConv(DnnBase, COp):
raise RuntimeError("CuDNN tiled-FFT convolution requires "
"CuDNN v4 or more recent")
if version() < (5000, 5000):
if self.algo == 'winograd':
raise RuntimeError("CuDNN winograd convolution requires "
"CuDNN v5 or more recent")
assert self.algo in ['none', 'small', 'large', 'fft', 'fft_tiling',
'guess_once', 'guess_on_shape_change',
'winograd', 'guess_once', 'guess_on_shape_change',
'time_once', 'time_on_shape_change']
def __setstate__(self, d):
......@@ -481,7 +486,7 @@ class GpuDnnConv3d(GpuDnnConv):
:param descr: the convolution descriptor
:param workmem:
*deprecated*, use parameter algo instead.
:param algo: ['none', 'small', 'fft_tiling',
:param algo: ['none', 'small', 'fft_tiling', 'winograd',
'guess_once', 'guess_on_shape_change',
'time_once', 'time_on_shape_change']
Default is the value of :attr:`config.dnn.conv.algo_fwd`.
......@@ -497,7 +502,8 @@ class GpuDnnConv3d(GpuDnnConv):
"Use 'algo' instead."), stacklevel=3)
assert algo is None
algo = workmem
good_algo = ['none', 'small', 'fft_tiling',
good_algo = ['none', 'small', 'fft_tiling', 'winograd',
'guess_once', 'guess_on_shape_change',
'time_once', 'time_on_shape_change']
if algo is None and config.dnn.conv.algo_fwd not in good_algo:
......@@ -505,11 +511,16 @@ class GpuDnnConv3d(GpuDnnConv):
elif algo is not None and algo not in good_algo:
algo = 'guess_once'
super(GpuDnnConv3d, self).__init__(inplace=inplace, algo=algo)
assert self.algo in good_algo
if version() < (5000, 5000):
if self.algo == 'fft_tiling':
raise RuntimeError("CuDNN 3d tiled-FFT convolution requires "
"CuDNN v5 or more recent")
elif self.algo == 'winograd':
raise RuntimeError("CuDNN 3d winograd convolution requires "
"CuDNN v5 or more recent")
def make_node(self, img, kern, output, desc, alpha=None, beta=None):
......@@ -791,9 +802,8 @@ class GpuDnnConvGradI(DnnBase, COp):
workmem
*deprecated*, use parameter algo instead.
algo
['none', 'deterministic', 'fft', 'fft_tiling', 'guess_once',
'guess_on_shape_change', 'time_once',
'time_on_shape_change']
['none', 'deterministic', 'fft', 'fft_tiling', 'winograd', 'guess_once',
'guess_on_shape_change', 'time_once', 'time_on_shape_change']
Default is the value of :attr:`config.dnn.conv.algo_bwd_data`.
......@@ -828,8 +838,13 @@ class GpuDnnConvGradI(DnnBase, COp):
raise RuntimeError("CuDNN's tiled-FFT convolution requires "
"CuDNN v4 or more recent")
if version() < (5000, 5000):
if self.algo == 'winograd':
raise RuntimeError("CuDNN's winograd convolution requires "
"CuDNN v5 or more recent")
assert self.algo in ['none', 'deterministic', 'fft', 'fft_tiling',
'guess_once', 'guess_on_shape_change',
'winograd', 'guess_once', 'guess_on_shape_change',
'time_once', 'time_on_shape_change']
def __setstate__(self, d):
......@@ -946,9 +961,8 @@ class GpuDnnConv3dGradI(GpuDnnConvGradI):
:param descr: the convolution descriptor
:param workmem:
*deprecated*, use parameter algo instead.
:param algo: ['none', 'deterministic, 'fft_tiling',
'guess_once', 'guess_on_shape_change',
'time_once', 'time_on_shape_change']
:param algo: ['none', 'deterministic, 'fft_tiling', 'winograd', 'guess_once',
'guess_on_shape_change', 'time_once', 'time_on_shape_change']
Default is the value of :attr:`config.dnn.conv.algo_bwd_data`.
......@@ -963,9 +977,11 @@ class GpuDnnConv3dGradI(GpuDnnConvGradI):
"deprecated. Use 'algo' instead."), stacklevel=3)
assert algo is None
algo = workmem
good_algo = ['none', 'deterministic', 'fft_tiling',
'guess_once', 'guess_on_shape_change',
'time_once', 'time_on_shape_change']
good_algo = ['none', 'deterministic', 'fft_tiling', 'winograd',
'guess_once', 'guess_on_shape_change', 'time_once',
'time_on_shape_change']
if algo is None and config.dnn.conv.algo_bwd_data not in good_algo:
algo = 'guess_once'
elif algo is not None and algo not in good_algo:
......@@ -977,6 +993,9 @@ class GpuDnnConv3dGradI(GpuDnnConvGradI):
if self.algo == 'fft_tiling':
raise RuntimeError("CuDNN 3d tiled-FFT convolution requires "
"CuDNN v5 or more recent")
elif self.algo == 'winograd':
raise RuntimeError("CuDNN 3d winograd convolution requires "
"CuDNN v5 or more recent")
def grad(self, inp, grads):
kerns, top, output, desc, alpha, beta = inp
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论