提交 5a116137 authored 作者: Frederic Bastien's avatar Frederic Bastien

Again skip tests when cudnn isn't there

上级 2bdfda4a
......@@ -584,6 +584,8 @@ class TestDnnInferShapes(utt.InferShapeTester):
conv_modes = ['conv', 'cross']
def setUp(self):
if not dnn.dnn_available(test_ctx_name):
raise SkipTest(dnn.dnn_available.msg)
super(TestDnnInferShapes, self).setUp()
self.mode = mode_with_gpu
......@@ -1032,6 +1034,8 @@ def test_dnn_conv_grad():
def get_conv3d_test_cases():
if not dnn.dnn_available(test_ctx_name):
raise SkipTest(dnn.dnn_available.msg)
# Every element of test_shapes follows the format
# [input_shape, filter_shape, subsample, dilation]
test_shapes = [[(128, 3, 5, 5, 5), (64, 3, 1, 2, 4), (1, 1, 1), (1, 1, 1)],
......@@ -1117,14 +1121,18 @@ def run_conv_small_batched_vs_multicall(inputs_shape, filters_shape, batch_sub):
def test_batched_conv_small():
# OK
if not dnn.dnn_available(test_ctx_name):
raise SkipTest(dnn.dnn_available.msg)
yield (run_conv_small_batched_vs_multicall, (65536, 2, 2, 2), (1, 2, 2, 2), 5)
# Should fail with cuDNN < V6020, but there's currently a workaround in `dnn_fwd.c` for that case.
yield (run_conv_small_batched_vs_multicall, (65537, 2, 2, 2), (1, 2, 2, 2), 5)
def test_batched_conv3d_small():
# OK
if not dnn.dnn_available(test_ctx_name):
raise SkipTest(dnn.dnn_available.msg)
yield (run_conv_small_batched_vs_multicall, (65536, 2, 2, 2, 2), (1, 2, 2, 2, 2), 5)
# Should fail with cuDNN < V6020, but there's currently a workaround in `dnn_fwd.c` for that case.
yield (run_conv_small_batched_vs_multicall, (65537, 2, 2, 2, 2), (1, 2, 2, 2, 2), 5)
......@@ -1529,11 +1537,15 @@ def dnn_reduction_strides(shp, shuffle, slice):
def test_dnn_reduction_strides():
if not dnn.dnn_available(test_ctx_name):
raise SkipTest(dnn.dnn_available.msg)
yield dnn_reduction_strides, (2, 3, 2), (1, 0, 2), slice(None, None, None)
yield dnn_reduction_strides, (2, 3, 2), (0, 1, 2), slice(None, None, -1)
def test_dnn_reduction_error():
if not dnn.dnn_available(test_ctx_name):
raise SkipTest(dnn.dnn_available.msg)
nLoops = 5
vec = np.arange(0, 10, dtype=np.float32)
slow_output = np.zeros((5, 10))
......@@ -2708,6 +2720,8 @@ class TestDnnConv3DRuntimeAlgorithms(TestDnnConv2DRuntimeAlgorithms):
def test_conv_guess_once_with_dtypes():
# This test checks that runtime conv algorithm selection does not raise any exception
# when consecutive functions with different dtypes and precisions are executed.
if not dnn.dnn_available(test_ctx_name):
raise SkipTest(dnn.dnn_available.msg)
utt.seed_rng()
inputs_shape = (2, 3, 5, 5)
filters_shape = (2, 3, 40, 4)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论