提交 cf26ae5c authored 作者: carriepl's avatar carriepl 提交者: Frederic

Limit new convolution implementations to CuDNN V4

上级 2a84aa6b
......@@ -460,6 +460,12 @@ class GpuDnnConv(DnnBase, COp):
raise RuntimeError("CuDNN convolution timing requires CuDNN "
"v3")
# The fft_tiling implementation is only available from CuDNN V4 onward
if version() < (4000, 4000):
if self.algo == 'fft_tiling':
raise RuntimeError("CuDNN tiled-FFT convolution requires "
"CuDNN v4 or more recent")
assert self.algo in ['none', 'small', 'large', 'fft', 'fft_tiling',
'guess_once', 'guess_on_shape_change',
'time_once', 'time_on_shape_change']
......@@ -699,6 +705,14 @@ class GpuDnnConvGradW(DnnBase, COp):
self.inplace = inplace
if self.inplace:
self.destroy_map = {0: [2]}
# The small-workspace implementation is only available from CuDNN V4
# onward.
if version() < (4000, 4000):
if self.algo == 'small':
raise RuntimeError("CuDNN's small workspace GradW convolution "
"requires CuDNN v4 or more recent.")
assert self.algo in ['none', 'deterministic', 'fft', 'small',
'guess_once', 'guess_on_shape_change',
'time_once', 'time_on_shape_change']
......@@ -911,6 +925,14 @@ class GpuDnnConvGradI(DnnBase, COp):
self.inplace = inplace
if self.inplace:
self.destroy_map = {0: [2]}
# The small-workspace implementation is only available from CuDNN V4
# onward.
if version() < (4000, 4000):
if self.algo == 'fft_tiling':
raise RuntimeError("CuDNN's tiled-FFT convolution requires "
"CuDNN v4 or more recent")
assert self.algo in ['none', 'deterministic', 'fft', 'fft_tiling',
'guess_once', 'guess_on_shape_change',
'time_once', 'time_on_shape_change']
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论