提交 5e11745f authored 作者: Chiheb Trabelsi's avatar Chiheb Trabelsi

test_var.py has been modified in order to respect the flake8 style.

上级 ca56cb80
......@@ -10,7 +10,7 @@ from theano.sandbox.cuda.var import float32_shared_constructor as f32sc
from theano.sandbox.cuda import CudaNdarrayType, cuda_available
import theano.sandbox.cuda as cuda
# Skip test if cuda_ndarray is not available.
if cuda_available == False:
if cuda_available is False:
raise SkipTest('Optional package cuda disabled')
......@@ -26,20 +26,19 @@ def test_float32_shared_constructor():
# test that broadcastable arg is accepted, and that they
# don't strictly have to be tuples
assert eq(
f32sc(npy_row, broadcastable=(True, False)).type,
CudaNdarrayType((True, False)))
assert eq(
f32sc(npy_row, broadcastable=[True, False]).type,
CudaNdarrayType((True, False)))
assert eq(
f32sc(npy_row, broadcastable=numpy.array([True, False])).type,
CudaNdarrayType([True, False]))
assert eq(f32sc(npy_row,
broadcastable=(True, False)).type,
CudaNdarrayType((True, False)))
assert eq(f32sc(npy_row,
broadcastable=[True, False]).type,
CudaNdarrayType((True, False)))
assert eq(f32sc(npy_row,
broadcastable=numpy.array([True, False])).type,
CudaNdarrayType([True, False]))
# test that we can make non-matrix shared vars
assert eq(
f32sc(numpy.zeros((2, 3, 4, 5), dtype='float32')).type,
CudaNdarrayType((False,) * 4))
assert eq(f32sc(numpy.zeros((2, 3, 4, 5), dtype='float32')).type,
CudaNdarrayType((False,) * 4))
def test_givens():
......@@ -72,13 +71,14 @@ class T_updates(unittest.TestCase):
# This test case uses code mentionned in #698
data = numpy.random.rand(10, 10).astype('float32')
output_var = f32sc(name="output",
value=numpy.zeros((10, 10), 'float32'))
value=numpy.zeros((10, 10), 'float32'))
x = tensor.fmatrix('x')
output_updates = [(output_var, x ** 2)]
output_givens = {x: data}
output_func = theano.function(inputs=[], outputs=[],
updates=output_updates, givens=output_givens)
output_func = theano.function(
inputs=[], outputs=[],
updates=output_updates, givens=output_givens)
output_func()
def test_err_ndim(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论