提交 516bb822 authored 作者: Frederic's avatar Frederic

Make theano.tensor.signal.conv2d(2d,2d) output 2d answer.

上级 9625521b
......@@ -82,7 +82,9 @@ def conv2d(input, filters, image_shape=None, filter_shape=None,
output = op(input4D, filters4D)
# flatten to 3D tensor if convolving with single filter or single image
if input.ndim==2 or filters.ndim==2:
if input.ndim == 2 and filters.ndim == 2:
output = tensor.flatten(output.T, outdim=2).T
elif input.ndim == 2 or filters.ndim == 2:
output = tensor.flatten(output.T, outdim=3).T
return output
......@@ -17,7 +17,7 @@ class TestSignalConv2D(unittest.TestCase):
def setUp(self):
utt.seed_rng()
def validate(self, image_shape, filter_shape, verify_grad=True):
def validate(self, image_shape, filter_shape, out_dim, verify_grad=True):
image_dim = len(image_shape)
filter_dim = len(filter_shape)
......@@ -36,6 +36,7 @@ class TestSignalConv2D(unittest.TestCase):
def sym_conv2d(input, filters):
return conv.conv2d(input, filters)
output = sym_conv2d(input, filters)
assert output.ndim == out_dim
theano_conv = theano.function([input, filters], output)
# initialize input and compute result
......@@ -90,10 +91,10 @@ class TestSignalConv2D(unittest.TestCase):
theano.config.cxx == ""):
raise SkipTest("conv2d tests need SciPy or a c++ compiler")
self.validate((1, 4, 5), (2, 2, 3), verify_grad=True)
self.validate((7, 5), (5, 2, 3), verify_grad=False)
self.validate((3, 7, 5), (2, 3), verify_grad=False)
self.validate((7, 5), (2, 3), verify_grad=False)
self.validate((1, 4, 5), (2, 2, 3), out_dim=4, verify_grad=True)
self.validate((7, 5), (5, 2, 3), out_dim=3, verify_grad=False)
self.validate((3, 7, 5), (2, 3), out_dim=3, verify_grad=False)
self.validate((7, 5), (2, 3), out_dim=2, verify_grad=False)
def test_fail(self):
"""
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论