提交 76b4f47a authored 作者: Gijs van Tulder's avatar Gijs van Tulder

Tests for arbitrary inputs should now raise errors.

上级 297da372
...@@ -472,13 +472,10 @@ class BaseTestConv2d(BaseTestConv): ...@@ -472,13 +472,10 @@ class BaseTestConv2d(BaseTestConv):
filter_shape = (2, 1, 3, 3) filter_shape = (2, 1, 3, 3)
for output_shape in [(2, 2, 8, 8), (2, 2, 9, 9), (2, 2, 12, 12)]: for output_shape in [(2, 2, 8, 8), (2, 2, 9, 9), (2, 2, 12, 12)]:
for border_mode in ["valid", "half", "full"]: for border_mode in ["valid", "half", "full"]:
# is this output shape large enough? computed_shape = get_conv_output_shape(
min_output_shape = self.get_output_shape( input_shape, filter_shape, border_mode, self.default_subsamples, self.default_filters_dilations)
input_shape, filter_shape, self.default_subsamples, # is this a valid combination?
border_mode, self.default_filters_dilations) if tuple(computed_shape) == output_shape:
if not all(o >= min_o for (o, min_o) in zip(output_shape, min_output_shape)):
continue
for provide_shape in self.provide_shape:
yield (self.tcase_gi, yield (self.tcase_gi,
input_shape, input_shape,
filter_shape, filter_shape,
...@@ -486,8 +483,21 @@ class BaseTestConv2d(BaseTestConv): ...@@ -486,8 +483,21 @@ class BaseTestConv2d(BaseTestConv):
self.default_subsamples, self.default_subsamples,
border_mode, border_mode,
True, True,
provide_shape, True,
self.default_filters_dilations) self.default_filters_dilations,
False)
else:
# expect an error
yield (self.tcase_gi,
input_shape,
filter_shape,
output_shape,
self.default_subsamples,
border_mode,
True,
True,
self.default_filters_dilations,
True)
def test_gradinput_impossible_output_shapes(self): def test_gradinput_impossible_output_shapes(self):
for i in range(1, 20): for i in range(1, 20):
...@@ -505,6 +515,7 @@ class BaseTestConv2d(BaseTestConv): ...@@ -505,6 +515,7 @@ class BaseTestConv2d(BaseTestConv):
# outputs that are too large or too small should be rejected # outputs that are too large or too small should be rejected
for o in (-3, -2, -1, 1, 2, 3): for o in (-3, -2, -1, 1, 2, 3):
output_shape = (1, 1, computed_shape[2] + o, computed_shape[3] + o) output_shape = (1, 1, computed_shape[2] + o, computed_shape[3] + o)
# expect an error
yield (self.tcase_gi, yield (self.tcase_gi,
image_shape, image_shape,
kernel_shape, kernel_shape,
...@@ -822,13 +833,11 @@ class BaseTestConv3d(BaseTestConv): ...@@ -822,13 +833,11 @@ class BaseTestConv3d(BaseTestConv):
filter_shape = (1, 1, 3, 3, 3) filter_shape = (1, 1, 3, 3, 3)
for output_shape in [(2, 1, 8, 8, 8), (2, 1, 9, 9, 9), (2, 1, 12, 12, 12)]: for output_shape in [(2, 1, 8, 8, 8), (2, 1, 9, 9, 9), (2, 1, 12, 12, 12)]:
for border_mode in ["valid", "half", "full"]: for border_mode in ["valid", "half", "full"]:
# is this output shape large enough? # compute the output that these inputs and parameters would produce
min_output_shape = self.get_output_shape( computed_shape = get_conv_output_shape(
input_shape, filter_shape, self.default_subsamples, input_shape, filter_shape, border_mode, self.default_subsamples, self.default_filters_dilations)
border_mode, self.default_filters_dilations) # is this a valid combination?
if not all(o >= min_o for (o, min_o) in zip(output_shape, min_output_shape)): if tuple(computed_shape) == output_shape:
continue
for provide_shape in self.provide_shape:
yield (self.tcase_gi, yield (self.tcase_gi,
input_shape, input_shape,
filter_shape, filter_shape,
...@@ -836,8 +845,21 @@ class BaseTestConv3d(BaseTestConv): ...@@ -836,8 +845,21 @@ class BaseTestConv3d(BaseTestConv):
self.default_subsamples, self.default_subsamples,
border_mode, border_mode,
True, True,
provide_shape, True,
self.default_filters_dilations) self.default_filters_dilations,
False)
else:
# expect an error
yield (self.tcase_gi,
input_shape,
filter_shape,
output_shape,
self.default_subsamples,
border_mode,
True,
True,
self.default_filters_dilations,
True)
def test_gradinput_impossible_output_shapes(self): def test_gradinput_impossible_output_shapes(self):
for i in range(1, 20): for i in range(1, 20):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论