提交 c2703ca1 authored 作者: --global's avatar --global

Update CuDNN conv3d tests conditions for skipping

上级 9b757b1d
...@@ -592,8 +592,8 @@ class TestDnnInferShapes(utt.InferShapeTester): ...@@ -592,8 +592,8 @@ class TestDnnInferShapes(utt.InferShapeTester):
) )
def test_conv3d(self): def test_conv3d(self):
if not dnn.dnn_available(): if not (cuda.dnn.dnn_available() and dnn.version() >= (2000, 2000)):
raise SkipTest(dnn.dnn_available.msg) raise SkipTest('"CuDNN 3D convolution requires CuDNN v2')
ftensor5 = T.TensorType(dtype="float32", broadcastable=(False,) * 5) ftensor5 = T.TensorType(dtype="float32", broadcastable=(False,) * 5)
img = ftensor5('img') img = ftensor5('img')
kerns = ftensor5('kerns') kerns = ftensor5('kerns')
...@@ -680,8 +680,8 @@ class TestDnnInferShapes(utt.InferShapeTester): ...@@ -680,8 +680,8 @@ class TestDnnInferShapes(utt.InferShapeTester):
) )
def test_conv3d_gradw(self): def test_conv3d_gradw(self):
if not dnn.dnn_available(): if not (cuda.dnn.dnn_available() and dnn.version() >= (2000, 2000)):
raise SkipTest(dnn.dnn_available.msg) raise SkipTest('"CuDNN 3D convolution requires CuDNN v2')
ftensor5 = T.TensorType(dtype="float32", broadcastable=(False,) * 5) ftensor5 = T.TensorType(dtype="float32", broadcastable=(False,) * 5)
img = ftensor5('img') img = ftensor5('img')
kerns = ftensor5('kerns') kerns = ftensor5('kerns')
...@@ -725,8 +725,8 @@ class TestDnnInferShapes(utt.InferShapeTester): ...@@ -725,8 +725,8 @@ class TestDnnInferShapes(utt.InferShapeTester):
) )
def test_conv_gradi(self): def test_conv_gradi(self):
if not dnn.dnn_available(): if not (cuda.dnn.dnn_available() and dnn.version() >= (2000, 2000)):
raise SkipTest(dnn.dnn_available.msg) raise SkipTest('"CuDNN 3D convolution requires CuDNN v2')
img = T.ftensor4('img') img = T.ftensor4('img')
kerns = T.ftensor4('kerns') kerns = T.ftensor4('kerns')
out = T.ftensor4('out') out = T.ftensor4('out')
...@@ -1089,8 +1089,8 @@ def get_conv3d_test_cases(): ...@@ -1089,8 +1089,8 @@ def get_conv3d_test_cases():
def test_conv3d_fwd(): def test_conv3d_fwd():
if not cuda.dnn.dnn_available() and dnn.version()[0] >= 3000: if not (cuda.dnn.dnn_available() and dnn.version() >= (2000, 2000)):
raise SkipTest('"CuDNN 3D convolution requires CuDNN v3') raise SkipTest('"CuDNN 3D convolution requires CuDNN v2')
def run_conv3d_fwd(inputs_shape, filters_shape, subsample, def run_conv3d_fwd(inputs_shape, filters_shape, subsample,
border_mode, conv_mode): border_mode, conv_mode):
...@@ -1153,8 +1153,8 @@ def test_conv3d_fwd(): ...@@ -1153,8 +1153,8 @@ def test_conv3d_fwd():
def test_conv3d_bwd(): def test_conv3d_bwd():
if not cuda.dnn.dnn_available() and dnn.version()[0] >= 3000: if not (cuda.dnn.dnn_available() and dnn.version() >= (2000, 2000)):
raise SkipTest('"CuDNN 3D convolution requires CuDNN v3') raise SkipTest('"CuDNN 3D convolution requires CuDNN v2')
def run_conv3d_bwd(inputs_shape, filters_shape, subsample, def run_conv3d_bwd(inputs_shape, filters_shape, subsample,
border_mode, conv_mode): border_mode, conv_mode):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论