提交 80da6e4c authored 作者: Iban Harlouchet's avatar Iban Harlouchet

flake8 for theano/gof/tests/test_sched.py

上级 1ace1445
from theano.gof.sched import (make_dependence_cmp, sort_apply_nodes, from theano.gof.sched import (make_dependence_cmp, sort_apply_nodes,
reverse_dict, _toposort, posort) reverse_dict, _toposort, posort)
import theano
from theano import tensor from theano import tensor
from theano.gof.graph import io_toposort from theano.gof.graph import io_toposort
from theano.compat import cmp from theano.compat import cmp
from six.moves import xrange
def test_dependence(): def test_dependence():
...@@ -22,7 +20,10 @@ def test_dependence(): ...@@ -22,7 +20,10 @@ def test_dependence():
def test_sort_apply_nodes(): def test_sort_apply_nodes():
x = tensor.matrix('x') x = tensor.matrix('x')
y = tensor.dot(x * 2, x + 1) y = tensor.dot(x * 2, x + 1)
str_cmp = lambda a, b: cmp(str(a), str(b)) # lexicographical sort
def str_cmp(a, b):
return cmp(str(a), str(b)) # lexicographical sort
nodes = sort_apply_nodes([x], [y], cmps=[str_cmp]) nodes = sort_apply_nodes([x], [y], cmps=[str_cmp])
for a, b in zip(nodes[:-1], nodes[1:]): for a, b in zip(nodes[:-1], nodes[1:]):
...@@ -39,10 +40,10 @@ def test_reverse_dict(): ...@@ -39,10 +40,10 @@ def test_reverse_dict():
def test__toposort(): def test__toposort():
edges = {1: set((4, 6, 7)), 2: set((4, 6, 7)), edges = {1: set((4, 6, 7)), 2: set((4, 6, 7)),
3: set((5, 7)), 4: set((6, 7)), 5: set((7,))} 3: set((5, 7)), 4: set((6, 7)), 5: set((7,))}
order = _toposort(edges) order = _toposort(edges)
assert not any(a in edges.get(b, ()) for i, a in enumerate(order) assert not any(a in edges.get(b, ()) for i, a in enumerate(order)
for b in order[i:]) for b in order[i:])
def test_posort_easy(): def test_posort_easy():
...@@ -64,5 +65,5 @@ def test_posort(): ...@@ -64,5 +65,5 @@ def test_posort():
cmps = [lambda a, b: a % 10 - b % 10, cmps = [lambda a, b: a % 10 - b % 10,
lambda a, b: (a / 10) % 2 - (b / 10) % 2, lambda a, b: (a / 10) % 2 - (b / 10) % 2,
lambda a, b: a - b] lambda a, b: a - b]
assert posort(l, *cmps) == \ assert (posort(l, *cmps) ==
[10, 1, 11, 2, 12, 3, 13, 4, 14, 5, 15, 6, 16, 7, 17, 8, 18, 9, 19] [10, 1, 11, 2, 12, 3, 13, 4, 14, 5, 15, 6, 16, 7, 17, 8, 18, 9, 19])
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论