提交 8d9e0253 authored 作者: Vincent Dumoulin's avatar Vincent Dumoulin

Add unit test

上级 335f708f
...@@ -9,7 +9,8 @@ import theano ...@@ -9,7 +9,8 @@ import theano
from theano import tensor from theano import tensor
from theano.gof.opt import check_stack_trace from theano.gof.opt import check_stack_trace
from theano.tests import unittest_tools as utt from theano.tests import unittest_tools as utt
from theano.tensor.nnet import corr, corr3d, abstract_conv as conv from theano.tensor.nnet import (corr, corr3d, conv2d_transpose,
abstract_conv as conv)
from theano.tensor.nnet.abstract_conv import (get_conv_output_shape, from theano.tensor.nnet.abstract_conv import (get_conv_output_shape,
get_conv_gradweights_shape, get_conv_gradweights_shape,
get_conv_gradinputs_shape, get_conv_gradinputs_shape,
...@@ -1548,3 +1549,32 @@ class TestBilinearUpsampling(unittest.TestCase): ...@@ -1548,3 +1549,32 @@ class TestBilinearUpsampling(unittest.TestCase):
f_1D = theano.function([], mat_1D, mode=self.compile_mode) f_1D = theano.function([], mat_1D, mode=self.compile_mode)
f_2D = theano.function([], mat_2D, mode=self.compile_mode) f_2D = theano.function([], mat_2D, mode=self.compile_mode)
utt.assert_allclose(f_1D(), f_2D(), rtol=1e-06) utt.assert_allclose(f_1D(), f_2D(), rtol=1e-06)
class TestConv2dTranspose(unittest.TestCase):
def test_interface(self):
"""Test conv2d_transpose wrapper.
This method tests that the order of the filter's
axes expected by the function produces the correct
output shape.
"""
output = theano.function(
inputs=[],
outputs=conv2d_transpose(input=tensor.ones((2, 2, 4, 4)),
filters=tensor.ones((2, 1, 4, 4)),
output_shape=(2, 1, 10, 10),
input_dilation=(2, 2)))()
expected_output = numpy.array(
[[[[2, 2, 4, 4, 4, 4, 4, 4, 2, 2],
[2, 2, 4, 4, 4, 4, 4, 4, 2, 2],
[4, 4, 8, 8, 8, 8, 8, 8, 4, 4],
[4, 4, 8, 8, 8, 8, 8, 8, 4, 4],
[4, 4, 8, 8, 8, 8, 8, 8, 4, 4],
[4, 4, 8, 8, 8, 8, 8, 8, 4, 4],
[4, 4, 8, 8, 8, 8, 8, 8, 4, 4],
[4, 4, 8, 8, 8, 8, 8, 8, 4, 4],
[2, 2, 4, 4, 4, 4, 4, 4, 2, 2],
[2, 2, 4, 4, 4, 4, 4, 4, 2, 2]]]] * 2)
numpy.testing.assert_equal(output, expected_output)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论