提交 f3fd620a authored 作者: Frederic Bastien's avatar Frederic Bastien

make convolution op accept Constant variable as shape.

上级 d2630156
......@@ -69,6 +69,23 @@ def conv2d(input, filters, image_shape=None, filter_shape=None,
:return: set of feature maps generated by convolutional layer. Tensor is of shape
(batch size, nb filters, output row, output col)
"""
#accept Constant value for image_shape and filter_shape.
if image_shape is not None:
image_shape = list(image_shape)
for i in range(len(image_shape)):
if image_shape[i] is not None:
image_shape[i] = tensor.get_constant_value(tensor.as_tensor_variable(image_shape[i]))
assert str(image_shape[i].dtype).startswith('int')
image_shape[i] = int(image_shape[i])
if filter_shape is not None:
filter_shape = list(filter_shape)
for i in range(len(filter_shape)):
if filter_shape[i] is not None:
filter_shape[i] = tensor.get_constant_value(tensor.as_tensor_variable(filter_shape[i]))
assert str(filter_shape[i].dtype).startswith('int')
filter_shape[i] = int(filter_shape[i])
if image_shape and filter_shape:
try:
assert image_shape[1]==filter_shape[1]
......
......@@ -25,9 +25,9 @@ class TestConv2D(unittest.TestCase):
verify_grad=True, should_raise=False):
if N_image_shape is None:
N_image_shape = image_shape
N_image_shape = [T.get_constant_value(T.as_tensor_variable(x)) for x in image_shape]
if N_filter_shape is None:
N_filter_shape = filter_shape
N_filter_shape = [T.get_constant_value(T.as_tensor_variable(x)) for x in filter_shape]
if not input:
input = self.input
......@@ -203,6 +203,19 @@ class TestConv2D(unittest.TestCase):
self.validate((3,2,7,5), (5,2,2,3), 'full', subsample=(2,2))
self.validate((3,2,7,5), (5,2,2,3), 'valid', subsample=(2,1))
def test_shape_Constant_tensor(self):
"""
Tests convolution where the {image,filter}_shape is a Constant tensor.
"""
as_t=T.as_tensor_variable
self.validate((as_t(3),as_t(2),as_t(7),as_t(5)), (5,2,2,3), 'valid')
self.validate(as_t([3,2,7,5]), (5,2,2,3), 'valid')
self.validate(as_t((3,2,7,5)), (5,2,2,3), 'valid')
self.validate((3,2,7,5), (as_t(5),as_t(2),as_t(2),as_t(3)), 'valid')
self.validate((3,2,7,5), as_t([5,2,2,3]), 'valid')
self.validate((3,2,7,5), as_t((5,2,2,3)), 'valid')
self.validate(as_t([3,2,7,5]), as_t([5,2,2,3]), 'full')
def test_invalid_filter_shape(self):
"""
Tests scenario where filter_shape[1] != input_shape[1]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论