提交 8929a515 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Fix test_dnn.py and make it pass.

上级 2bd44678
...@@ -15,12 +15,12 @@ from theano.tensor.signal.downsample import MaxPoolGrad, AveragePoolGrad ...@@ -15,12 +15,12 @@ from theano.tensor.signal.downsample import MaxPoolGrad, AveragePoolGrad
from .. import dnn from .. import dnn
from ..basic_ops import GpuAllocEmpty from ..basic_ops import GpuAllocEmpty
from .config import mode_with_gpu, mode_without_gpu from .config import mode_with_gpu, mode_without_gpu, test_ctx_name
from . import test_nnet from . import test_nnet
def test_dnn_conv_desc_merge(): def test_dnn_conv_desc_merge():
if not dnn.dnn_available(): if not dnn.dnn_available(test_ctx_name):
raise SkipTest(dnn.dnn_available.msg) raise SkipTest(dnn.dnn_available.msg)
kern_shp = T.as_tensor_variable( kern_shp = T.as_tensor_variable(
numpy.asarray([3, 1, 2, 2]).astype('int64')) numpy.asarray([3, 1, 2, 2]).astype('int64'))
...@@ -41,7 +41,7 @@ def test_dnn_conv_desc_merge(): ...@@ -41,7 +41,7 @@ def test_dnn_conv_desc_merge():
def test_dnn_conv_merge(): def test_dnn_conv_merge():
# This test that we merge correctly multiple dnn_conv. # This test that we merge correctly multiple dnn_conv.
if not dnn.dnn_available(): if not dnn.dnn_available(test_ctx_name):
raise SkipTest(dnn.dnn_available.msg) raise SkipTest(dnn.dnn_available.msg)
img_shp = [2, 5, 6, 8] img_shp = [2, 5, 6, 8]
kern_shp = [3, 5, 5, 6] kern_shp = [3, 5, 5, 6]
...@@ -80,7 +80,7 @@ def test_dnn_conv_inplace(): ...@@ -80,7 +80,7 @@ def test_dnn_conv_inplace():
GpuAllocEmpty get merged together. GpuAllocEmpty get merged together.
""" """
if not dnn.dnn_available(): if not dnn.dnn_available(test_ctx_name):
raise SkipTest(dnn.dnn_available.msg) raise SkipTest(dnn.dnn_available.msg)
img_shp = [2, 5, 6, 8] img_shp = [2, 5, 6, 8]
kern_shp = [3, 5, 5, 6] kern_shp = [3, 5, 5, 6]
...@@ -105,7 +105,7 @@ def test_dnn_conv_inplace(): ...@@ -105,7 +105,7 @@ def test_dnn_conv_inplace():
assert len([n for n in topo if isinstance(n.op, GpuAllocEmpty)]) == 2 assert len([n for n in topo if isinstance(n.op, GpuAllocEmpty)]) == 2
# Test grad w op # Test grad w op
out = GpuAllocEmpty(kern.dtype)(*kern.shape) out = GpuAllocEmpty(kern.dtype, test_ctx_name)(*kern.shape)
o1 = dnn.GpuDnnConvGradW()(img, kern, out, desc1) o1 = dnn.GpuDnnConvGradW()(img, kern, out, desc1)
o2 = dnn.GpuDnnConvGradW()(img, kern, out, desc2) o2 = dnn.GpuDnnConvGradW()(img, kern, out, desc2)
f = theano.function([img, kern], [o1, o2], mode=mode_with_gpu) f = theano.function([img, kern], [o1, o2], mode=mode_with_gpu)
...@@ -116,7 +116,7 @@ def test_dnn_conv_inplace(): ...@@ -116,7 +116,7 @@ def test_dnn_conv_inplace():
assert len([n for n in topo if isinstance(n.op, GpuAllocEmpty)]) == 2 assert len([n for n in topo if isinstance(n.op, GpuAllocEmpty)]) == 2
# Test grad i op # Test grad i op
out = GpuAllocEmpty(img.dtype)(*img.shape) out = GpuAllocEmpty(img.dtype, test_ctx_name)(*img.shape)
o1 = dnn.GpuDnnConvGradI()(img, kern, out, desc1) o1 = dnn.GpuDnnConvGradI()(img, kern, out, desc1)
o2 = dnn.GpuDnnConvGradI()(img, kern, out, desc2) o2 = dnn.GpuDnnConvGradI()(img, kern, out, desc2)
f = theano.function([img, kern], [o1, o2], mode=mode_with_gpu) f = theano.function([img, kern], [o1, o2], mode=mode_with_gpu)
...@@ -163,7 +163,7 @@ def pool_2d_i2n(input, ds=(2, 2), strides=None, ...@@ -163,7 +163,7 @@ def pool_2d_i2n(input, ds=(2, 2), strides=None,
def test_pooling(): def test_pooling():
if not dnn.dnn_available(): if not dnn.dnn_available(test_ctx_name):
raise SkipTest(dnn.dnn_available.msg) raise SkipTest(dnn.dnn_available.msg)
x = T.ftensor4() x = T.ftensor4()
...@@ -269,7 +269,7 @@ def test_pooling(): ...@@ -269,7 +269,7 @@ def test_pooling():
def test_pooling_opt(): def test_pooling_opt():
if not dnn.dnn_available(): if not dnn.dnn_available(test_ctx_name):
raise SkipTest(dnn.dnn_available.msg) raise SkipTest(dnn.dnn_available.msg)
x = T.fmatrix() x = T.fmatrix()
...@@ -318,7 +318,7 @@ def test_dnn_tag(): ...@@ -318,7 +318,7 @@ def test_dnn_tag():
max_pool_2d(x, ds=(2, 2), ignore_border=True), max_pool_2d(x, ds=(2, 2), ignore_border=True),
mode=mode_with_gpu.including("cudnn")) mode=mode_with_gpu.including("cudnn"))
except (AssertionError, RuntimeError): except (AssertionError, RuntimeError):
assert not dnn.dnn_available() assert not dnn.dnn_available(test_ctx_name)
raised = True raised = True
finally: finally:
theano.config.on_opt_error = old theano.config.on_opt_error = old
...@@ -327,7 +327,7 @@ def test_dnn_tag(): ...@@ -327,7 +327,7 @@ def test_dnn_tag():
logging.getLogger('theano').addHandler(theano.logging_default_handler) logging.getLogger('theano').addHandler(theano.logging_default_handler)
if not raised: if not raised:
assert dnn.dnn_available() assert dnn.dnn_available(test_ctx_name)
assert any([isinstance(n.op, dnn.GpuDnnPool) assert any([isinstance(n.op, dnn.GpuDnnPool)
for n in f.maker.fgraph.toposort()]) for n in f.maker.fgraph.toposort()])
...@@ -338,7 +338,7 @@ class TestDnnInferShapes(utt.InferShapeTester): ...@@ -338,7 +338,7 @@ class TestDnnInferShapes(utt.InferShapeTester):
self.mode = mode_with_gpu self.mode = mode_with_gpu
def test_softmax(self): def test_softmax(self):
if not dnn.dnn_available(): if not dnn.dnn_available(test_ctx_name):
raise SkipTest(dnn.dnn_available.msg) raise SkipTest(dnn.dnn_available.msg)
t = T.ftensor4('t') t = T.ftensor4('t')
rand_tensor = numpy.asarray( rand_tensor = numpy.asarray(
...@@ -368,7 +368,7 @@ class TestDnnInferShapes(utt.InferShapeTester): ...@@ -368,7 +368,7 @@ class TestDnnInferShapes(utt.InferShapeTester):
) )
def test_conv(self): def test_conv(self):
if not dnn.dnn_available(): if not dnn.dnn_available(test_ctx_name):
raise SkipTest(dnn.dnn_available.msg) raise SkipTest(dnn.dnn_available.msg)
img = T.ftensor4('img') img = T.ftensor4('img')
kerns = T.ftensor4('kerns') kerns = T.ftensor4('kerns')
...@@ -406,7 +406,7 @@ class TestDnnInferShapes(utt.InferShapeTester): ...@@ -406,7 +406,7 @@ class TestDnnInferShapes(utt.InferShapeTester):
) )
def test_conv_gradw(self): def test_conv_gradw(self):
if not dnn.dnn_available(): if not dnn.dnn_available(test_ctx_name):
raise SkipTest(dnn.dnn_available.msg) raise SkipTest(dnn.dnn_available.msg)
img = T.ftensor4('img') img = T.ftensor4('img')
kerns = T.ftensor4('kerns') kerns = T.ftensor4('kerns')
...@@ -455,7 +455,7 @@ class TestDnnInferShapes(utt.InferShapeTester): ...@@ -455,7 +455,7 @@ class TestDnnInferShapes(utt.InferShapeTester):
) )
def test_conv_gradi(self): def test_conv_gradi(self):
if not dnn.dnn_available(): if not dnn.dnn_available(test_ctx_name):
raise SkipTest(dnn.dnn_available.msg) raise SkipTest(dnn.dnn_available.msg)
img = T.ftensor4('img') img = T.ftensor4('img')
kerns = T.ftensor4('kerns') kerns = T.ftensor4('kerns')
...@@ -499,7 +499,7 @@ class TestDnnInferShapes(utt.InferShapeTester): ...@@ -499,7 +499,7 @@ class TestDnnInferShapes(utt.InferShapeTester):
) )
def test_pool(self): def test_pool(self):
if not dnn.dnn_available(): if not dnn.dnn_available(test_ctx_name):
raise SkipTest(dnn.dnn_available.msg) raise SkipTest(dnn.dnn_available.msg)
img = T.ftensor4('img') img = T.ftensor4('img')
img_val = numpy.asarray( img_val = numpy.asarray(
...@@ -524,7 +524,7 @@ class TestDnnInferShapes(utt.InferShapeTester): ...@@ -524,7 +524,7 @@ class TestDnnInferShapes(utt.InferShapeTester):
) )
def test_pool_grad(self): def test_pool_grad(self):
if not dnn.dnn_available(): if not dnn.dnn_available(test_ctx_name):
raise SkipTest(dnn.dnn_available.msg) raise SkipTest(dnn.dnn_available.msg)
img = T.ftensor4('img') img = T.ftensor4('img')
img_grad = T.ftensor4('img_grad') img_grad = T.ftensor4('img_grad')
...@@ -568,7 +568,7 @@ class TestDnnInferShapes(utt.InferShapeTester): ...@@ -568,7 +568,7 @@ class TestDnnInferShapes(utt.InferShapeTester):
# this has been a problem in the past # this has been a problem in the past
def test_dnn_conv_border_mode(): def test_dnn_conv_border_mode():
if not dnn.dnn_available(): if not dnn.dnn_available(test_ctx_name):
raise SkipTest(dnn.dnn_available.msg) raise SkipTest(dnn.dnn_available.msg)
img = T.ftensor4() img = T.ftensor4()
kern = T.ftensor4() kern = T.ftensor4()
...@@ -580,7 +580,7 @@ def test_dnn_conv_border_mode(): ...@@ -580,7 +580,7 @@ def test_dnn_conv_border_mode():
def test_dnn_conv_alpha_output_merge(): def test_dnn_conv_alpha_output_merge():
if not dnn.dnn_available(): if not dnn.dnn_available(test_ctx_name):
raise SkipTest(dnn.dnn_available.msg) raise SkipTest(dnn.dnn_available.msg)
img = T.ftensor4() img = T.ftensor4()
kern = T.ftensor4() kern = T.ftensor4()
...@@ -678,7 +678,7 @@ def test_dnn_conv_grad(): ...@@ -678,7 +678,7 @@ def test_dnn_conv_grad():
def test_version(): def test_version():
if not dnn.dnn_available(): if not dnn.dnn_available(test_ctx_name):
raise SkipTest(dnn.dnn_available.msg) raise SkipTest(dnn.dnn_available.msg)
assert isinstance(dnn.version(), int) assert isinstance(dnn.version(), int)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论