提交 c3f6b9fb authored 作者: Frederic's avatar Frederic

pep8

上级 7f255792
import sys, time, unittest
import sys
import time
import unittest
import numpy
import theano
......@@ -10,6 +12,7 @@ from theano.tensor.nnet import conv
from theano.tensor.basic import _allclose
class TestConv2D(unittest.TestCase):
def setUp(self):
......@@ -18,16 +21,18 @@ class TestConv2D(unittest.TestCase):
self.filters = T.dtensor4('filters')
def validate(self, image_shape, filter_shape,
border_mode='valid', subsample=(1,1),
border_mode='valid', subsample=(1, 1),
N_image_shape=None, N_filter_shape=None,
input=None, filters=None,
unroll_batch=None, unroll_kern=None, unroll_patch=None,
verify_grad=True, should_raise=False):
if N_image_shape is None:
N_image_shape = [T.get_constant_value(T.as_tensor_variable(x)) for x in image_shape]
N_image_shape = [T.get_constant_value(T.
as_tensor_variable(x)) for x in image_shape]
if N_filter_shape is None:
N_filter_shape = [T.get_constant_value(T.as_tensor_variable(x)) for x in filter_shape]
N_filter_shape = [T.get_constant_value(T.
as_tensor_variable(x)) for x in filter_shape]
if not input:
input = self.input
......@@ -47,48 +52,53 @@ class TestConv2D(unittest.TestCase):
theano_conv = theano.function([input, filters], output)
# initialize input and compute result
image_data = numpy.random.random(N_image_shape)
image_data = numpy.random.random(N_image_shape)
filter_data = numpy.random.random(N_filter_shape)
try:
theano_output = theano_conv(image_data, filter_data)
except ValueError:
if not should_raise: raise
if not should_raise:
raise
return
else:
if should_raise: raise Exception("ConvOp should have generated an error")
if should_raise:
raise Exception(
"ConvOp should have generated an error")
############# REFERENCE IMPLEMENTATION ############
s = 1.
orig_image_data = image_data
if border_mode is not 'full': s = -1.
if border_mode is not 'full':
s = -1.
out_shape2d = numpy.array(N_image_shape[-2:]) +\
s*numpy.array(N_filter_shape[-2:]) - s
s * numpy.array(N_filter_shape[-2:]) - s
out_shape2d = numpy.ceil(out_shape2d / numpy.array(subsample))
out_shape = (N_image_shape[0],N_filter_shape[0]) + tuple(out_shape2d)
out_shape = (N_image_shape[0], N_filter_shape[0]) + tuple(out_shape2d)
ref_output = numpy.zeros(out_shape)
# loop over output feature maps
ref_output.fill(0)
if border_mode=='full':
image_data2 = numpy.zeros((N_image_shape[0],N_image_shape[1],
N_image_shape[2]+2*N_filter_shape[2]-2,
N_image_shape[3]+2*N_filter_shape[3]-2))
image_data2[:,:,N_filter_shape[2]-1:N_filter_shape[2]-1+N_image_shape[2],
N_filter_shape[3]-1:N_filter_shape[3]-1+N_image_shape[3]] = image_data
if border_mode == 'full':
image_data2 = numpy.zeros((N_image_shape[0], N_image_shape[1],
N_image_shape[2] + 2 * N_filter_shape[2] - 2,
N_image_shape[3] + 2 * N_filter_shape[3] - 2))
image_data2[:, :, N_filter_shape[2] - 1:N_filter_shape[2] - 1 + N_image_shape[2],
N_filter_shape[3] - 1:N_filter_shape[3] - 1 + N_image_shape[3]] = image_data
image_data = image_data2
N_image_shape = image_data.shape
for bb in range(N_image_shape[0]):
for nn in range(N_filter_shape[0]):
for im0 in range(N_image_shape[1]):
filter2d = filter_data[nn,im0,:,:]
image2d = image_data[bb,im0,:,:]
filter2d = filter_data[nn, im0, :, :]
image2d = image_data[bb, im0, :, :]
for row in range(ref_output.shape[2]):
irow = row * subsample[0]#image row
irow = row * subsample[0] # image row
for col in range(ref_output.shape[3]):
icol = col * subsample[1]#image col
ref_output[bb,nn,row,col] += (image2d[irow:irow+N_filter_shape[2],
icol:icol+N_filter_shape[3]]*filter2d[::-1,::-1]
).sum()
icol = col * subsample[1] # image col
ref_output[bb, nn, row, col] += (image2d[
irow:irow + N_filter_shape[2],
icol:icol + N_filter_shape[3]] * filter2d[::-1,::-1]
).sum()
self.assertTrue(_allclose(theano_output, ref_output))
......@@ -96,136 +106,162 @@ class TestConv2D(unittest.TestCase):
if verify_grad:
utt.verify_grad(sym_conv2d, [orig_image_data, filter_data])
def test_basic1(self):
"""Tests that basic convolutions work for odd and even
dimensions of image and filter shapes, as well as rectangular
images and filters.
"""
Tests that basic convolutions work for odd and even dimensions of image and filter
shapes, as well as rectangular images and filters.
"""
self.validate((2,2,3,3), (2,2,2,2), 'valid', verify_grad=False)
self.validate((2, 2, 3, 3), (2, 2, 2, 2), 'valid', verify_grad=False)
def test_basic(self):
"""Tests that basic convolutions work for odd and even
dimensions of image and filter shapes, as well as rectangular
images and filters.
"""
Tests that basic convolutions work for odd and even dimensions of image and filter
shapes, as well as rectangular images and filters.
"""
self.validate((3,2,8,8), (4,2,5,5), 'valid', verify_grad=False)
self.validate((3,2,7,5), (5,2,2,3), 'valid')
self.validate((3,2,7,5), (5,2,3,2), 'valid', verify_grad=False)
self.validate((3,2,8,8), (4,2,5,5), 'full', verify_grad=False)
self.validate((3,2,7,5), (5,2,2,3), 'full')
self.validate((3, 2, 8, 8), (4, 2, 5, 5), 'valid', verify_grad=False)
self.validate((3, 2, 7, 5), (5, 2, 2, 3), 'valid')
self.validate((3, 2, 7, 5), (5, 2, 3, 2), 'valid', verify_grad=False)
self.validate((3, 2, 8, 8), (4, 2, 5, 5), 'full', verify_grad=False)
self.validate((3, 2, 7, 5), (5, 2, 2, 3), 'full')
# test filter same size as input
def test_img_kernel_same_shape(self):
self.validate((3,2,3,3), (4,2,3,3), 'full')
self.validate((3,2,3,3), (4,2,3,3), 'valid')
self.validate((3, 2, 3, 3), (4, 2, 3, 3), 'full')
self.validate((3, 2, 3, 3), (4, 2, 3, 3), 'valid')
def test_unroll_patch_true(self):
"""
Test basic convs with True.
"""
self.validate((3,2,7,5), (5,2,2,3), 'valid', unroll_patch=True)
self.validate((3,2,7,5), (5,2,2,3), 'full', unroll_patch=True)
self.validate((3,2,3,3), (4,2,3,3), 'valid', unroll_patch=True, verify_grad=False)
self.validate((3, 2, 7, 5), (5, 2, 2, 3), 'valid', unroll_patch=True)
self.validate((3, 2, 7, 5), (5, 2, 2, 3), 'full', unroll_patch=True)
self.validate((3, 2, 3, 3), (4, 2, 3, 3), 'valid',
unroll_patch=True, verify_grad=False)
def test_unroll_patch_false(self):
"""
Test basic convs with False.
"""
self.validate((3,2,7,5), (5,2,2,3), 'valid', unroll_patch=False)
self.validate((3,2,7,5), (5,2,2,3), 'full', unroll_patch=False)
self.validate((3,2,3,3), (4,2,3,3), 'valid', unroll_patch=False, verify_grad=False)
self.validate((3, 2, 7, 5), (5, 2, 2, 3), 'valid', unroll_patch=False)
self.validate((3, 2, 7, 5), (5, 2, 2, 3), 'full', unroll_patch=False)
self.validate((3, 2, 3, 3), (4, 2, 3, 3), 'valid',
unroll_patch=False, verify_grad=False)
def test_unroll_patch_true_fail(self):
"""
Test basic convs with True.
"""
self.validate((3,2,7,5), (5,2,2,3), 'valid', unroll_patch=True,
N_image_shape=(1,3,3,3), N_filter_shape=(6,3,2,2), should_raise=True)
self.validate((3,2,7,5), (5,2,2,3), 'full', unroll_patch=True,
N_image_shape=(1,3,3,3), N_filter_shape=(6,3,2,2), should_raise=True)
self.validate((3,2,3,3), (4,2,3,3), 'valid', unroll_patch=True,
N_image_shape=(1,3,3,3), N_filter_shape=(6,3,2,2), should_raise=True)
self.validate((3, 2, 7, 5), (5, 2, 2, 3), 'valid', unroll_patch=True,
N_image_shape=(1, 3, 3, 3), N_filter_shape=(6, 3, 2, 2),
should_raise=True)
self.validate((3, 2, 7, 5), (5, 2, 2, 3), 'full', unroll_patch=True,
N_image_shape=(1, 3, 3, 3), N_filter_shape=(6, 3, 2, 2),
should_raise=True)
self.validate((3, 2, 3, 3), (4, 2, 3, 3), 'valid', unroll_patch=True,
N_image_shape=(1, 3, 3, 3), N_filter_shape=(6, 3, 2, 2),
should_raise=True)
def test_unroll_special(self):
"""
(unroll_kern, unroll_batch) in (0,1),(1,0) is special case.
"""
self.validate((6,2,3,3), (3,2,2,2), 'valid', unroll_batch=1)
self.validate((6, 2, 3, 3), (3, 2, 2, 2), 'valid', unroll_batch=1)
def test_unroll_batch(self):
"""
Test mini-batch unrolling for various legal values.
"""
# mini-batch of size 6 is multiple of 2 and 3. Should work.
self.validate((6,2,3,3), (3,2,2,2), 'valid', unroll_batch=2, verify_grad=False)
self.validate((6,2,3,3), (3,2,2,2), 'valid', unroll_batch=3, verify_grad=False)
self.validate((6, 2, 3, 3), (3, 2, 2, 2), 'valid',
unroll_batch=2, verify_grad=False)
self.validate((6, 2, 3, 3), (3, 2, 2, 2), 'valid',
unroll_batch=3, verify_grad=False)
def test_unroll_kern(self):
"""
Test kernel unrolling for various legal values.
"""
# 6 filters is a multiple of 2 and 3. Should work.
self.validate((2,3,3,3), (6,3,2,2), 'valid', unroll_kern=2, verify_grad=False)
self.validate((2,3,3,3), (6,3,2,2), 'valid', unroll_kern=3, verify_grad=False)
self.validate((2, 3, 3, 3), (6, 3, 2, 2), 'valid', unroll_kern=2,
verify_grad=False)
self.validate((2, 3, 3, 3), (6, 3, 2, 2), 'valid', unroll_kern=3,
verify_grad=False)
def test_unroll_batch_kern(self):
"""
Test mini-batch unrolling with kernel unrolling for various legal values.
"""Test mini-batch unrolling with kernel unrolling for various
legal values.
"""
# mini-batch of size 6 is multiple of 2 and 3. Should work.
self.validate((6,2,3,3), (3,2,2,2), 'valid', unroll_batch=2, unroll_kern=3, verify_grad=False)
self.validate((6,2,3,3), (3,2,2,2), 'valid', unroll_batch=3, unroll_kern=3, verify_grad=False)
self.validate((6, 2, 3, 3), (3, 2, 2, 2), 'valid',
unroll_batch=2, unroll_kern=3, verify_grad=False)
self.validate((6, 2, 3, 3), (3, 2, 2, 2), 'valid',
unroll_batch=3, unroll_kern=3, verify_grad=False)
# 6 filters is a multiple of 2 and 3. Should work.
self.validate((2,3,3,3), (6,3,2,2), 'valid', unroll_batch=2, unroll_kern=2, verify_grad=False)
self.validate((2,3,3,3), (6,3,2,2), 'valid', unroll_batch=2, unroll_kern=3, verify_grad=False)
self.validate((2, 3, 3, 3), (6, 3, 2, 2), 'valid',
unroll_batch=2, unroll_kern=2, verify_grad=False)
self.validate((2, 3, 3, 3), (6, 3, 2, 2), 'valid',
unroll_batch=2, unroll_kern=3, verify_grad=False)
def test_unroll_batch_kern_fail(self):
"""
Test mini-batch unrolling with kernel unrolling for various legal values, but pass bad input.
All those test must generate errors
"""Test mini-batch unrolling with kernel unrolling for various
legal values, but pass bad input. All those test must
generate errors
"""
# mini-batch of size 6 is multiple of 2 and 3. Should work.
self.validate((6,2,3,3), (3,2,2,2), 'valid', unroll_batch=2, unroll_kern=3,
N_image_shape=(7,2,3,3), N_filter_shape=(3,2,2,2), should_raise=True)
self.validate((6,2,3,3), (3,2,2,2), 'valid', unroll_batch=3, unroll_kern=3,
N_image_shape=(6,2,3,3), N_filter_shape=(4,2,2,2), should_raise=True)
self.validate((2,3,3,3), (6,3,2,2), 'valid', unroll_batch=2, unroll_kern=2,
N_image_shape=(1,3,3,3), N_filter_shape=(6,3,2,2), should_raise=True)
self.validate((2,3,3,3), (6,3,2,2), 'valid', unroll_batch=2, unroll_kern=3,
N_image_shape=(2,3,3,3), N_filter_shape=(5,3,2,2), should_raise=True)
self.validate((6, 2, 3, 3), (3, 2, 2, 2), 'valid',
unroll_batch=2, unroll_kern=3,
N_image_shape=(7, 2, 3, 3), N_filter_shape=(3, 2, 2, 2),
should_raise=True)
self.validate((6, 2, 3, 3), (3, 2, 2, 2), 'valid',
unroll_batch=3, unroll_kern=3,
N_image_shape=(6, 2, 3, 3), N_filter_shape=(4, 2, 2, 2),
should_raise=True)
self.validate((2, 3, 3, 3), (6, 3, 2, 2), 'valid',
unroll_batch=2, unroll_kern=2,
N_image_shape=(1, 3, 3, 3), N_filter_shape=(6, 3, 2, 2),
should_raise=True)
self.validate((2, 3, 3, 3), (6, 3, 2, 2), 'valid',
unroll_batch=2, unroll_kern=3,
N_image_shape=(2, 3, 3, 3), N_filter_shape=(5, 3, 2, 2),
should_raise=True)
def test_subsample(self):
"""
Tests convolution where subsampling != (1,1)
"""
self.validate((3,2,7,5), (5,2,2,3), 'valid', subsample=(2,2))
self.validate((3,2,7,5), (5,2,2,3), 'full', subsample=(2,2))
self.validate((3,2,7,5), (5,2,2,3), 'valid', subsample=(2,1))
self.validate((3, 2, 7, 5), (5, 2, 2, 3), 'valid', subsample=(2, 2))
self.validate((3, 2, 7, 5), (5, 2, 2, 3), 'full', subsample=(2, 2))
self.validate((3, 2, 7, 5), (5, 2, 2, 3), 'valid', subsample=(2, 1))
# Fails as of 2012-04-12
self.assertRaises(NotImplementedError, self.validate, (1, 1, 6, 6),
(1, 1, 3, 3), 'valid', subsample=(3, 3))
def test_shape_Constant_tensor(self):
"""
Tests convolution where the {image,filter}_shape is a Constant tensor.
"""
as_t=T.as_tensor_variable
self.validate((as_t(3),as_t(2),as_t(7),as_t(5)), (5,2,2,3), 'valid')
self.validate(as_t([3,2,7,5]), (5,2,2,3), 'valid')
self.validate(as_t((3,2,7,5)), (5,2,2,3), 'valid')
self.validate((3,2,7,5), (as_t(5),as_t(2),as_t(2),as_t(3)), 'valid')
self.validate((3,2,7,5), as_t([5,2,2,3]), 'valid')
self.validate((3,2,7,5), as_t((5,2,2,3)), 'valid')
self.validate(as_t([3,2,7,5]), as_t([5,2,2,3]), 'full')
as_t = T.as_tensor_variable
self.validate((as_t(3), as_t(2), as_t(7), as_t(5)), (5, 2,
2, 3), 'valid')
self.validate(as_t([3, 2, 7, 5]), (5, 2, 2, 3), 'valid')
self.validate(as_t((3, 2, 7, 5)), (5, 2, 2, 3), 'valid')
self.validate((3, 2, 7, 5), (as_t(5), as_t(2), as_t(2),
as_t(3)), 'valid')
self.validate((3, 2, 7, 5), as_t([5, 2, 2, 3]), 'valid')
self.validate((3, 2, 7, 5), as_t((5, 2, 2, 3)), 'valid')
self.validate(as_t([3, 2, 7, 5]), as_t([5, 2, 2, 3]), 'full')
def test_invalid_filter_shape(self):
"""
Tests scenario where filter_shape[1] != input_shape[1]
"""
self.assertRaises(AssertionError, self.validate, (3,2,8,8), (4,3,5,5),
self.assertRaises(AssertionError, self.validate,
(3, 2, 8, 8), (4, 3, 5, 5),
'valid')
def test_invalid_input_shape(self):
......@@ -292,23 +328,24 @@ class TestConv2D(unittest.TestCase):
Test convolutions for various pieces of missing info.
"""
self.validate(None, None,
N_image_shape=(3,2,8,8),
N_filter_shape=(4,2,5,5))
self.validate((3,2,None,None), None,
N_image_shape=(3,2,8,8),
N_filter_shape=(4,2,5,5))
self.validate((None,2,None,None), (None,2,5,5),
N_image_shape=(3,2,8,8),
N_filter_shape=(4,2,5,5))
N_image_shape=(3, 2, 8, 8),
N_filter_shape=(4, 2, 5, 5))
self.validate((3, 2, None, None), None,
N_image_shape=(3, 2, 8, 8),
N_filter_shape=(4, 2, 5, 5))
self.validate((None, 2, None, None), (None, 2, 5, 5),
N_image_shape=(3, 2, 8, 8),
N_filter_shape=(4, 2, 5, 5))
def test_full_mode(self):
"""
Tests basic convolution in full mode and case where filter
is larger than the input image.
"""
self.validate((3,2,5,5), (4,2,8,8), 'full')
self.validate((3, 2, 5, 5), (4, 2, 8, 8), 'full')
def f():
self.validate((3,2,5,5), (4,2,8,8), 'valid')
self.validate((3, 2, 5, 5), (4, 2, 8, 8), 'valid')
self.assertRaises(Exception, f)
def test_wrong_input(self):
......@@ -329,4 +366,5 @@ class TestConv2D(unittest.TestCase):
crashed in this following case. I changed the c code to don't hit
gcc bug. So it should not crash anymore
"""
self.validate((1,10,213,129), (46,10,212,1), 'valid', verify_grad=False)
self.validate((1, 10, 213, 129), (46, 10, 212, 1), 'valid',
verify_grad=False)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论