提交 44a54a08 authored 作者: Olivier Delalleau's avatar Olivier Delalleau

PEP8

上级 2ea2bc86
...@@ -5110,7 +5110,7 @@ class test_arithmetic_cast(unittest.TestCase): ...@@ -5110,7 +5110,7 @@ class test_arithmetic_cast(unittest.TestCase):
warnings.filterwarnings('ignore', message='Division of two integer', warnings.filterwarnings('ignore', message='Division of two integer',
category=DeprecationWarning) category=DeprecationWarning)
try: try:
for cfg in ('numpy+floatX', ): # Used to test 'numpy' as well. for cfg in ('numpy+floatX', ): # Used to test 'numpy' as well.
config.cast_policy = cfg config.cast_policy = cfg
for op in (operator.add, operator.sub, operator.mul, for op in (operator.add, operator.sub, operator.mul,
operator.div, operator.floordiv): operator.div, operator.floordiv):
...@@ -5237,7 +5237,7 @@ class test_broadcast(unittest.TestCase): ...@@ -5237,7 +5237,7 @@ class test_broadcast(unittest.TestCase):
def test_broadcast_bigdim(self): def test_broadcast_bigdim(self):
def f(): def f():
x = matrix() x = matrix()
addbroadcast(x,2) addbroadcast(x, 2)
self.assertRaises(ValueError, f) self.assertRaises(ValueError, f)
def test_unbroadcast_addbroadcast(self): def test_unbroadcast_addbroadcast(self):
...@@ -5246,41 +5246,41 @@ class test_broadcast(unittest.TestCase): ...@@ -5246,41 +5246,41 @@ class test_broadcast(unittest.TestCase):
and fuse consecutive Rebroadcast op and fuse consecutive Rebroadcast op
""" """
x=matrix() x = matrix()
assert unbroadcast(x,0) is x assert unbroadcast(x, 0) is x
assert unbroadcast(x,1) is x assert unbroadcast(x, 1) is x
assert unbroadcast(x,1,0) is x assert unbroadcast(x, 1, 0) is x
assert unbroadcast(x,0,1) is x assert unbroadcast(x, 0, 1) is x
assert addbroadcast(x,0) is not x assert addbroadcast(x, 0) is not x
assert addbroadcast(x,1) is not x assert addbroadcast(x, 1) is not x
assert addbroadcast(x,1,0).owner.inputs[0] is x assert addbroadcast(x, 1, 0).owner.inputs[0] is x
assert unbroadcast(addbroadcast(x,0),0) is x assert unbroadcast(addbroadcast(x, 0), 0) is x
assert addbroadcast(unbroadcast(x,0),0) is not x assert addbroadcast(unbroadcast(x, 0), 0) is not x
x=row() x = row()
assert unbroadcast(x,0) is not x assert unbroadcast(x, 0) is not x
assert unbroadcast(x,1) is x assert unbroadcast(x, 1) is x
assert unbroadcast(x,1,0) is not x assert unbroadcast(x, 1, 0) is not x
assert unbroadcast(x,0,1) is not x assert unbroadcast(x, 0, 1) is not x
assert addbroadcast(x,0) is x assert addbroadcast(x, 0) is x
assert addbroadcast(x,1).owner.inputs[0] is x assert addbroadcast(x, 1).owner.inputs[0] is x
assert addbroadcast(x,1,0).owner.inputs[0] is x assert addbroadcast(x, 1, 0).owner.inputs[0] is x
assert addbroadcast(x,0,1).owner.inputs[0] is x assert addbroadcast(x, 0, 1).owner.inputs[0] is x
assert unbroadcast(addbroadcast(x,1),1) is x assert unbroadcast(addbroadcast(x, 1), 1) is x
assert addbroadcast(unbroadcast(x,1),1) is not x assert addbroadcast(unbroadcast(x, 1), 1) is not x
# The first broadcast is remove the broadcast, so the second # The first broadcast is remove the broadcast, so the second
# should not make one # should not make one
assert unbroadcast(unbroadcast(x,0),0).owner.inputs[0] is x assert unbroadcast(unbroadcast(x, 0), 0).owner.inputs[0] is x
# Test that consecutive Rebroadcast op are fused # Test that consecutive Rebroadcast op are fused
x=TensorType(dtype = 'float64', broadcastable = (True,True))() x = TensorType(dtype='float64', broadcastable=(True, True))()
assert unbroadcast(unbroadcast(x,1),0).owner.inputs[0] is x assert unbroadcast(unbroadcast(x, 1), 0).owner.inputs[0] is x
assert addbroadcast(unbroadcast(x,1),0).owner.inputs[0] is x assert addbroadcast(unbroadcast(x, 1), 0).owner.inputs[0] is x
assert addbroadcast(unbroadcast(x,0),0) is x assert addbroadcast(unbroadcast(x, 0), 0) is x
def test_patternbroadcast(self): def test_patternbroadcast(self):
# Test that patternbroadcast with an empty broadcasting pattern works # Test that patternbroadcast with an empty broadcasting pattern works
...@@ -5295,7 +5295,7 @@ class test_broadcast(unittest.TestCase): ...@@ -5295,7 +5295,7 @@ class test_broadcast(unittest.TestCase):
x = matrix() x = matrix()
y = addbroadcast(x, 0) y = addbroadcast(x, 0)
f = theano.function([x], y.shape) f = theano.function([x], y.shape)
assert (f(numpy.zeros((1,5), dtype=config.floatX)) == [1,5]).all() assert (f(numpy.zeros((1, 5), dtype=config.floatX)) == [1, 5]).all()
topo = f.maker.env.toposort() topo = f.maker.env.toposort()
if theano.config.mode != 'FAST_COMPILE': if theano.config.mode != 'FAST_COMPILE':
assert len(topo) == 2 assert len(topo) == 2
...@@ -5305,7 +5305,7 @@ class test_broadcast(unittest.TestCase): ...@@ -5305,7 +5305,7 @@ class test_broadcast(unittest.TestCase):
x = matrix() x = matrix()
y = unbroadcast(x, 0) y = unbroadcast(x, 0)
f = theano.function([x], y.shape) f = theano.function([x], y.shape)
assert (f(numpy.zeros((2,5), dtype=config.floatX)) == [2,5]).all() assert (f(numpy.zeros((2, 5), dtype=config.floatX)) == [2, 5]).all()
topo = f.maker.env.toposort() topo = f.maker.env.toposort()
if theano.config.mode != 'FAST_COMPILE': if theano.config.mode != 'FAST_COMPILE':
assert len(topo) == 3 assert len(topo) == 3
...@@ -5316,7 +5316,7 @@ class test_broadcast(unittest.TestCase): ...@@ -5316,7 +5316,7 @@ class test_broadcast(unittest.TestCase):
x = row() x = row()
y = unbroadcast(x, 0) y = unbroadcast(x, 0)
f = theano.function([x], y.shape) f = theano.function([x], y.shape)
assert (f(numpy.zeros((1,5), dtype=config.floatX)) == [1,5]).all() assert (f(numpy.zeros((1, 5), dtype=config.floatX)) == [1, 5]).all()
topo = f.maker.env.toposort() topo = f.maker.env.toposort()
if theano.config.mode != 'FAST_COMPILE': if theano.config.mode != 'FAST_COMPILE':
assert len(topo) == 2 assert len(topo) == 2
...@@ -5588,7 +5588,7 @@ class test_sort(unittest.TestCase): ...@@ -5588,7 +5588,7 @@ class test_sort(unittest.TestCase):
def setUp(self): def setUp(self):
self.rng = numpy.random.RandomState(seed=utt.fetch_seed()) self.rng = numpy.random.RandomState(seed=utt.fetch_seed())
self.m_val = self.rng.rand(3,2) self.m_val = self.rng.rand(3, 2)
self.v_val = self.rng.rand(4) self.v_val = self.rng.rand(4)
def test1(self): def test1(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论