提交 4d879ee9 authored 作者: Frederic's avatar Frederic

pep8

上级 1415a5e2
...@@ -16,12 +16,13 @@ from theano.sandbox.cuda import CudaNdarrayType, cuda_available ...@@ -16,12 +16,13 @@ from theano.sandbox.cuda import CudaNdarrayType, cuda_available
if cuda_available == False: if cuda_available == False:
raise SkipTest('Optional package cuda disabled') raise SkipTest('Optional package cuda disabled')
def test_float32_shared_constructor(): def test_float32_shared_constructor():
npy_row = numpy.zeros((1,10), dtype='float32') npy_row = numpy.zeros((1, 10), dtype='float32')
def eq(a,b): def eq(a, b):
return a==b return a == b
# test that we can create a CudaNdarray # test that we can create a CudaNdarray
assert (f32sc(npy_row).type == CudaNdarrayType((False, False))) assert (f32sc(npy_row).type == CudaNdarrayType((False, False)))
...@@ -40,37 +41,41 @@ def test_float32_shared_constructor(): ...@@ -40,37 +41,41 @@ def test_float32_shared_constructor():
# test that we can make non-matrix shared vars # test that we can make non-matrix shared vars
assert eq( assert eq(
f32sc(numpy.zeros((2,3,4,5), dtype='float32')).type, f32sc(numpy.zeros((2, 3, 4, 5), dtype='float32')).type,
CudaNdarrayType((False,)*4)) CudaNdarrayType((False,) * 4))
def test_givens(): def test_givens():
# Test that you can use a TensorType expression to replace a # Test that you can use a TensorType expression to replace a
# CudaNdarrayType in the givens dictionary. # CudaNdarrayType in the givens dictionary.
# This test case uses code mentionned in #757 # This test case uses code mentionned in #757
data = numpy.float32([1,2,3,4]) data = numpy.float32([1, 2, 3, 4])
x = f32sc(data) x = f32sc(data)
y = x**2 y = x ** 2
f = theano.function([], y, givens={x:x+1}) f = theano.function([], y, givens={x: x + 1})
f()
class T_updates(unittest.TestCase): class T_updates(unittest.TestCase):
# Test that you can use a TensorType expression to update a # Test that you can use a TensorType expression to update a
# CudaNdarrayType in the updates dictionary. # CudaNdarrayType in the updates dictionary.
def test_1(self): def test_1(self):
data = numpy.float32([1,2,3,4]) data = numpy.float32([1, 2, 3, 4])
x = f32sc(data) x = f32sc(data)
y = x**2 y = x ** 2
f = theano.function([], y, updates={x:x+1}) f = theano.function([], y, updates={x: x + 1})
f()
def test_2(self): def test_2(self):
# This test case uses code mentionned in #698 # This test case uses code mentionned in #698
data = numpy.random.rand(10,10).astype('float32') data = numpy.random.rand(10, 10).astype('float32')
output_var = f32sc(name="output", output_var = f32sc(name="output",
value=numpy.zeros((10,10), 'float32')) value=numpy.zeros((10, 10), 'float32'))
x = tensor.fmatrix('x') x = tensor.fmatrix('x')
output_updates = {output_var:x**2} output_updates = {output_var: x ** 2}
output_givens = {x:data} output_givens = {x: data}
output_func = theano.function(inputs=[], outputs=[], output_func = theano.function(inputs=[], outputs=[],
updates=output_updates, givens=output_givens) updates=output_updates, givens=output_givens)
output_func() output_func()
...@@ -78,16 +83,16 @@ class T_updates(unittest.TestCase): ...@@ -78,16 +83,16 @@ class T_updates(unittest.TestCase):
def test_3(self): def test_3(self):
# Test that broadcastable dimensions don't screw up # Test that broadcastable dimensions don't screw up
# update expressions. # update expressions.
data = numpy.random.rand(10,10).astype('float32') data = numpy.random.rand(10, 10).astype('float32')
output_var = f32sc(name="output", output_var = f32sc(name="output", value=data)
value=numpy.zeros((10,10), 'float32'))
# the update_var has type matrix, and the update expression # the update_var has type matrix, and the update expression
# is a broadcasted scalar, and that should be allowed. # is a broadcasted scalar, and that should be allowed.
output_func = theano.function(inputs=[], outputs=[], output_func = theano.function(inputs=[], outputs=[],
updates={output_var:output_var.sum().dimshuffle('x', 'x')}) updates={output_var: output_var.sum().dimshuffle('x', 'x')})
output_func() output_func()
class T_ifelse(unittest.TestCase): class T_ifelse(unittest.TestCase):
def setUp(self): def setUp(self):
utt.seed_rng() utt.seed_rng()
...@@ -111,10 +116,10 @@ class T_ifelse(unittest.TestCase): ...@@ -111,10 +116,10 @@ class T_ifelse(unittest.TestCase):
f = theano.function([cond], out1) f = theano.function([cond], out1)
g = theano.function([cond], out2) g = theano.function([cond], out2)
assert numpy.all(f(0) == data+1) assert numpy.all(f(0) == data + 1)
assert numpy.all(f(1) == data) assert numpy.all(f(1) == data)
assert numpy.all(g(0) == data) assert numpy.all(g(0) == data)
assert numpy.all(g(1) == data+1) assert numpy.all(g(1) == data + 1)
def test_dtype_mismatch(self): def test_dtype_mismatch(self):
data = self.rng.rand(5).astype('float32') data = self.rng.rand(5).astype('float32')
...@@ -135,7 +140,7 @@ class T_ifelse(unittest.TestCase): ...@@ -135,7 +140,7 @@ class T_ifelse(unittest.TestCase):
self.assertRaises(TypeError, ifelse, cond, y, x) self.assertRaises(TypeError, ifelse, cond, y, x)
def test_broadcast_mismatch(self): def test_broadcast_mismatch(self):
data = self.rng.rand(2,3).astype('float32') data = self.rng.rand(2, 3).astype('float32')
x = f32sc(data) x = f32sc(data)
print x.broadcastable print x.broadcastable
y = tensor.frow('y') y = tensor.frow('y')
...@@ -146,7 +151,7 @@ class T_ifelse(unittest.TestCase): ...@@ -146,7 +151,7 @@ class T_ifelse(unittest.TestCase):
self.assertRaises(TypeError, ifelse, cond, y, x) self.assertRaises(TypeError, ifelse, cond, y, x)
def test_sparse_tensor_error(self): def test_sparse_tensor_error(self):
data = self.rng.rand(2,3).astype('float32') data = self.rng.rand(2, 3).astype('float32')
x = f32sc(data) x = f32sc(data)
y = sparse.matrix('csc', dtype='float32', name='y') y = sparse.matrix('csc', dtype='float32', name='y')
z = sparse.matrix('csr', dtype='float32', name='z') z = sparse.matrix('csr', dtype='float32', name='z')
...@@ -160,5 +165,3 @@ class T_ifelse(unittest.TestCase): ...@@ -160,5 +165,3 @@ class T_ifelse(unittest.TestCase):
self.assertRaises((TypeError, ValueError), ifelse, cond, z, x) self.assertRaises((TypeError, ValueError), ifelse, cond, z, x)
self.assertRaises((TypeError, ValueError), ifelse, cond, y, z) self.assertRaises((TypeError, ValueError), ifelse, cond, y, z)
self.assertRaises((TypeError, ValueError), ifelse, cond, z, y) self.assertRaises((TypeError, ValueError), ifelse, cond, z, y)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论