提交 516bb822 authored 作者: Frederic's avatar Frederic

Make theano.tensor.signal.conv2d(2d,2d) output 2d answer.

上级 9625521b
...@@ -82,7 +82,9 @@ def conv2d(input, filters, image_shape=None, filter_shape=None, ...@@ -82,7 +82,9 @@ def conv2d(input, filters, image_shape=None, filter_shape=None,
output = op(input4D, filters4D) output = op(input4D, filters4D)
# flatten to 3D tensor if convolving with single filter or single image # flatten to 3D tensor if convolving with single filter or single image
if input.ndim==2 or filters.ndim==2: if input.ndim == 2 and filters.ndim == 2:
output = tensor.flatten(output.T, outdim=2).T
elif input.ndim == 2 or filters.ndim == 2:
output = tensor.flatten(output.T, outdim=3).T output = tensor.flatten(output.T, outdim=3).T
return output return output
...@@ -17,7 +17,7 @@ class TestSignalConv2D(unittest.TestCase): ...@@ -17,7 +17,7 @@ class TestSignalConv2D(unittest.TestCase):
def setUp(self): def setUp(self):
utt.seed_rng() utt.seed_rng()
def validate(self, image_shape, filter_shape, verify_grad=True): def validate(self, image_shape, filter_shape, out_dim, verify_grad=True):
image_dim = len(image_shape) image_dim = len(image_shape)
filter_dim = len(filter_shape) filter_dim = len(filter_shape)
...@@ -36,6 +36,7 @@ class TestSignalConv2D(unittest.TestCase): ...@@ -36,6 +36,7 @@ class TestSignalConv2D(unittest.TestCase):
def sym_conv2d(input, filters): def sym_conv2d(input, filters):
return conv.conv2d(input, filters) return conv.conv2d(input, filters)
output = sym_conv2d(input, filters) output = sym_conv2d(input, filters)
assert output.ndim == out_dim
theano_conv = theano.function([input, filters], output) theano_conv = theano.function([input, filters], output)
# initialize input and compute result # initialize input and compute result
...@@ -90,10 +91,10 @@ class TestSignalConv2D(unittest.TestCase): ...@@ -90,10 +91,10 @@ class TestSignalConv2D(unittest.TestCase):
theano.config.cxx == ""): theano.config.cxx == ""):
raise SkipTest("conv2d tests need SciPy or a c++ compiler") raise SkipTest("conv2d tests need SciPy or a c++ compiler")
self.validate((1, 4, 5), (2, 2, 3), verify_grad=True) self.validate((1, 4, 5), (2, 2, 3), out_dim=4, verify_grad=True)
self.validate((7, 5), (5, 2, 3), verify_grad=False) self.validate((7, 5), (5, 2, 3), out_dim=3, verify_grad=False)
self.validate((3, 7, 5), (2, 3), verify_grad=False) self.validate((3, 7, 5), (2, 3), out_dim=3, verify_grad=False)
self.validate((7, 5), (2, 3), verify_grad=False) self.validate((7, 5), (2, 3), out_dim=2, verify_grad=False)
def test_fail(self): def test_fail(self):
""" """
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论