提交 adaea20e authored 作者: Iban Harlouchet's avatar Iban Harlouchet

flake8 for theano/gof/tests/test_graph.py

上级 3dc12504
...@@ -6,15 +6,14 @@ from itertools import count ...@@ -6,15 +6,14 @@ from itertools import count
from theano import ( from theano import (
clone, sparse, sparse,
shared, tensor) shared, tensor)
from theano.gof.graph import ( from theano.gof.graph import (
Node, Apply, Constant, Apply,
as_string, clone, general_toposort, inputs, io_toposort, as_string, clone, general_toposort, inputs, io_toposort,
is_same_graph, Variable) is_same_graph, Variable)
from theano.gof.op import Op from theano.gof.op import Op
from theano.gof.type import Type from theano.gof.type import Type
from theano.tensor.var import TensorVariable
from theano.sandbox.cuda.var import ( from theano.sandbox.cuda.var import (
CudaNdarrayVariable, CudaNdarrayConstant, CudaNdarraySharedVariable) CudaNdarrayVariable, CudaNdarrayConstant, CudaNdarraySharedVariable)
...@@ -85,9 +84,11 @@ class TestInputs: ...@@ -85,9 +84,11 @@ class TestInputs:
class X: class X:
leaf_formatter = lambda self, leaf: str(leaf.type) def leaf_formatter(self, leaf):
node_formatter = lambda self, node, argstrings: "%s(%s)" % (node.op, return str(leaf.type)
", ".join(argstrings))
def node_formatter(self, node, argstrings):
return "%s(%s)" % (node.op, ", ".join(argstrings))
def str(self, inputs, outputs): def str(self, inputs, outputs):
return as_string(inputs, outputs, return as_string(inputs, outputs,
...@@ -117,7 +118,7 @@ class TestStr(X): ...@@ -117,7 +118,7 @@ class TestStr(X):
assert self.str([r1, r2, r5], node2.outputs) == ["MyOp(*1 -> MyOp(R1, R2), *1)"] assert self.str([r1, r2, r5], node2.outputs) == ["MyOp(*1 -> MyOp(R1, R2), *1)"]
def test_cutoff(self): def test_cutoff(self):
r1, r2, r5 = MyVariable(1), MyVariable(2), MyVariable(5) r1, r2 = MyVariable(1), MyVariable(2)
node = MyOp.make_node(r1, r2) node = MyOp.make_node(r1, r2)
node2 = MyOp.make_node(node.outputs[0], node.outputs[0]) node2 = MyOp.make_node(node.outputs[0], node.outputs[0])
assert self.str(node.outputs, node2.outputs) == ["MyOp(R3, R3)"] assert self.str(node.outputs, node2.outputs) == ["MyOp(R3, R3)"]
...@@ -185,7 +186,7 @@ class TestToposort: ...@@ -185,7 +186,7 @@ class TestToposort:
def test_1(self): def test_1(self):
"""Test a graph with double dependencies""" """Test a graph with double dependencies"""
r1, r2, r5 = MyVariable(1), MyVariable(2), MyVariable(5) r1, r5 = MyVariable(1), MyVariable(5)
o = MyOp.make_node(r1, r1) o = MyOp.make_node(r1, r1)
o2 = MyOp.make_node(o.outputs[0], r5) o2 = MyOp.make_node(o.outputs[0], r5)
all = general_toposort(o2.outputs, prenode) all = general_toposort(o2.outputs, prenode)
...@@ -193,7 +194,7 @@ class TestToposort: ...@@ -193,7 +194,7 @@ class TestToposort:
def test_2(self): def test_2(self):
"""Test a graph where the inputs have owners""" """Test a graph where the inputs have owners"""
r1, r2, r5 = MyVariable(1), MyVariable(2), MyVariable(5) r1, r5 = MyVariable(1), MyVariable(5)
o = MyOp.make_node(r1, r1) o = MyOp.make_node(r1, r1)
r2b = o.outputs[0] r2b = o.outputs[0]
o2 = MyOp.make_node(r2b, r2b) o2 = MyOp.make_node(r2b, r2b)
...@@ -214,7 +215,7 @@ class TestToposort: ...@@ -214,7 +215,7 @@ class TestToposort:
def test_4(self): def test_4(self):
"""Test inputs and outputs mixed together in a chain graph""" """Test inputs and outputs mixed together in a chain graph"""
r1, r2, r3, r4 = MyVariable(1), MyVariable(2), MyVariable(3), MyVariable(4) r1, r2 = MyVariable(1), MyVariable(2)
o0 = MyOp.make_node(r1, r2) o0 = MyOp.make_node(r1, r2)
o1 = MyOp.make_node(o0.outputs[0], r1) o1 = MyOp.make_node(o0.outputs[0], r1)
all = io_toposort([r1, o0.outputs[0]], [o0.outputs[0], o1.outputs[0]]) all = io_toposort([r1, o0.outputs[0]], [o0.outputs[0], o1.outputs[0]])
...@@ -222,9 +223,9 @@ class TestToposort: ...@@ -222,9 +223,9 @@ class TestToposort:
def test_5(self): def test_5(self):
"""Test when outputs have clients""" """Test when outputs have clients"""
r1, r2, r3, r4 = MyVariable(1), MyVariable(2), MyVariable(3), MyVariable(4) r1, r2, r4 = MyVariable(1), MyVariable(2), MyVariable(4)
o0 = MyOp.make_node(r1, r2) o0 = MyOp.make_node(r1, r2)
o1 = MyOp.make_node(o0.outputs[0], r4) MyOp.make_node(o0.outputs[0], r4)
all = io_toposort([], o0.outputs) all = io_toposort([], o0.outputs)
assert all == [o0] assert all == [o0]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论