提交 3938ba99 authored 作者: Gabe Schwartz's avatar Gabe Schwartz

Only run dnnconv tests with dilation if v>=6.

上级 e97223be
......@@ -384,6 +384,9 @@ class GpuDnnConvDesc(COp):
precision="float32"):
COp.__init__(self, ["conv_desc.c"], "APPLY_SPECIFIC(conv_desc)")
if version() < 6000 and any([d != 1 for d in dilation]):
raise RuntimeError("Dilation > 1 not supported for cuDNN version < 6.")
if isinstance(border_mode, integer_types):
border_mode = (border_mode,) * len(subsample)
if isinstance(border_mode, tuple):
......@@ -455,11 +458,6 @@ class GpuDnnConvDesc(COp):
else:
sub2 = '0'
if version() < 6000:
dil0 = '1'
dil1 = '1'
dil2 = '1'
else:
dil0 = str(self.dilation[0])
dil1 = str(self.dilation[1])
if len(self.dilation) > 2:
......
......@@ -641,7 +641,7 @@ class TestDnnInferShapes(utt.InferShapeTester):
def test_conv(self, algo, border_mode, conv_mode):
# Currently only CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM (algo 'none')
# supports dilation > 1.
dilations = [(1, 1), (2, 2)] if algo == "none" else [(1, 1)]
dilations = [(1, 1), (2, 2)] if (algo == "none" and dnn.version() >= 6000) else [(1, 1)]
self._test_conv(T.tensor4('img'),
T.tensor4('kerns'),
......@@ -712,6 +712,7 @@ class TestDnnInferShapes(utt.InferShapeTester):
@parameterized.expand(product(border_modes, conv_modes), utt.custom_name_func)
def test_conv_gradw(self, border_mode, conv_mode):
dilations = [(1, 1), (2, 2)] if dnn.version() >= 6000 else [(1, 1)]
self._test_conv_gradw(T.tensor4('img'),
T.tensor4('topgrad'),
T.tensor4('kerns'),
......@@ -720,7 +721,7 @@ class TestDnnInferShapes(utt.InferShapeTester):
border_mode,
conv_mode,
[(1, 1)],
[(1, 1), (2, 2)])
dilations)
def test_conv_gradi(self):
if not dnn.dnn_available(test_ctx_name):
......@@ -737,10 +738,11 @@ class TestDnnInferShapes(utt.InferShapeTester):
dtype=theano.config.floatX
)
dilations = [(1, 1), (2, 2)] if dnn.version() >= 6000 else [(1, 1)]
for border_mode, subsample, dilation, conv_mode in product(
['valid', 'full'],
[(1, 1)],
[(1, 1), (2, 2)],
dilations,
['conv', 'cross']
):
shape = get_conv_gradinputs_shape(kern_vals.shape, out_vals.shape, border_mode, subsample, dilation)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论