提交 e57bfe5f authored 作者: Frederic Bastien's avatar Frederic Bastien

white space fix.

上级 49c2fb0b
...@@ -20,7 +20,7 @@ class TestConv2D(unittest.TestCase): ...@@ -20,7 +20,7 @@ class TestConv2D(unittest.TestCase):
def validate(self, image_shape, filter_shape, def validate(self, image_shape, filter_shape,
border_mode='valid', subsample=(1,1), border_mode='valid', subsample=(1,1),
N_image_shape=None, N_filter_shape=None, N_image_shape=None, N_filter_shape=None,
input=None, filters=None, input=None, filters=None,
unroll_batch=None, unroll_kern=None, unroll_patch=None, unroll_batch=None, unroll_kern=None, unroll_patch=None,
verify_grad=True, should_raise=False): verify_grad=True, should_raise=False):
...@@ -28,14 +28,14 @@ class TestConv2D(unittest.TestCase): ...@@ -28,14 +28,14 @@ class TestConv2D(unittest.TestCase):
N_image_shape = [T.get_constant_value(T.as_tensor_variable(x)) for x in image_shape] N_image_shape = [T.get_constant_value(T.as_tensor_variable(x)) for x in image_shape]
if N_filter_shape is None: if N_filter_shape is None:
N_filter_shape = [T.get_constant_value(T.as_tensor_variable(x)) for x in filter_shape] N_filter_shape = [T.get_constant_value(T.as_tensor_variable(x)) for x in filter_shape]
if not input: if not input:
input = self.input input = self.input
if not filters: if not filters:
filters = self.filters filters = self.filters
############# THEANO IMPLEMENTATION ############ ############# THEANO IMPLEMENTATION ############
# we create a symbolic function so that verify_grad can work # we create a symbolic function so that verify_grad can work
def sym_conv2d(input, filters): def sym_conv2d(input, filters):
# define theano graph and function # define theano graph and function
...@@ -45,7 +45,7 @@ class TestConv2D(unittest.TestCase): ...@@ -45,7 +45,7 @@ class TestConv2D(unittest.TestCase):
output = sym_conv2d(input, filters) output = sym_conv2d(input, filters)
theano_conv = theano.function([input, filters], output) theano_conv = theano.function([input, filters], output)
# initialize input and compute result # initialize input and compute result
image_data = numpy.random.random(N_image_shape) image_data = numpy.random.random(N_image_shape)
filter_data = numpy.random.random(N_filter_shape) filter_data = numpy.random.random(N_filter_shape)
...@@ -140,11 +140,11 @@ class TestConv2D(unittest.TestCase): ...@@ -140,11 +140,11 @@ class TestConv2D(unittest.TestCase):
""" """
Test basic convs with True. Test basic convs with True.
""" """
self.validate((3,2,7,5), (5,2,2,3), 'valid', unroll_patch=True, self.validate((3,2,7,5), (5,2,2,3), 'valid', unroll_patch=True,
N_image_shape=(1,3,3,3), N_filter_shape=(6,3,2,2), should_raise=True) N_image_shape=(1,3,3,3), N_filter_shape=(6,3,2,2), should_raise=True)
self.validate((3,2,7,5), (5,2,2,3), 'full', unroll_patch=True, self.validate((3,2,7,5), (5,2,2,3), 'full', unroll_patch=True,
N_image_shape=(1,3,3,3), N_filter_shape=(6,3,2,2), should_raise=True) N_image_shape=(1,3,3,3), N_filter_shape=(6,3,2,2), should_raise=True)
self.validate((3,2,3,3), (4,2,3,3), 'valid', unroll_patch=True, self.validate((3,2,3,3), (4,2,3,3), 'valid', unroll_patch=True,
N_image_shape=(1,3,3,3), N_filter_shape=(6,3,2,2), should_raise=True) N_image_shape=(1,3,3,3), N_filter_shape=(6,3,2,2), should_raise=True)
def test_unroll_special(self): def test_unroll_special(self):
...@@ -228,19 +228,19 @@ class TestConv2D(unittest.TestCase): ...@@ -228,19 +228,19 @@ class TestConv2D(unittest.TestCase):
""" """
Test convolutions for various pieces of missing info. Test convolutions for various pieces of missing info.
""" """
self.validate(None, None, self.validate(None, None,
N_image_shape=(3,2,8,8), N_image_shape=(3,2,8,8),
N_filter_shape=(4,2,5,5)) N_filter_shape=(4,2,5,5))
self.validate((3,2,None,None), None, self.validate((3,2,None,None), None,
N_image_shape=(3,2,8,8), N_image_shape=(3,2,8,8),
N_filter_shape=(4,2,5,5)) N_filter_shape=(4,2,5,5))
self.validate((None,2,None,None), (None,2,5,5), self.validate((None,2,None,None), (None,2,5,5),
N_image_shape=(3,2,8,8), N_image_shape=(3,2,8,8),
N_filter_shape=(4,2,5,5)) N_filter_shape=(4,2,5,5))
def test_full_mode(self): def test_full_mode(self):
""" """
Tests basic convolution in full mode and case where filter Tests basic convolution in full mode and case where filter
is larger than the input image. is larger than the input image.
""" """
self.validate((3,2,5,5), (4,2,8,8), 'full') self.validate((3,2,5,5), (4,2,8,8), 'full')
...@@ -256,26 +256,26 @@ class TestConv2D(unittest.TestCase): ...@@ -256,26 +256,26 @@ class TestConv2D(unittest.TestCase):
self.validate((3,2,8,8), (4,2,5,5), 'valid', input = T.dmatrix()) self.validate((3,2,8,8), (4,2,5,5), 'valid', input = T.dmatrix())
# should never reach here # should never reach here
self.fail() self.fail()
except: except:
pass pass
try: try:
self.validate((3,2,8,8), (4,2,5,5), 'valid', filters = T.dvector()) self.validate((3,2,8,8), (4,2,5,5), 'valid', filters = T.dvector())
# should never reach here # should never reach here
self.fail() self.fail()
except: except:
pass pass
try: try:
self.validate((3,2,8,8), (4,2,5,5), 'valid', input = T.dtensor3()) self.validate((3,2,8,8), (4,2,5,5), 'valid', input = T.dtensor3())
# should never reach here # should never reach here
self.fail() self.fail()
except: except:
pass pass
def test_gcc_crash(self): def test_gcc_crash(self):
""" """
gcc 4.3.0 20080428 (Red Hat 4.3.0-8) gcc 4.3.0 20080428 (Red Hat 4.3.0-8)
crashed in this following case. I changed the c code to don't hit crashed in this following case. I changed the c code to don't hit
gcc bug. So it should not crash anymore gcc bug. So it should not crash anymore
""" """
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论