提交 14725ace authored 作者: Eric Larsen's avatar Eric Larsen 提交者: Frederic

testing infer_shape: op ConvOp

上级 fe026b97
...@@ -13,10 +13,10 @@ from theano.tensor.nnet import conv ...@@ -13,10 +13,10 @@ from theano.tensor.nnet import conv
from theano.tensor.basic import _allclose from theano.tensor.basic import _allclose
class TestConv2D(unittest.TestCase): class TestConv2D(utt.InferShapeTester):
def setUp(self): def setUp(self):
utt.seed_rng() super (TestConv2D, self).setUp()
self.input = T.dtensor4('input') self.input = T.dtensor4('input')
self.filters = T.dtensor4('filters') self.filters = T.dtensor4('filters')
...@@ -368,8 +368,7 @@ class TestConv2D(unittest.TestCase): ...@@ -368,8 +368,7 @@ class TestConv2D(unittest.TestCase):
gcc bug. So it should not crash anymore gcc bug. So it should not crash anymore
""" """
self.validate((1, 10, 213, 129), (46, 10, 212, 1), 'valid', self.validate((1, 10, 213, 129), (46, 10, 212, 1), 'valid',
verify_grad=False) verify_grad=False)
self.validate((1, 10, 213, 129), (46, 10, 212, 1), 'valid', verify_grad=False)
def speed(self): def speed(self):
n_calls = 20000 n_calls = 20000
...@@ -407,3 +406,100 @@ class TestConv2D(unittest.TestCase): ...@@ -407,3 +406,100 @@ class TestConv2D(unittest.TestCase):
t2 = time.time() t2 = time.time()
print t2 - t1, print t2 - t1,
print print
def test_infer_shape(self):
# Note: infer_shape is incomplete and thus input and filter shapes
# must be provided explicitly
def rand(*shape):
r = numpy.asarray(numpy.random.rand(*shape), dtype='float64')
return r * 2 - 1
adtens = T.dtensor4()
bdtens = T.dtensor4()
aivec_val = [2, 2, 3, 3]
bivec_val = [2, 2, 2, 2]
adtens_val = rand(*aivec_val)
bdtens_val = rand(*bivec_val)
self._compile_and_check([adtens, bdtens],
[conv.conv2d(adtens, bdtens, aivec_val, bivec_val,
border_mode='valid')], [adtens_val, bdtens_val], conv.ConvOp)
aivec_val = [2, 2, 3, 3]
bivec_val = [2, 2, 2, 2]
adtens_val = rand(*aivec_val)
bdtens_val = rand(*bivec_val)
self._compile_and_check([adtens, bdtens],
[conv.conv2d(adtens, bdtens, aivec_val, bivec_val,
border_mode='full')], [adtens_val, bdtens_val], conv.ConvOp)
aivec_val = [3, 2, 8, 8]
bivec_val = [4, 2, 5, 5]
adtens_val = rand(*aivec_val)
bdtens_val = rand(*bivec_val)
self._compile_and_check([adtens, bdtens],
[conv.conv2d(adtens, bdtens, aivec_val, bivec_val,
border_mode='valid')], [adtens_val, bdtens_val], conv.ConvOp)
aivec_val = [3, 2, 8, 8]
bivec_val = [4, 2, 5, 5]
adtens_val = rand(*aivec_val)
bdtens_val = rand(*bivec_val)
self._compile_and_check([adtens, bdtens],
[conv.conv2d(adtens, bdtens, aivec_val, bivec_val,
border_mode='full')], [adtens_val, bdtens_val], conv.ConvOp)
aivec_val = [3, 2, 7, 5]
bivec_val = [5, 2, 3, 2]
adtens_val = rand(*aivec_val)
bdtens_val = rand(*bivec_val)
self._compile_and_check([adtens, bdtens],
[conv.conv2d(adtens, bdtens, aivec_val, bivec_val,
border_mode='valid')], [adtens_val, bdtens_val], conv.ConvOp)
aivec_val = [3, 2, 7, 5]
bivec_val = [5, 2, 3, 2]
adtens_val = rand(*aivec_val)
bdtens_val = rand(*bivec_val)
self._compile_and_check([adtens, bdtens],
[conv.conv2d(adtens, bdtens, aivec_val, bivec_val,
border_mode='full')], [adtens_val, bdtens_val], conv.ConvOp)
aivec_val = [3, 2, 7, 5]
bivec_val = [5, 2, 2, 3]
adtens_val = rand(*aivec_val)
bdtens_val = rand(*bivec_val)
self._compile_and_check([adtens, bdtens],
[conv.conv2d(adtens, bdtens, aivec_val, bivec_val,
border_mode='valid')], [adtens_val, bdtens_val], conv.ConvOp)
aivec_val = [3, 2, 7, 5]
bivec_val = [5, 2, 2, 3]
adtens_val = rand(*aivec_val)
bdtens_val = rand(*bivec_val)
self._compile_and_check([adtens, bdtens],
[conv.conv2d(adtens, bdtens, aivec_val, bivec_val,
border_mode='full')], [adtens_val, bdtens_val], conv.ConvOp)
aivec_val = [3, 2, 3, 3]
bivec_val = [4, 2, 3, 3]
adtens_val = rand(*aivec_val)
bdtens_val = rand(*bivec_val)
self._compile_and_check([adtens, bdtens],
[conv.conv2d(adtens, bdtens, aivec_val, bivec_val,
border_mode='valid')], [adtens_val, bdtens_val], conv.ConvOp)
aivec_val = [3, 2, 3, 3]
bivec_val = [4, 2, 3, 3]
adtens_val = rand(*aivec_val)
bdtens_val = rand(*bivec_val)
self._compile_and_check([adtens, bdtens],
[conv.conv2d(adtens, bdtens, aivec_val, bivec_val,
border_mode='full')], [adtens_val, bdtens_val], conv.ConvOp)
if __name__ == '__main__':
t = TestConv2D('setUp')
t.setUp()
t.test_infer_shape()
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论