提交 e57bfe5f authored 作者: Frederic Bastien's avatar Frederic Bastien

white space fix.

上级 49c2fb0b
......@@ -20,7 +20,7 @@ class TestConv2D(unittest.TestCase):
def validate(self, image_shape, filter_shape,
border_mode='valid', subsample=(1,1),
N_image_shape=None, N_filter_shape=None,
input=None, filters=None,
input=None, filters=None,
unroll_batch=None, unroll_kern=None, unroll_patch=None,
verify_grad=True, should_raise=False):
......@@ -28,14 +28,14 @@ class TestConv2D(unittest.TestCase):
N_image_shape = [T.get_constant_value(T.as_tensor_variable(x)) for x in image_shape]
if N_filter_shape is None:
N_filter_shape = [T.get_constant_value(T.as_tensor_variable(x)) for x in filter_shape]
if not input:
input = self.input
if not filters:
filters = self.filters
############# THEANO IMPLEMENTATION ############
# we create a symbolic function so that verify_grad can work
def sym_conv2d(input, filters):
# define theano graph and function
......@@ -45,7 +45,7 @@ class TestConv2D(unittest.TestCase):
output = sym_conv2d(input, filters)
theano_conv = theano.function([input, filters], output)
# initialize input and compute result
image_data = numpy.random.random(N_image_shape)
filter_data = numpy.random.random(N_filter_shape)
......@@ -140,11 +140,11 @@ class TestConv2D(unittest.TestCase):
"""
Test basic convs with True.
"""
self.validate((3,2,7,5), (5,2,2,3), 'valid', unroll_patch=True,
self.validate((3,2,7,5), (5,2,2,3), 'valid', unroll_patch=True,
N_image_shape=(1,3,3,3), N_filter_shape=(6,3,2,2), should_raise=True)
self.validate((3,2,7,5), (5,2,2,3), 'full', unroll_patch=True,
self.validate((3,2,7,5), (5,2,2,3), 'full', unroll_patch=True,
N_image_shape=(1,3,3,3), N_filter_shape=(6,3,2,2), should_raise=True)
self.validate((3,2,3,3), (4,2,3,3), 'valid', unroll_patch=True,
self.validate((3,2,3,3), (4,2,3,3), 'valid', unroll_patch=True,
N_image_shape=(1,3,3,3), N_filter_shape=(6,3,2,2), should_raise=True)
def test_unroll_special(self):
......@@ -228,19 +228,19 @@ class TestConv2D(unittest.TestCase):
"""
Test convolutions for various pieces of missing info.
"""
self.validate(None, None,
N_image_shape=(3,2,8,8),
self.validate(None, None,
N_image_shape=(3,2,8,8),
N_filter_shape=(4,2,5,5))
self.validate((3,2,None,None), None,
N_image_shape=(3,2,8,8),
N_image_shape=(3,2,8,8),
N_filter_shape=(4,2,5,5))
self.validate((None,2,None,None), (None,2,5,5),
N_image_shape=(3,2,8,8),
N_image_shape=(3,2,8,8),
N_filter_shape=(4,2,5,5))
def test_full_mode(self):
"""
Tests basic convolution in full mode and case where filter
Tests basic convolution in full mode and case where filter
is larger than the input image.
"""
self.validate((3,2,5,5), (4,2,8,8), 'full')
......@@ -256,26 +256,26 @@ class TestConv2D(unittest.TestCase):
self.validate((3,2,8,8), (4,2,5,5), 'valid', input = T.dmatrix())
# should never reach here
self.fail()
except:
except:
pass
try:
self.validate((3,2,8,8), (4,2,5,5), 'valid', filters = T.dvector())
# should never reach here
self.fail()
except:
except:
pass
try:
self.validate((3,2,8,8), (4,2,5,5), 'valid', input = T.dtensor3())
# should never reach here
self.fail()
except:
except:
pass
def test_gcc_crash(self):
"""
gcc 4.3.0 20080428 (Red Hat 4.3.0-8)
crashed in this following case. I changed the c code to don't hit
gcc bug. So it should not crash anymore
"""
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论