提交 d1ca9a09 authored 作者: Frederic Bastien's avatar Frederic Bastien

Correctly skip test when cudnn isn't available.

上级 56bd5b80
...@@ -170,7 +170,7 @@ def test_pooling(): ...@@ -170,7 +170,7 @@ def test_pooling():
raise SkipTest(dnn.dnn_available.msg) raise SkipTest(dnn.dnn_available.msg)
# 'average_exc_pad' is disabled for versions < 4004 # 'average_exc_pad' is disabled for versions < 4004
if dnn.version() < 4004: if dnn.version(False) < 4004:
modes = ('max', 'average_inc_pad') modes = ('max', 'average_inc_pad')
else: else:
modes = ('max', 'average_inc_pad', 'average_exc_pad') modes = ('max', 'average_inc_pad', 'average_exc_pad')
...@@ -467,7 +467,7 @@ class TestDnnInferShapes(utt.InferShapeTester): ...@@ -467,7 +467,7 @@ class TestDnnInferShapes(utt.InferShapeTester):
[conv_modes[0]])), [conv_modes[0]])),
testcase_func_name=utt.custom_name_func) testcase_func_name=utt.custom_name_func)
def test_conv(self, algo, border_mode, conv_mode): def test_conv(self, algo, border_mode, conv_mode):
if algo == 'winograd' and dnn.version() < 5000: if algo == 'winograd' and dnn.version(False) < 5000:
raise SkipTest(dnn.dnn_available.msg) raise SkipTest(dnn.dnn_available.msg)
self._test_conv(T.ftensor4('img'), self._test_conv(T.ftensor4('img'),
...@@ -600,7 +600,7 @@ class TestDnnInferShapes(utt.InferShapeTester): ...@@ -600,7 +600,7 @@ class TestDnnInferShapes(utt.InferShapeTester):
) )
# 'average_exc_pad' is disabled for versions < 4004 # 'average_exc_pad' is disabled for versions < 4004
if dnn.version() < 4004: if dnn.version(False) < 4004:
modes = ['max', 'average_inc_pad'] modes = ['max', 'average_inc_pad']
else: else:
modes = ['max', 'average_inc_pad', 'average_exc_pad'] modes = ['max', 'average_inc_pad', 'average_exc_pad']
...@@ -906,7 +906,7 @@ class test_SoftMax(test_nnet.test_SoftMax): ...@@ -906,7 +906,7 @@ class test_SoftMax(test_nnet.test_SoftMax):
def test_log_softmax(self): def test_log_softmax(self):
# This is a test for an optimization that depends on CuDNN v3 or # This is a test for an optimization that depends on CuDNN v3 or
# more recent. Don't test if the CuDNN version is too old. # more recent. Don't test if the CuDNN version is too old.
if dnn.version() < 3000: if dnn.version(False) < 3000:
raise SkipTest("Log-softmax is only in cudnn v3+") raise SkipTest("Log-softmax is only in cudnn v3+")
x = T.ftensor4() x = T.ftensor4()
...@@ -947,7 +947,7 @@ class test_SoftMax(test_nnet.test_SoftMax): ...@@ -947,7 +947,7 @@ class test_SoftMax(test_nnet.test_SoftMax):
# This is a test for an optimization that depends on CuDNN v3 or # This is a test for an optimization that depends on CuDNN v3 or
# more recent. Don't test if the CuDNN version is too old. # more recent. Don't test if the CuDNN version is too old.
if dnn.version() < 3000: if dnn.version(False) < 3000:
raise SkipTest("Log-softmax is only in cudnn v3+") raise SkipTest("Log-softmax is only in cudnn v3+")
# Compile a reference function, on the CPU, to be used to validate the # Compile a reference function, on the CPU, to be used to validate the
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论