提交 0a361bbc authored 作者: Gijs van Tulder's avatar Gijs van Tulder

AbstractConv tests for empty dimensions.

上级 e8fdf903
...@@ -289,7 +289,7 @@ class BaseTestConv(object): ...@@ -289,7 +289,7 @@ class BaseTestConv(object):
res_ref = numpy.array(f_ref()) res_ref = numpy.array(f_ref())
res = numpy.array(f()) res = numpy.array(f())
utt.assert_allclose(res_ref, res) utt.assert_allclose(res_ref, res)
if verify_grad: if verify_grad and inputs_val.size > 0 and filters_val.size > 0 and res.size > 0:
utt.verify_grad(conv_op(border_mode=border_mode, utt.verify_grad(conv_op(border_mode=border_mode,
imshp=imshp, kshp=kshp, imshp=imshp, kshp=kshp,
subsample=subsample, subsample=subsample,
...@@ -355,7 +355,7 @@ class BaseTestConv(object): ...@@ -355,7 +355,7 @@ class BaseTestConv(object):
filter_dilation=filter_dilation) filter_dilation=filter_dilation)
return conv_op(inputs_val, output_val, filters_shape[2:]) return conv_op(inputs_val, output_val, filters_shape[2:])
if verify_grad: if verify_grad and inputs_val.size > 0 and output_val.size > 0 and res.size > 0:
utt.verify_grad(abstract_conv_gradweight, utt.verify_grad(abstract_conv_gradweight,
[inputs_val, output_val], [inputs_val, output_val],
mode=mode, eps=1) mode=mode, eps=1)
...@@ -421,7 +421,7 @@ class BaseTestConv(object): ...@@ -421,7 +421,7 @@ class BaseTestConv(object):
filter_dilation=filter_dilation) filter_dilation=filter_dilation)
return conv_op(filters_val, output_val, inputs_shape[2:]) return conv_op(filters_val, output_val, inputs_shape[2:])
if verify_grad: if verify_grad and filters_val.size > 0 and output_val.size > 0 and res.size > 0:
utt.verify_grad(abstract_conv_gradinputs, utt.verify_grad(abstract_conv_gradinputs,
[filters_val, output_val], [filters_val, output_val],
mode=mode, eps=1) mode=mode, eps=1)
...@@ -436,13 +436,14 @@ class BaseTestConv(object): ...@@ -436,13 +436,14 @@ class BaseTestConv(object):
for (i, f) in zip(self.inputs_shapes, self.filters_shapes): for (i, f) in zip(self.inputs_shapes, self.filters_shapes):
for provide_shape in self.provide_shape: for provide_shape in self.provide_shape:
yield (self.tcase, i, f, ds, db, dflip, provide_shape) yield (self.tcase, i, f, ds, db, dflip, provide_shape)
for fd in self.filters_dilations: if min(i) > 0 and min(f) > 0:
for s in self.subsamples: for fd in self.filters_dilations:
for b in self.border_modes: for s in self.subsamples:
yield (self.tcase, i, f, s, b, dflip, for b in self.border_modes:
dprovide_shape, fd) yield (self.tcase, i, f, s, b, dflip,
for flip in self.filter_flip: dprovide_shape, fd)
yield (self.tcase, i, f, ds, db, flip, dprovide_shape) for flip in self.filter_flip:
yield (self.tcase, i, f, ds, db, flip, dprovide_shape)
class BaseTestConv2d(BaseTestConv): class BaseTestConv2d(BaseTestConv):
...@@ -450,9 +451,11 @@ class BaseTestConv2d(BaseTestConv): ...@@ -450,9 +451,11 @@ class BaseTestConv2d(BaseTestConv):
def setup_class(cls): def setup_class(cls):
# This tests can run even when theano.config.blas.ldflags is empty. # This tests can run even when theano.config.blas.ldflags is empty.
cls.inputs_shapes = [(8, 1, 6, 6), (8, 1, 8, 8), (2, 1, 7, 7), cls.inputs_shapes = [(8, 1, 6, 6), (8, 1, 8, 8), (2, 1, 7, 7),
(6, 1, 10, 11), (2, 1, 6, 5), (1, 5, 9, 9)] (6, 1, 10, 11), (2, 1, 6, 5), (1, 5, 9, 9),
(0, 1, 6, 6), (1, 0, 6, 6), (1, 1, 6, 6)]
cls.filters_shapes = [(5, 1, 2, 2), (4, 1, 3, 3), (2, 1, 3, 3), cls.filters_shapes = [(5, 1, 2, 2), (4, 1, 3, 3), (2, 1, 3, 3),
(1, 1, 2, 3), (4, 1, 1, 3), (4, 5, 3, 2)] (1, 1, 2, 3), (4, 1, 1, 3), (4, 5, 3, 2),
(1, 1, 2, 2), (1, 0, 2, 2), (0, 1, 2, 2)]
cls.subsamples = [(1, 1), (2, 2), (2, 4)] cls.subsamples = [(1, 1), (2, 2), (2, 4)]
cls.default_subsamples = (1, 1) cls.default_subsamples = (1, 1)
cls.filters_dilations = [(1, 1), (1, 2), (2, 1)] cls.filters_dilations = [(1, 1), (1, 2), (2, 1)]
...@@ -806,8 +809,10 @@ class BaseTestConv3d(BaseTestConv): ...@@ -806,8 +809,10 @@ class BaseTestConv3d(BaseTestConv):
@classmethod @classmethod
def setup_class(cls): def setup_class(cls):
# This tests can run even when theano.config.blas.ldflags is empty. # This tests can run even when theano.config.blas.ldflags is empty.
cls.inputs_shapes = [(2, 1, 5, 5, 5), (1, 2, 7, 5, 6)] cls.inputs_shapes = [(2, 1, 5, 5, 5), (1, 2, 7, 5, 6),
cls.filters_shapes = [(2, 1, 2, 2, 2), (1, 2, 2, 1, 3)] (0, 1, 5, 5, 5), (1, 0, 5, 5, 5), (1, 1, 5, 5, 5)]
cls.filters_shapes = [(2, 1, 2, 2, 2), (1, 2, 2, 1, 3),
(1, 1, 2, 2, 2), (1, 0, 2, 2, 2), (0, 1, 2, 2, 2)]
cls.subsamples = [(1, 1, 1), (2, 2, 2), (1, 2, 3)] cls.subsamples = [(1, 1, 1), (2, 2, 2), (1, 2, 3)]
cls.default_subsamples = (1, 1, 1) cls.default_subsamples = (1, 1, 1)
cls.filters_dilations = [(1, 1, 1), (1, 2, 1), (2, 1, 2)] cls.filters_dilations = [(1, 1, 1), (1, 2, 1), (2, 1, 2)]
...@@ -975,6 +980,9 @@ class TestCpuConv3d(BaseTestConv3d): ...@@ -975,6 +980,9 @@ class TestCpuConv3d(BaseTestConv3d):
raise SkipTest("No dilation implementation for basic cpu Conv3D.") raise SkipTest("No dilation implementation for basic cpu Conv3D.")
if not theano.config.cxx: if not theano.config.cxx:
raise SkipTest("Need cxx to test conv2d") raise SkipTest("Need cxx to test conv2d")
if min(i) == 0 or min(f) == 0:
raise SkipTest('Not tested for old cpu Conv3D.')
mode = self.mode mode = self.mode
o = self.get_output_shape(i, f, s, b, fd) o = self.get_output_shape(i, f, s, b, fd)
fwd_OK = True fwd_OK = True
...@@ -1062,6 +1070,8 @@ class TestCpuConv3d(BaseTestConv3d): ...@@ -1062,6 +1070,8 @@ class TestCpuConv3d(BaseTestConv3d):
if fd != (1, 1, 1): if fd != (1, 1, 1):
raise SkipTest("No dilation implementation for basic cpu Conv3D.") raise SkipTest("No dilation implementation for basic cpu Conv3D.")
mode = self.mode mode = self.mode
if min(i) == 0 or min(f) == 0 or min(o) == 0:
raise SkipTest('Not tested for old cpu Conv3D.')
if b not in ((0, 0, 0), 'valid'): if b not in ((0, 0, 0), 'valid'):
return return
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论