提交 4d879ee9 authored 作者: Frederic's avatar Frederic

pep8

上级 1415a5e2
......@@ -16,12 +16,13 @@ from theano.sandbox.cuda import CudaNdarrayType, cuda_available
if cuda_available == False:
raise SkipTest('Optional package cuda disabled')
def test_float32_shared_constructor():
npy_row = numpy.zeros((1,10), dtype='float32')
npy_row = numpy.zeros((1, 10), dtype='float32')
def eq(a,b):
return a==b
def eq(a, b):
return a == b
# test that we can create a CudaNdarray
assert (f32sc(npy_row).type == CudaNdarrayType((False, False)))
......@@ -40,37 +41,41 @@ def test_float32_shared_constructor():
# test that we can make non-matrix shared vars
assert eq(
f32sc(numpy.zeros((2,3,4,5), dtype='float32')).type,
CudaNdarrayType((False,)*4))
f32sc(numpy.zeros((2, 3, 4, 5), dtype='float32')).type,
CudaNdarrayType((False,) * 4))
def test_givens():
# Test that you can use a TensorType expression to replace a
# CudaNdarrayType in the givens dictionary.
# This test case uses code mentionned in #757
data = numpy.float32([1,2,3,4])
data = numpy.float32([1, 2, 3, 4])
x = f32sc(data)
y = x**2
f = theano.function([], y, givens={x:x+1})
y = x ** 2
f = theano.function([], y, givens={x: x + 1})
f()
class T_updates(unittest.TestCase):
# Test that you can use a TensorType expression to update a
# CudaNdarrayType in the updates dictionary.
def test_1(self):
data = numpy.float32([1,2,3,4])
data = numpy.float32([1, 2, 3, 4])
x = f32sc(data)
y = x**2
f = theano.function([], y, updates={x:x+1})
y = x ** 2
f = theano.function([], y, updates={x: x + 1})
f()
def test_2(self):
# This test case uses code mentionned in #698
data = numpy.random.rand(10,10).astype('float32')
data = numpy.random.rand(10, 10).astype('float32')
output_var = f32sc(name="output",
value=numpy.zeros((10,10), 'float32'))
value=numpy.zeros((10, 10), 'float32'))
x = tensor.fmatrix('x')
output_updates = {output_var:x**2}
output_givens = {x:data}
output_updates = {output_var: x ** 2}
output_givens = {x: data}
output_func = theano.function(inputs=[], outputs=[],
updates=output_updates, givens=output_givens)
output_func()
......@@ -78,16 +83,16 @@ class T_updates(unittest.TestCase):
def test_3(self):
# Test that broadcastable dimensions don't screw up
# update expressions.
data = numpy.random.rand(10,10).astype('float32')
output_var = f32sc(name="output",
value=numpy.zeros((10,10), 'float32'))
data = numpy.random.rand(10, 10).astype('float32')
output_var = f32sc(name="output", value=data)
# the update_var has type matrix, and the update expression
# is a broadcasted scalar, and that should be allowed.
output_func = theano.function(inputs=[], outputs=[],
updates={output_var:output_var.sum().dimshuffle('x', 'x')})
updates={output_var: output_var.sum().dimshuffle('x', 'x')})
output_func()
class T_ifelse(unittest.TestCase):
def setUp(self):
utt.seed_rng()
......@@ -111,10 +116,10 @@ class T_ifelse(unittest.TestCase):
f = theano.function([cond], out1)
g = theano.function([cond], out2)
assert numpy.all(f(0) == data+1)
assert numpy.all(f(0) == data + 1)
assert numpy.all(f(1) == data)
assert numpy.all(g(0) == data)
assert numpy.all(g(1) == data+1)
assert numpy.all(g(1) == data + 1)
def test_dtype_mismatch(self):
data = self.rng.rand(5).astype('float32')
......@@ -135,7 +140,7 @@ class T_ifelse(unittest.TestCase):
self.assertRaises(TypeError, ifelse, cond, y, x)
def test_broadcast_mismatch(self):
data = self.rng.rand(2,3).astype('float32')
data = self.rng.rand(2, 3).astype('float32')
x = f32sc(data)
print x.broadcastable
y = tensor.frow('y')
......@@ -146,7 +151,7 @@ class T_ifelse(unittest.TestCase):
self.assertRaises(TypeError, ifelse, cond, y, x)
def test_sparse_tensor_error(self):
data = self.rng.rand(2,3).astype('float32')
data = self.rng.rand(2, 3).astype('float32')
x = f32sc(data)
y = sparse.matrix('csc', dtype='float32', name='y')
z = sparse.matrix('csr', dtype='float32', name='z')
......@@ -160,5 +165,3 @@ class T_ifelse(unittest.TestCase):
self.assertRaises((TypeError, ValueError), ifelse, cond, z, x)
self.assertRaises((TypeError, ValueError), ifelse, cond, y, z)
self.assertRaises((TypeError, ValueError), ifelse, cond, z, y)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论