提交 5e11745f authored 作者: Chiheb Trabelsi's avatar Chiheb Trabelsi

test_var.py has been modified in order to respect the flake8 style.

上级 ca56cb80
...@@ -10,7 +10,7 @@ from theano.sandbox.cuda.var import float32_shared_constructor as f32sc ...@@ -10,7 +10,7 @@ from theano.sandbox.cuda.var import float32_shared_constructor as f32sc
from theano.sandbox.cuda import CudaNdarrayType, cuda_available from theano.sandbox.cuda import CudaNdarrayType, cuda_available
import theano.sandbox.cuda as cuda import theano.sandbox.cuda as cuda
# Skip test if cuda_ndarray is not available. # Skip test if cuda_ndarray is not available.
if cuda_available == False: if cuda_available is False:
raise SkipTest('Optional package cuda disabled') raise SkipTest('Optional package cuda disabled')
...@@ -26,20 +26,19 @@ def test_float32_shared_constructor(): ...@@ -26,20 +26,19 @@ def test_float32_shared_constructor():
# test that broadcastable arg is accepted, and that they # test that broadcastable arg is accepted, and that they
# don't strictly have to be tuples # don't strictly have to be tuples
assert eq( assert eq(f32sc(npy_row,
f32sc(npy_row, broadcastable=(True, False)).type, broadcastable=(True, False)).type,
CudaNdarrayType((True, False))) CudaNdarrayType((True, False)))
assert eq( assert eq(f32sc(npy_row,
f32sc(npy_row, broadcastable=[True, False]).type, broadcastable=[True, False]).type,
CudaNdarrayType((True, False))) CudaNdarrayType((True, False)))
assert eq( assert eq(f32sc(npy_row,
f32sc(npy_row, broadcastable=numpy.array([True, False])).type, broadcastable=numpy.array([True, False])).type,
CudaNdarrayType([True, False])) CudaNdarrayType([True, False]))
# test that we can make non-matrix shared vars # test that we can make non-matrix shared vars
assert eq( assert eq(f32sc(numpy.zeros((2, 3, 4, 5), dtype='float32')).type,
f32sc(numpy.zeros((2, 3, 4, 5), dtype='float32')).type, CudaNdarrayType((False,) * 4))
CudaNdarrayType((False,) * 4))
def test_givens(): def test_givens():
...@@ -72,13 +71,14 @@ class T_updates(unittest.TestCase): ...@@ -72,13 +71,14 @@ class T_updates(unittest.TestCase):
# This test case uses code mentionned in #698 # This test case uses code mentionned in #698
data = numpy.random.rand(10, 10).astype('float32') data = numpy.random.rand(10, 10).astype('float32')
output_var = f32sc(name="output", output_var = f32sc(name="output",
value=numpy.zeros((10, 10), 'float32')) value=numpy.zeros((10, 10), 'float32'))
x = tensor.fmatrix('x') x = tensor.fmatrix('x')
output_updates = [(output_var, x ** 2)] output_updates = [(output_var, x ** 2)]
output_givens = {x: data} output_givens = {x: data}
output_func = theano.function(inputs=[], outputs=[], output_func = theano.function(
updates=output_updates, givens=output_givens) inputs=[], outputs=[],
updates=output_updates, givens=output_givens)
output_func() output_func()
def test_err_ndim(self): def test_err_ndim(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论