提交 5caafd99 authored 作者: Gijs van Tulder's avatar Gijs van Tulder

Test if the promised input/filter shapes match the actual shapes.

上级 193c9979
......@@ -471,6 +471,41 @@ def assert_conv_shape(shape):
return tuple(out_shape)
def assert_shape(x, expected_shape, msg='Unexpected shape.'):
"""Wraps `x` in an `Assert` to check its shape.
Parameters
----------
x : Tensor
x will be wrapped in an `Assert`.
expected_shape : tuple or list
The expected shape of `x`. The size of a dimension can be None,
which means it will not be checked.
msg : str
The error message of the `Assert`.
Returns
-------
Tensor
`x` wrapped in an `Assert`. At execution time, this will throw an
AssertionError if the shape of `x` does not match `expected_shape`.
If `expected_shape` is None or contains only Nones, the function
will return `x` directly.
"""
if expected_shape is None:
return x
shape = x.shape
tests = []
for i in range(x.ndim):
if expected_shape[i] is not None:
tests.append(theano.tensor.eq(shape[i], expected_shape[i]))
if tests:
return Assert(msg)(x, theano.tensor.all(tests))
else:
return x
def conv2d(input,
filters,
input_shape=None,
......@@ -488,6 +523,14 @@ def conv2d(input,
input = as_tensor_variable(input)
filters = as_tensor_variable(filters)
if input_shape is not None:
input = assert_shape(input, input_shape,
'conv2d shape mismatch: shape of '
'input does not match given input_shape.')
if filter_shape is not None:
filters = assert_shape(filters, filter_shape,
'conv2d shape mismatch: shape of '
'filters does not match given filter_shape.')
conv_op = AbstractConv2d(imshp=input_shape,
kshp=filter_shape,
border_mode=border_mode,
......@@ -587,6 +630,14 @@ def conv3d(input,
input = as_tensor_variable(input)
filters = as_tensor_variable(filters)
if input_shape is not None:
input = assert_shape(input, input_shape,
'conv3d shape mismatch: shape of '
'input does not match given input_shape.')
if filter_shape is not None:
filters = assert_shape(filters, filter_shape,
'conv3d shape mismatch: shape of '
'filters does not match given filter_shape.')
conv_op = AbstractConv3d(imshp=input_shape,
kshp=filter_shape,
border_mode=border_mode,
......@@ -713,6 +764,9 @@ def conv2d_grad_wrt_inputs(output_grad,
for dim in [0, 1, 2, 3]:
assert isinstance(filter_shape[dim], (theano.tensor.TensorConstant,
integer_types, type(None)))
filters = assert_shape(filters, filter_shape,
'conv2d_grad_wrt_inputs shape mismatch: shape of '
'filters does not match given filter_shape.')
# setting the last two dimensions of input_shape to None, if
# the type of these dimensions is TensorVariable.
......@@ -848,6 +902,9 @@ def conv3d_grad_wrt_inputs(output_grad,
for dim in [0, 1, 2, 3, 4]:
assert isinstance(filter_shape[dim], (theano.tensor.TensorConstant,
integer_types, type(None)))
filters = assert_shape(filters, filter_shape,
'conv3d_grad_wrt_inputs shape mismatch: shape of '
'filters does not match given filter_shape.')
# setting the last three dimensions of input_shape to None, if
# the type of these dimensions is TensorVariable.
......
......@@ -14,7 +14,8 @@ from theano.tensor.nnet.abstract_conv import (get_conv_output_shape,
get_conv_gradweights_shape,
get_conv_gradinputs_shape,
check_conv_gradinputs_shape,
assert_conv_shape)
assert_conv_shape,
assert_shape)
from theano.tensor.nnet.abstract_conv import AbstractConv2d
from theano.tensor.nnet.abstract_conv import AbstractConv2d_gradInputs
from theano.tensor.nnet.abstract_conv import AbstractConv2d_gradWeights
......@@ -226,6 +227,23 @@ class TestAssertConvShape(unittest.TestCase):
assert_raises(AssertionError, f, -1, 3, 3, 3)
class TestAssertShape(unittest.TestCase):
def test_basic(self):
x = tensor.tensor4()
s1 = tensor.iscalar()
s2 = tensor.iscalar()
expected_shape = [None, s1, s2, None]
f = theano.function([x, s1, s2], assert_shape(x, expected_shape))
v = numpy.zeros((3, 5, 7, 11), dtype='float32')
self.assertEqual(0, numpy.sum(f(v, 5, 7)))
assert_raises(AssertionError, f, v, 5, 0)
assert_raises(AssertionError, f, v, 5, 9)
assert_raises(AssertionError, f, v, 0, 7)
assert_raises(AssertionError, f, v, 7, 7)
class BaseTestConv(object):
def get_output_shape(self, inputs_shape, filters_shape,
subsample, border_mode, filter_dilation):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论