提交 6f4a125d authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Last touchups on the GPU tests.

上级 238f0c87
...@@ -56,7 +56,7 @@ class TestDnnConv2d(test_abstract_conv.TestConv2d): ...@@ -56,7 +56,7 @@ class TestDnnConv2d(test_abstract_conv.TestConv2d):
class TestCorrMMConv2d(test_abstract_conv.TestConv2d): class TestCorrMMConv2d(test_abstract_conv.TestConv2d):
def setUp(self): def setUp(self):
super(TestDnnConv2d, self).setUp() super(TestCorrMMConv2d, self).setUp()
self.shared = gpu_shared self.shared = gpu_shared
def test_gpucorrmm_conv(self): def test_gpucorrmm_conv(self):
......
...@@ -3,7 +3,7 @@ import itertools ...@@ -3,7 +3,7 @@ import itertools
from nose.plugins.skip import SkipTest from nose.plugins.skip import SkipTest
from theano.tensor.nnet.tests import test_abstract_conv from theano.tensor.nnet.tests import test_abstract_conv
from ..type import GpuArrayType from ..type import GpuArrayType, gpuarray_shared_constructor
from ..dnn import dnn_available, GpuDnnConv, GpuDnnConvGradW, GpuDnnConvGradI from ..dnn import dnn_available, GpuDnnConv, GpuDnnConvGradW, GpuDnnConvGradI
from .config import mode_with_gpu, test_ctx_name from .config import mode_with_gpu, test_ctx_name
...@@ -12,6 +12,10 @@ gpu_ftensor4 = GpuArrayType(dtype='float32', broadcastable=(False,) * 4) ...@@ -12,6 +12,10 @@ gpu_ftensor4 = GpuArrayType(dtype='float32', broadcastable=(False,) * 4)
class TestDnnConv2d(test_abstract_conv.TestConv2d): class TestDnnConv2d(test_abstract_conv.TestConv2d):
def setUp(self):
super(TestDnnConv2d, self).setUp()
self.shared = gpuarray_shared_constructor
def test_dnn_conv(self): def test_dnn_conv(self):
if not dnn_available(test_ctx_name): if not dnn_available(test_ctx_name):
raise SkipTest(dnn_available.msg) raise SkipTest(dnn_available.msg)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论