提交 9abca4b8 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Convert theano.gof.graph graph walking functions to generators

上级 88dfd88f
......@@ -8,22 +8,24 @@ from theano import shared, tensor
from theano.gof.graph import (
Apply,
Variable,
ancestors,
as_string,
clone,
equal_computations,
general_toposort,
inputs,
io_toposort,
is_in_ancestors,
list_of_nodes,
ops,
orphans,
stack_search,
variables,
)
from theano.gof.op import Op
from theano.gof.type import Type
def as_variable(x):
assert isinstance(x, Variable)
return x
class MyType(Type):
def __init__(self, thingy):
self.thingy = thingy
......@@ -47,32 +49,16 @@ class MyOp(Op):
__props__ = ()
def make_node(self, *inputs):
inputs = list(map(as_variable, inputs))
for input in inputs:
if not isinstance(input.type, MyType):
print(input, input.type, type(input), type(input.type))
raise Exception("Error 1")
outputs = [MyVariable(sum([input.type.thingy for input in inputs]))]
return Apply(self, inputs, outputs)
assert isinstance(input, Variable)
assert isinstance(input.type, MyType)
outputs = [MyVariable(sum(input.type.thingy for input in inputs))]
return Apply(self, list(inputs), outputs)
MyOp = MyOp()
class TestInputs:
def test_inputs(self):
r1, r2 = MyVariable(1), MyVariable(2)
node = MyOp.make_node(r1, r2)
assert inputs(node.outputs) == [r1, r2]
def test_inputs_deep(self):
r1, r2, r5 = MyVariable(1), MyVariable(2), MyVariable(5)
node = MyOp.make_node(r1, r2)
node2 = MyOp.make_node(node.outputs[0], r5)
i = inputs(node2.outputs)
assert i == [r1, r2, r5], i
class X:
def leaf_formatter(self, leaf):
return str(leaf.type)
......@@ -145,7 +131,7 @@ class TestClone(X):
node = MyOp.make_node(MyOp.make_node(r1, r2).outputs[0], r5)
_, new = clone([r1, r2, r5], node.outputs, False)
new_node = new[0].owner
new_node.inputs = MyVariable(7), MyVariable(8)
new_node.inputs = [MyVariable(7), MyVariable(8)]
assert self.str(inputs(new_node.outputs), new_node.outputs) == ["MyOp(R7, R8)"]
assert self.str(inputs(node.outputs), node.outputs) == [
"MyOp(MyOp(R1, R2), R5)"
......@@ -156,7 +142,7 @@ class TestClone(X):
node = MyOp.make_node(MyOp.make_node(r1, r2).outputs[0], r5)
_, new = clone([r1, r2, r5], node.outputs, False)
new_node = new[0].owner
new_node.inputs = MyVariable(7), MyVariable(8)
new_node.inputs = [MyVariable(7), MyVariable(8)]
c1 = tensor.constant(1.5)
i, o = clone([c1], [c1])
......@@ -181,19 +167,36 @@ def prenode(obj):
class TestToposort:
def test_0(self):
def test_simple(self):
# Test a simple graph
r1, r2, r5 = MyVariable(1), MyVariable(2), MyVariable(5)
o = MyOp.make_node(r1, r2)
o2 = MyOp.make_node(o.outputs[0], r5)
all = general_toposort(o2.outputs, prenode)
assert all == [r5, r2, r1, o, o.outputs[0], o2, o2.outputs[0]]
all = io_toposort([r5], o2.outputs)
assert all == [o, o2]
def test_1(self):
o = MyOp(r1, r2)
o.name = "o1"
o2 = MyOp(o, r5)
o2.name = "o2"
clients = {}
res = general_toposort([o2], prenode, clients=clients)
assert clients == {
o2.owner: [o2],
o: [o2.owner],
r5: [o2.owner],
o.owner: [o],
r1: [o.owner],
r2: [o.owner],
}
assert res == [r5, r2, r1, o.owner, o, o2.owner, o2]
with pytest.raises(ValueError):
general_toposort(
[o2], prenode, compute_deps_cache=lambda x: None, deps_cache=None
)
res = io_toposort([r5], [o2])
assert res == [o.owner, o2.owner]
def test_double_dependencies(self):
# Test a graph with double dependencies
r1, r5 = MyVariable(1), MyVariable(5)
o = MyOp.make_node(r1, r1)
......@@ -201,7 +204,7 @@ class TestToposort:
all = general_toposort(o2.outputs, prenode)
assert all == [r5, r1, o, o.outputs[0], o2, o2.outputs[0]]
def test_2(self):
def test_inputs_owners(self):
# Test a graph where the inputs have owners
r1, r5 = MyVariable(1), MyVariable(5)
o = MyOp.make_node(r1, r1)
......@@ -214,7 +217,7 @@ class TestToposort:
all = io_toposort([r2b], o2.outputs)
assert all == [o2]
def test_3(self):
def test_not_connected(self):
# Test a graph which is not connected
r1, r2, r3, r4 = MyVariable(1), MyVariable(2), MyVariable(3), MyVariable(4)
o0 = MyOp.make_node(r1, r2)
......@@ -222,7 +225,7 @@ class TestToposort:
all = io_toposort([r1, r2, r3, r4], o0.outputs + o1.outputs)
assert all == [o1, o0] or all == [o0, o1]
def test_4(self):
def test_io_chain(self):
# Test inputs and outputs mixed together in a chain graph
r1, r2 = MyVariable(1), MyVariable(2)
o0 = MyOp.make_node(r1, r2)
......@@ -230,7 +233,7 @@ class TestToposort:
all = io_toposort([r1, o0.outputs[0]], [o0.outputs[0], o1.outputs[0]])
assert all == [o1]
def test_5(self):
def test_outputs_clients(self):
# Test when outputs have clients
r1, r2, r4 = MyVariable(1), MyVariable(2), MyVariable(4)
o0 = MyOp.make_node(r1, r2)
......@@ -326,3 +329,134 @@ def test_equal_computations():
max_argmax1 = tensor.max_and_argmax(m)
max_argmax2 = tensor.max_and_argmax(m)
assert equal_computations(max_argmax1, max_argmax2)
def test_stack_search():
r1, r2, r3 = MyVariable(1), MyVariable(2), MyVariable(3)
o1 = MyOp(r1, r2)
o1.name = "o1"
o2 = MyOp(r3, o1)
o2.name = "o2"
def expand(r):
if r.owner:
return r.owner.inputs
res = stack_search([o2], expand, bfs=True, return_children=False)
res_list = list(res)
assert res_list == [o2, r3, o1, r1, r2]
res = stack_search([o2], expand, bfs=False, return_children=False)
res_list = list(res)
assert res_list == [o2, o1, r2, r1, r3]
res = stack_search([o2], expand, bfs=True, return_children=True)
res_list = list(res)
assert res_list == [
(o2, [r3, o1]),
(r3, None),
(o1, [r1, r2]),
(r1, None),
(r2, None),
]
def test_ancestors():
r1, r2, r3 = MyVariable(1), MyVariable(2), MyVariable(3)
o1 = MyOp(r1, r2)
o1.name = "o1"
o2 = MyOp(r3, o1)
o2.name = "o2"
res = ancestors([o2], blockers=None)
res_list = list(res)
assert res_list == [o2, r3, o1, r1, r2]
res = ancestors([o2], blockers=None)
assert r3 in res
res_list = list(res)
assert res_list == [o1, r1, r2]
res = ancestors([o2], blockers=[o1])
res_list = list(res)
assert res_list == [o2, r3, o1]
def test_inputs():
r1, r2, r3 = MyVariable(1), MyVariable(2), MyVariable(3)
o1 = MyOp(r1, r2)
o1.name = "o1"
o2 = MyOp(r3, o1)
o2.name = "o2"
res = inputs([o2], blockers=None)
res_list = list(res)
assert res_list == [r3, r1, r2]
def test_variables_and_orphans():
r1, r2, r3 = MyVariable(1), MyVariable(2), MyVariable(3)
o1 = MyOp(r1, r2)
o1.name = "o1"
o2 = MyOp(r3, o1)
o2.name = "o2"
vars_res = variables([r1, r2], [o2])
orphans_res = orphans([r1, r2], [o2])
vars_res_list = list(vars_res)
orphans_res_list = list(orphans_res)
assert vars_res_list == [o2, o1, r3, r2, r1]
assert orphans_res_list == [r3]
def test_ops():
r1, r2, r3, r4 = MyVariable(1), MyVariable(2), MyVariable(3), MyVariable(4)
o1 = MyOp(r1, r2)
o1.name = "o1"
o2 = MyOp(r3, r4)
o2.name = "o2"
o3 = MyOp(r3, o1, o2)
o3.name = "o3"
res = ops([r1, r2], [o3])
res_list = list(res)
assert res_list == [o3.owner, o2.owner, o1.owner]
def test_list_of_nodes():
r1, r2, r3 = MyVariable(1), MyVariable(2), MyVariable(3)
o1 = MyOp(r1, r2)
o1.name = "o1"
o2 = MyOp(r3, o1)
o2.name = "o2"
res = list_of_nodes([r1, r2], [o2])
assert res == [o2.owner, o1.owner]
def test_is_in_ancestors():
r1, r2, r3 = MyVariable(1), MyVariable(2), MyVariable(3)
o1 = MyOp(r1, r2)
o1.name = "o1"
o2 = MyOp(r3, o1)
o2.name = "o2"
assert is_in_ancestors(o2.owner, o1.owner)
@pytest.mark.xfail(reason="Not implemented")
def test_io_connection_pattern():
raise AssertionError()
@pytest.mark.xfail(reason="Not implemented")
def test_view_roots():
raise AssertionError()
......@@ -1336,7 +1336,6 @@ def test_grad_useless_sum():
x = TensorType(theano.config.floatX, (True,))("x")
l = tt.log(1.0 - sigmoid(x))[0]
g = tt.grad(l, x)
nodes = theano.gof.graph.ops([x], [g])
f = theano.function([x], g, mode=mode)
test_values = [-100, -1, 0, 1, 100]
......@@ -1349,7 +1348,9 @@ def test_grad_useless_sum():
finally:
TensorType.values_eq_approx = old_values_eq_approx
assert not any([isinstance(node.op, Sum) for node in nodes])
assert not any(
[isinstance(node.op, Sum) for node in theano.gof.graph.ops([x], [g])]
)
assert np.allclose(
outputs, [[-3.72007598e-44], [-0.26894142], [-0.5], [-0.73105858], [-1.0]]
)
......
......@@ -22,7 +22,7 @@ def grad_sources_inputs(sources, inputs):
the new interface so the tests don't need to be rewritten.
"""
if inputs is None:
inputs = theano.gof.graph.inputs([source[0] for source in sources])
inputs = list(theano.gof.graph.inputs([source[0] for source in sources]))
return dict(
zip(
inputs,
......
......@@ -2415,9 +2415,11 @@ class _Maker(FunctionMaker): # inheritance buys a few helper functions
inputs = [self.wrap_in(i) for i in inputs]
outputs = [self.wrap_out(o) for o in outputs]
_inputs = gof.graph.inputs(
[o.variable for o in outputs]
+ [i.update for i in inputs if getattr(i, "update", False)]
_inputs = list(
gof.graph.inputs(
[o.variable for o in outputs]
+ [i.update for i in inputs if getattr(i, "update", False)]
)
)
# Check if some input variables are unused
......
......@@ -1206,7 +1206,7 @@ def insert_deepcopy(fgraph, wrapped_inputs, wrapped_outputs):
}
# We can't use fgraph.inputs as this don't include Constant Value.
all_graph_inputs = gof.graph.inputs(fgraph.outputs)
all_graph_inputs = list(gof.graph.inputs(fgraph.outputs))
has_destroyers_attr = hasattr(fgraph, "has_destroyers")
for i in range(len(fgraph.outputs)):
......@@ -1553,9 +1553,11 @@ class FunctionMaker:
# Wrap them in In or Out instances if needed.
inputs = [self.wrap_in(i) for i in inputs]
outputs = [self.wrap_out(o) for o in outputs]
_inputs = gof.graph.inputs(
[o.variable for o in outputs]
+ [i.update for i in inputs if getattr(i, "update", False)]
_inputs = list(
gof.graph.inputs(
[o.variable for o in outputs]
+ [i.update for i in inputs if getattr(i, "update", False)]
)
)
# Check if some input variables are unused
......@@ -1697,12 +1699,14 @@ class FunctionMaker:
# There should be two categories of variables in inputs:
# - variables that have to be provided (used_inputs)
# - shared variables that will be updated
used_inputs = gof.graph.ancestors(
(
[o.variable for o in outputs]
+ [i.update for i in inputs if getattr(i, "update", False)]
),
blockers=[i.variable for i in inputs],
used_inputs = list(
gof.graph.ancestors(
(
[o.variable for o in outputs]
+ [i.update for i in inputs if getattr(i, "update", False)]
),
blockers=[i.variable for i in inputs],
)
)
msg = (
......
......@@ -710,7 +710,7 @@ class FunctionGraph(utils.MetaObject):
Call this for a diagnosis if things go awry.
"""
nodes = ops_between(self.inputs, self.outputs)
nodes = set(ops_between(self.inputs, self.outputs))
if self.apply_nodes != nodes:
missing = nodes.difference(self.apply_nodes)
excess = self.apply_nodes.difference(nodes)
......
差异被折叠。
......@@ -35,10 +35,6 @@ _logger = logging.getLogger("theano.gof.opt")
_optimizer_idx = [0]
def _list_of_nodes(fgraph):
return list(graph.io_toposort(fgraph.inputs, fgraph.outputs))
class LocalMetaOptimizerSkipAssertionError(AssertionError):
"""This is an AssertionError, but instead of having the
LocalMetaOptimizer print the error, it just skip that
......@@ -1344,7 +1340,9 @@ class LocalOptGroup(LocalOptimizer):
else: # It must be a dict
new_vars = list(new_repl.values())
if self.profile:
self.node_created[opt] += len(graph.ops(fgraph.variables, new_vars))
self.node_created[opt] += len(
list(graph.ops(fgraph.variables, new_vars))
)
self.applied_true[opt] += 1
break # break from the for loop over optimization.
if not new_repl: # No optimization applied in the last iteration
......@@ -1454,7 +1452,9 @@ class GraphToGPULocalOptGroup(LocalOptGroup):
if not new_repl:
continue
if self.profile:
self.node_created[opt] += len(graph.ops(fgraph.variables, new_repl))
self.node_created[opt] += len(
list(graph.ops(fgraph.variables, new_repl))
)
self.applied_true[opt] += 1
return new_repl
......
......@@ -807,7 +807,7 @@ def is_same_graph_with_merge(var1, var2, givens=None):
vars = copied[0:2]
givens = copied[2]
# Create FunctionGraph.
graph_inputs = inputs(vars)
graph_inputs = list(inputs(vars))
# The clone isn't needed as we did a deepcopy and we cloning will
# break the mapping in givens.
fgraph = theano.gof.fg.FunctionGraph(graph_inputs, vars, clone=False)
......
......@@ -637,7 +637,7 @@ class CLinker(Linker):
# We need to include the unused inputs in our variables,
# otherwise we can't pass them to the module.
self.variables = [var for var in self.inputs if not len(fgraph.clients[var])]
self.variables += get_variables(self.inputs, self.outputs)
self.variables += list(get_variables(self.inputs, self.outputs))
# This adds a hidden input which is the params for each node
# that needs it
......
......@@ -820,7 +820,7 @@ def pydotprint(
fct = fct.outputs
assert isinstance(fct, (list, tuple))
assert all(isinstance(v, gof.Variable) for v in fct)
fct = gof.FunctionGraph(inputs=gof.graph.inputs(fct), outputs=fct)
fct = gof.FunctionGraph(inputs=list(gof.graph.inputs(fct)), outputs=fct)
profile = None
outputs = fct.outputs
topo = fct.toposort()
......
......@@ -150,7 +150,7 @@ def remove_constants_and_unused_inputs_scan(fgraph, node):
# Same for the outer graph, initialized w/ number of steps
nw_outer = [node.inputs[0]]
all_ins = gof.graph.inputs(op_outs)
all_ins = list(gof.graph.inputs(op_outs))
for idx in range(op.n_seqs):
node_inp = node.inputs[idx + 1]
if (
......
......@@ -268,7 +268,7 @@ def map_variables(replacer, graphs, additional_inputs=None):
return new_graph
graphs = list(graphs)
inputs_ = list(set(gof.graph.inputs(graphs) + list(additional_inputs)))
inputs_ = list(set(list(gof.graph.inputs(graphs)) + list(additional_inputs)))
# perform any desired replacement of input variables. these
# aren't replaced by the local optimizer approach because they are
......@@ -280,7 +280,7 @@ def map_variables(replacer, graphs, additional_inputs=None):
if new_input is not input_
]
graphs = clone(graphs, share_inputs=True, replace=replacements)
inputs_ = list(set(gof.graph.inputs(graphs) + list(additional_inputs)))
inputs_ = list(set(list(gof.graph.inputs(graphs)) + list(additional_inputs)))
fg = gof.fg.FunctionGraph(inputs_, graphs, clone=False)
......@@ -714,7 +714,7 @@ def scan_can_remove_outs(op, out_idxs):
"""
non_removable = [o for i, o in enumerate(op.outputs) if i not in out_idxs]
required_inputs = gof.graph.inputs(non_removable)
required_inputs = list(gof.graph.inputs(non_removable))
out_ins = []
offset = op.n_seqs
......@@ -734,7 +734,7 @@ def scan_can_remove_outs(op, out_idxs):
if out_idxs_mask[pos] and any([x in required_inputs for x in out_ins[idx]]):
# This output is required ..
out_idxs_mask[pos] = 0
required_inputs += gof.graph.inputs([op.outputs[idx]])
required_inputs += list(gof.graph.inputs([op.outputs[idx]]))
added = True
required_outs = [x for i, x in enumerate(out_idxs) if out_idxs_mask[i] == 0]
......@@ -900,7 +900,7 @@ def reconstruct_graph(inputs, outputs, tag=None):
givens = OrderedDict()
for nw_x, x in zip(nw_inputs, inputs):
givens[x] = nw_x
allinputs = theano.gof.graph.inputs(outputs)
allinputs = list(theano.gof.graph.inputs(outputs))
for inp in allinputs:
if isinstance(inp, theano.Constant):
givens[inp] = inp.clone()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论