提交 adaea20e authored 作者: Iban Harlouchet's avatar Iban Harlouchet

flake8 for theano/gof/tests/test_graph.py

上级 3dc12504
...@@ -6,17 +6,16 @@ from itertools import count ...@@ -6,17 +6,16 @@ from itertools import count
from theano import ( from theano import (
clone, sparse, sparse,
shared, tensor) shared, tensor)
from theano.gof.graph import ( from theano.gof.graph import (
Node, Apply, Constant, Apply,
as_string, clone, general_toposort, inputs, io_toposort, as_string, clone, general_toposort, inputs, io_toposort,
is_same_graph, Variable) is_same_graph, Variable)
from theano.gof.op import Op from theano.gof.op import Op
from theano.gof.type import Type from theano.gof.type import Type
from theano.tensor.var import TensorVariable
from theano.sandbox.cuda.var import ( from theano.sandbox.cuda.var import (
CudaNdarrayVariable, CudaNdarrayConstant, CudaNdarraySharedVariable) CudaNdarrayVariable, CudaNdarrayConstant, CudaNdarraySharedVariable)
def as_variable(x): def as_variable(x):
...@@ -46,7 +45,7 @@ def MyVariable(thingy): ...@@ -46,7 +45,7 @@ def MyVariable(thingy):
class MyOp(Op): class MyOp(Op):
__props__ = () __props__ = ()
def make_node(self, *inputs): def make_node(self, *inputs):
inputs = list(map(as_variable, inputs)) inputs = list(map(as_variable, inputs))
for input in inputs: for input in inputs:
...@@ -85,9 +84,11 @@ class TestInputs: ...@@ -85,9 +84,11 @@ class TestInputs:
class X: class X:
leaf_formatter = lambda self, leaf: str(leaf.type) def leaf_formatter(self, leaf):
node_formatter = lambda self, node, argstrings: "%s(%s)" % (node.op, return str(leaf.type)
", ".join(argstrings))
def node_formatter(self, node, argstrings):
return "%s(%s)" % (node.op, ", ".join(argstrings))
def str(self, inputs, outputs): def str(self, inputs, outputs):
return as_string(inputs, outputs, return as_string(inputs, outputs,
...@@ -117,7 +118,7 @@ class TestStr(X): ...@@ -117,7 +118,7 @@ class TestStr(X):
assert self.str([r1, r2, r5], node2.outputs) == ["MyOp(*1 -> MyOp(R1, R2), *1)"] assert self.str([r1, r2, r5], node2.outputs) == ["MyOp(*1 -> MyOp(R1, R2), *1)"]
def test_cutoff(self): def test_cutoff(self):
r1, r2, r5 = MyVariable(1), MyVariable(2), MyVariable(5) r1, r2 = MyVariable(1), MyVariable(2)
node = MyOp.make_node(r1, r2) node = MyOp.make_node(r1, r2)
node2 = MyOp.make_node(node.outputs[0], node.outputs[0]) node2 = MyOp.make_node(node.outputs[0], node.outputs[0])
assert self.str(node.outputs, node2.outputs) == ["MyOp(R3, R3)"] assert self.str(node.outputs, node2.outputs) == ["MyOp(R3, R3)"]
...@@ -185,7 +186,7 @@ class TestToposort: ...@@ -185,7 +186,7 @@ class TestToposort:
def test_1(self): def test_1(self):
"""Test a graph with double dependencies""" """Test a graph with double dependencies"""
r1, r2, r5 = MyVariable(1), MyVariable(2), MyVariable(5) r1, r5 = MyVariable(1), MyVariable(5)
o = MyOp.make_node(r1, r1) o = MyOp.make_node(r1, r1)
o2 = MyOp.make_node(o.outputs[0], r5) o2 = MyOp.make_node(o.outputs[0], r5)
all = general_toposort(o2.outputs, prenode) all = general_toposort(o2.outputs, prenode)
...@@ -193,7 +194,7 @@ class TestToposort: ...@@ -193,7 +194,7 @@ class TestToposort:
def test_2(self): def test_2(self):
"""Test a graph where the inputs have owners""" """Test a graph where the inputs have owners"""
r1, r2, r5 = MyVariable(1), MyVariable(2), MyVariable(5) r1, r5 = MyVariable(1), MyVariable(5)
o = MyOp.make_node(r1, r1) o = MyOp.make_node(r1, r1)
r2b = o.outputs[0] r2b = o.outputs[0]
o2 = MyOp.make_node(r2b, r2b) o2 = MyOp.make_node(r2b, r2b)
...@@ -214,7 +215,7 @@ class TestToposort: ...@@ -214,7 +215,7 @@ class TestToposort:
def test_4(self): def test_4(self):
"""Test inputs and outputs mixed together in a chain graph""" """Test inputs and outputs mixed together in a chain graph"""
r1, r2, r3, r4 = MyVariable(1), MyVariable(2), MyVariable(3), MyVariable(4) r1, r2 = MyVariable(1), MyVariable(2)
o0 = MyOp.make_node(r1, r2) o0 = MyOp.make_node(r1, r2)
o1 = MyOp.make_node(o0.outputs[0], r1) o1 = MyOp.make_node(o0.outputs[0], r1)
all = io_toposort([r1, o0.outputs[0]], [o0.outputs[0], o1.outputs[0]]) all = io_toposort([r1, o0.outputs[0]], [o0.outputs[0], o1.outputs[0]])
...@@ -222,9 +223,9 @@ class TestToposort: ...@@ -222,9 +223,9 @@ class TestToposort:
def test_5(self): def test_5(self):
"""Test when outputs have clients""" """Test when outputs have clients"""
r1, r2, r3, r4 = MyVariable(1), MyVariable(2), MyVariable(3), MyVariable(4) r1, r2, r4 = MyVariable(1), MyVariable(2), MyVariable(4)
o0 = MyOp.make_node(r1, r2) o0 = MyOp.make_node(r1, r2)
o1 = MyOp.make_node(o0.outputs[0], r4) MyOp.make_node(o0.outputs[0], r4)
all = io_toposort([], o0.outputs) all = io_toposort([], o0.outputs)
assert all == [o0] assert all == [o0]
...@@ -264,11 +265,11 @@ class TestIsSameGraph(unittest.TestCase): ...@@ -264,11 +265,11 @@ class TestIsSameGraph(unittest.TestCase):
""" """
x, y, z = tensor.vectors('x', 'y', 'z') x, y, z = tensor.vectors('x', 'y', 'z')
self.check([ self.check([
(x, x, (({}, True), )), (x, x, (({}, True), )),
(x, y, (({}, False), ({y: x}, True), )), (x, y, (({}, False), ({y: x}, True), )),
(x, tensor.neg(x), (({}, False), )), (x, tensor.neg(x), (({}, False), )),
(x, tensor.neg(y), (({}, False), )), (x, tensor.neg(y), (({}, False), )),
]) ])
def test_full_graph(self): def test_full_graph(self):
""" """
...@@ -277,14 +278,14 @@ class TestIsSameGraph(unittest.TestCase): ...@@ -277,14 +278,14 @@ class TestIsSameGraph(unittest.TestCase):
x, y, z = tensor.vectors('x', 'y', 'z') x, y, z = tensor.vectors('x', 'y', 'z')
t = x * y t = x * y
self.check([ self.check([
(x * 2, x * 2, (({}, True), )), (x * 2, x * 2, (({}, True), )),
(x * 2, y * 2, (({}, False), ({y: x}, True), )), (x * 2, y * 2, (({}, False), ({y: x}, True), )),
(x * 2, y * 2, (({}, False), ({x: y}, True), )), (x * 2, y * 2, (({}, False), ({x: y}, True), )),
(x * 2, y * 3, (({}, False), ({y: x}, False), )), (x * 2, y * 3, (({}, False), ({y: x}, False), )),
(t * 2, z * 2, (({}, False), ({t: z}, True), )), (t * 2, z * 2, (({}, False), ({t: z}, True), )),
(t * 2, z * 2, (({}, False), ({z: t}, True), )), (t * 2, z * 2, (({}, False), ({z: t}, True), )),
(x * (y * z), (x * y) * z, (({}, False), )), (x * (y * z), (x * y) * z, (({}, False), )),
]) ])
def test_merge_only(self): def test_merge_only(self):
""" """
...@@ -293,15 +294,15 @@ class TestIsSameGraph(unittest.TestCase): ...@@ -293,15 +294,15 @@ class TestIsSameGraph(unittest.TestCase):
x, y, z = tensor.vectors('x', 'y', 'z') x, y, z = tensor.vectors('x', 'y', 'z')
t = x * y t = x * y
self.check([ self.check([
(x, t, (({}, False), ({t: x}, True))), (x, t, (({}, False), ({t: x}, True))),
(t * 2, x * 2, (({}, False), ({t: x}, True), )), (t * 2, x * 2, (({}, False), ({t: x}, True), )),
(x * x, x * y, (({}, False), ({y: x}, True), )), (x * x, x * y, (({}, False), ({y: x}, True), )),
(x * x, x * y, (({}, False), ({y: x}, True), )), (x * x, x * y, (({}, False), ({y: x}, True), )),
(x * x + z, x * y + t, (({}, False), (x * x + z, x * y + t, (({}, False),
({y: x}, False), ({y: x}, False),
({y: x, t: z}, True))), ({y: x, t: z}, True))),
], ],
debug=False) debug=False)
################ ################
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论