提交 e0cdd248 authored 作者: Gijs van Tulder's avatar Gijs van Tulder

Test if conv helper functions call assert_shape.

上级 5a730afe
......@@ -243,6 +243,92 @@ class TestAssertShape(unittest.TestCase):
assert_raises(AssertionError, f, v, 0, 7)
assert_raises(AssertionError, f, v, 7, 7)
def test_shape_check_conv2d(self):
input = tensor.tensor4()
filters = tensor.tensor4()
out = conv.conv2d(input, filters,
input_shape=(3, 5, 7, 11),
filter_shape=(7, 5, 3, 3))
f = theano.function([input, filters], out)
# mismatched input_shape
assert_raises(AssertionError, f,
numpy.zeros((3, 5, 9, 11), dtype='float32'),
numpy.zeros((7, 5, 3, 3), dtype='float32'))
# mismatched filter_shape
assert_raises(AssertionError, f,
numpy.zeros((3, 5, 7, 11), dtype='float32'),
numpy.zeros((7, 5, 2, 2), dtype='float32'))
def test_shape_check_conv3d(self):
input = tensor.tensor5()
filters = tensor.tensor5()
out = conv.conv3d(input, filters,
input_shape=(3, 5, 7, 11, 13),
filter_shape=(7, 5, 3, 3, 3))
f = theano.function([input, filters], out)
# mismatched input_shape
assert_raises(AssertionError, f,
numpy.zeros((3, 5, 9, 11, 13), dtype='float32'),
numpy.zeros((7, 5, 3, 3, 3), dtype='float32'))
# mismatched filter_shape
assert_raises(AssertionError, f,
numpy.zeros((3, 5, 7, 11, 13), dtype='float32'),
numpy.zeros((7, 5, 2, 2, 2), dtype='float32'))
def test_shape_check_conv2d_grad_wrt_inputs(self):
output_grad = tensor.tensor4()
filters = tensor.tensor4()
out = conv.conv2d_grad_wrt_inputs(output_grad, filters,
input_shape=(None, None, 7, 11),
filter_shape=(7, 5, 3, 3))
f = theano.function([output_grad, filters], out)
# mismatched filter_shape
assert_raises(AssertionError, f,
numpy.zeros((3, 6, 5, 9), dtype='float32'),
numpy.zeros((7, 6, 3, 3), dtype='float32'))
def test_shape_check_conv3d_grad_wrt_inputs(self):
output_grad = tensor.tensor5()
filters = tensor.tensor5()
out = conv.conv3d_grad_wrt_inputs(output_grad, filters,
input_shape=(None, None, 7, 11, 13),
filter_shape=(7, 5, 3, 3, 3))
f = theano.function([output_grad, filters], out)
# mismatched filter_shape
assert_raises(AssertionError, f,
numpy.zeros((3, 6, 5, 9, 11), dtype='float32'),
numpy.zeros((7, 6, 3, 3, 3), dtype='float32'))
def test_shape_check_conv2d_grad_wrt_weights(self):
input = tensor.tensor4()
output_grad = tensor.tensor4()
out = conv.conv2d_grad_wrt_weights(input, output_grad,
filter_shape=(None, None, 3, 3),
input_shape=(3, 5, 7, 11))
f = theano.function([input, output_grad], out)
# mismatched filter_shape
assert_raises(AssertionError, f,
numpy.zeros((3, 6, 7, 11), dtype='float32'),
numpy.zeros((3, 7, 5, 9), dtype='float32'))
def test_shape_check_conv3d_grad_wrt_weights(self):
input = tensor.tensor5()
output_grad = tensor.tensor5()
out = conv.conv3d_grad_wrt_weights(input, output_grad,
filter_shape=(None, None, 3, 3, 3),
input_shape=(3, 5, 7, 11, 13))
f = theano.function([input, output_grad], out)
# mismatched filter_shape
assert_raises(AssertionError, f,
numpy.zeros((3, 6, 7, 11, 13), dtype='float32'),
numpy.zeros((3, 7, 5, 9, 11), dtype='float32'))
class BaseTestConv(object):
def get_output_shape(self, inputs_shape, filters_shape,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论