提交 0c57414e authored 作者: Frédéric Bastien's avatar Frédéric Bastien 提交者: GitHub

Merge pull request #4654 from slefrancois/yield_test_absconv

yield test in test_abstract_conv for tensor.nnet, sandbox.cuda and gpuarray
......@@ -15,11 +15,12 @@ gpu_ftensor4 = GpuArrayType(dtype='float32', broadcastable=(False,) * 4)
class TestDnnConv2d(test_abstract_conv.BaseTestConv2d):
def setUp(self):
super(TestDnnConv2d, self).setUp()
self.shared = gpuarray_shared_constructor
@classmethod
def setup_class(cls):
test_abstract_conv.BaseTestConv2d.setup_class()
cls.shared = staticmethod(gpuarray_shared_constructor)
# provide_shape is not used by the cuDNN impementation
self.provide_shape = [False]
cls.provide_shape = [False]
def tcase(self, i, f, s, b, flip, provide_shape, fd=(1, 1)):
if not dnn_available(test_ctx_name):
......
......@@ -23,11 +23,12 @@ else:
class TestDnnConv2d(test_abstract_conv.BaseTestConv2d):
def setUp(self):
super(TestDnnConv2d, self).setUp()
@classmethod
def setup_class(cls):
test_abstract_conv.BaseTestConv2d.setup_class()
# provide_shape is not used by the cuDNN impementation
self.provide_shape = [False]
self.shared = gpu_shared
cls.provide_shape = [False]
cls.shared = staticmethod(gpu_shared)
def tcase(self, i, f, s, b, flip, provide_shape, fd=(1, 1)):
if fd != (1, 1):
......@@ -56,10 +57,11 @@ class TestDnnConv2d(test_abstract_conv.BaseTestConv2d):
class TestCorrMMConv2d(test_abstract_conv.BaseTestConv2d):
def setUp(self):
super(TestCorrMMConv2d, self).setUp()
self.shared = gpu_shared
self.mode = mode_with_gpu.excluding('cudnn')
@classmethod
def setup_class(cls):
test_abstract_conv.BaseTestConv2d.setup_class()
cls.shared = staticmethod(gpu_shared)
cls.mode = mode_with_gpu.excluding('cudnn')
def tcase(self, i, f, s, b, flip, provide_shape, fd=(1, 1)):
mode = self.mode
......
......@@ -3,7 +3,8 @@ import unittest
import numpy
import numpy as np
from nose.plugins.skip import SkipTest
from nose.tools import assert_raises
from nose.tools import assert_raises, assert_true
import theano
from theano import tensor
from theano.gof.opt import check_stack_trace
......@@ -77,20 +78,21 @@ class TestGetConvOutShape(unittest.TestCase):
self.assertTrue(test4_params == (3, 4, 6, 4))
class BaseTestConv2d(unittest.TestCase):
def setUp(self):
class BaseTestConv2d:
@classmethod
def setup_class(cls):
if theano.config.blas.ldflags == '':
raise SkipTest("BLAS required for reference")
self.inputs_shapes = [(8, 1, 6, 6), (8, 1, 8, 8), (2, 1, 7, 7),
(6, 1, 10, 11), (2, 1, 6, 5), (1, 5, 9, 9)]
self.filters_shapes = [(5, 1, 2, 2), (4, 1, 3, 3), (2, 1, 3, 3),
(1, 1, 2, 3), (4, 1, 1, 3), (4, 5, 3, 2)]
self.subsamples = [(1, 1), (2, 2), (2, 4)]
self.filters_dilations = [(1, 1), (1, 2), (2, 1)]
self.border_modes = ["valid", "full", (0, 0), (1, 1), (5, 5), (5, 2)]
self.filter_flip = [True, False]
self.provide_shape = [True, False]
self.shared = theano.compile.shared
cls.inputs_shapes = [(8, 1, 6, 6), (8, 1, 8, 8), (2, 1, 7, 7),
(6, 1, 10, 11), (2, 1, 6, 5), (1, 5, 9, 9)]
cls.filters_shapes = [(5, 1, 2, 2), (4, 1, 3, 3), (2, 1, 3, 3),
(1, 1, 2, 3), (4, 1, 1, 3), (4, 5, 3, 2)]
cls.subsamples = [(1, 1), (2, 2), (2, 4)]
cls.filters_dilations = [(1, 1), (1, 2), (2, 1)]
cls.border_modes = ["valid", "full", (0, 0), (1, 1), (5, 5), (5, 2)]
cls.filter_flip = [True, False]
cls.provide_shape = [True, False]
cls.shared = staticmethod(theano.compile.shared)
def get_output_shape(self, inputs_shape, filters_shape,
subsample, border_mode, filter_dilation):
......@@ -153,7 +155,7 @@ class BaseTestConv2d(unittest.TestCase):
assert any([isinstance(n.op, target_op) for n
in f.maker.fgraph.toposort()])
if check_trace:
self.assertTrue(check_stack_trace(f, ops_to_check=target_op))
assert_true(check_stack_trace(f, ops_to_check=target_op))
res_ref = numpy.array(f_ref())
res = numpy.array(f())
......@@ -207,7 +209,7 @@ class BaseTestConv2d(unittest.TestCase):
assert any([isinstance(n.op, target_op) for n
in f.maker.fgraph.toposort()])
if check_trace:
self.assertTrue(check_stack_trace(f, ops_to_check=target_op))
assert_true(check_stack_trace(f, ops_to_check=target_op))
res_ref = numpy.array(f_ref())
res = numpy.array(f())
......@@ -260,7 +262,7 @@ class BaseTestConv2d(unittest.TestCase):
assert any([isinstance(n.op, target_op) for n
in f.maker.fgraph.toposort()])
if check_trace:
self.assertTrue(check_stack_trace(f, ops_to_check=target_op))
assert_true(check_stack_trace(f, ops_to_check=target_op))
res_ref = numpy.array(f_ref())
res = numpy.array(f())
......@@ -284,36 +286,24 @@ class BaseTestConv2d(unittest.TestCase):
db = (0, 0)
dflip = True in self.filter_flip
dprovide_shape = True in self.provide_shape
skipped = False
for (i, f) in zip(self.inputs_shapes, self.filters_shapes):
for provide_shape in self.provide_shape:
try:
self.tcase(i, f, ds, db, dflip, provide_shape)
except SkipTest as e:
skipped = e
yield (self.tcase, i, f, ds, db, dflip, provide_shape)
for fd in self.filters_dilations:
for s in self.subsamples:
for b in self.border_modes:
try:
self.tcase(i, f, s, db, dflip,
dprovide_shape, fd)
except SkipTest as e:
skipped = e
yield (self.tcase, i, f, s, db, dflip,
dprovide_shape, fd)
for flip in self.filter_flip:
try:
self.tcase(i, f, ds, db, flip,
dprovide_shape)
except SkipTest as e:
skipped = e
if skipped:
raise skipped
yield (self.tcase, i, f, ds, db, flip, dprovide_shape)
class TestCorrConv2d(BaseTestConv2d):
def setUp(self):
@classmethod
def setup_class(cls):
if theano.config.blas.ldflags == "":
raise SkipTest()
return super(TestCorrConv2d, self).setUp()
BaseTestConv2d.setup_class()
def tcase(self, i, f, s, b, flip, provide_shape, fd=(1, 1)):
o = self.get_output_shape(i, f, s, b, fd)
......@@ -339,14 +329,16 @@ class TestCorrConv2d(BaseTestConv2d):
class TestCpuConv2d(BaseTestConv2d):
def setUp(self):
super(TestCpuConv2d, self).setUp()
self.mode = theano.compile.mode.get_default_mode().excluding('conv_gemm')
self.opt_err = theano.config.on_opt_error
@classmethod
def setup(cls):
BaseTestConv2d.setup_class()
cls.mode = theano.compile.mode.get_default_mode().excluding('conv_gemm')
cls.opt_err = theano.config.on_opt_error
theano.config.on_opt_error = 'ignore'
def tearDown(self):
theano.config.on_opt_error = self.opt_err
@classmethod
def tearDown(cls):
theano.config.on_opt_error = cls.opt_err
def tcase(self, i, f, s, b, flip, provide_shape, fd=(1, 1)):
if fd != (1, 1):
......@@ -385,18 +377,18 @@ class TestCpuConv2d(BaseTestConv2d):
check_trace=True, filter_dilation=fd)
else:
self.assertRaises(AssertionError,
self.run_fwd,
inputs_shape=i,
filters_shape=f,
subsample=s,
verify_grad=False,
mode=mode,
provide_shape=provide_shape,
border_mode=b,
filter_flip=flip,
check_trace=True,
filter_dilation=fd)
assert_raises(AssertionError,
self.run_fwd,
inputs_shape=i,
filters_shape=f,
subsample=s,
verify_grad=False,
mode=mode,
provide_shape=provide_shape,
border_mode=b,
filter_flip=flip,
check_trace=True,
filter_dilation=fd)
if gradweight_OK:
if not theano.config.blas.ldflags:
......@@ -410,19 +402,19 @@ class TestCpuConv2d(BaseTestConv2d):
check_trace=True,
filter_dilation=fd)
else:
self.assertRaises(AssertionError,
self.run_gradweight,
inputs_shape=i,
filters_shape=f,
output_shape=o,
subsample=s,
verify_grad=False,
mode=mode,
provide_shape=provide_shape,
border_mode=b,
filter_flip=flip,
check_trace=True,
filter_dilation=fd)
assert_raises(AssertionError,
self.run_gradweight,
inputs_shape=i,
filters_shape=f,
output_shape=o,
subsample=s,
verify_grad=False,
mode=mode,
provide_shape=provide_shape,
border_mode=b,
filter_flip=flip,
check_trace=True,
filter_dilation=fd)
if gradinput_OK:
if not theano.config.blas.ldflags:
......@@ -436,19 +428,19 @@ class TestCpuConv2d(BaseTestConv2d):
check_trace=True,
filter_dilation=fd)
else:
self.assertRaises(AssertionError,
self.run_gradinput,
inputs_shape=i,
filters_shape=f,
output_shape=o,
subsample=s,
verify_grad=False,
mode=mode,
provide_shape=provide_shape,
border_mode=b,
filter_flip=flip,
check_trace=True,
filter_dilation=fd)
assert_raises(AssertionError,
self.run_gradinput,
inputs_shape=i,
filters_shape=f,
output_shape=o,
subsample=s,
verify_grad=False,
mode=mode,
provide_shape=provide_shape,
border_mode=b,
filter_flip=flip,
check_trace=True,
filter_dilation=fd)
def test_constant_shapes():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论