提交 9abca4b8 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Convert theano.gof.graph graph walking functions to generators

上级 88dfd88f
......@@ -8,22 +8,24 @@ from theano import shared, tensor
from theano.gof.graph import (
Apply,
Variable,
ancestors,
as_string,
clone,
equal_computations,
general_toposort,
inputs,
io_toposort,
is_in_ancestors,
list_of_nodes,
ops,
orphans,
stack_search,
variables,
)
from theano.gof.op import Op
from theano.gof.type import Type
def as_variable(x):
assert isinstance(x, Variable)
return x
class MyType(Type):
def __init__(self, thingy):
self.thingy = thingy
......@@ -47,32 +49,16 @@ class MyOp(Op):
__props__ = ()
def make_node(self, *inputs):
inputs = list(map(as_variable, inputs))
for input in inputs:
if not isinstance(input.type, MyType):
print(input, input.type, type(input), type(input.type))
raise Exception("Error 1")
outputs = [MyVariable(sum([input.type.thingy for input in inputs]))]
return Apply(self, inputs, outputs)
assert isinstance(input, Variable)
assert isinstance(input.type, MyType)
outputs = [MyVariable(sum(input.type.thingy for input in inputs))]
return Apply(self, list(inputs), outputs)
MyOp = MyOp()
class TestInputs:
def test_inputs(self):
r1, r2 = MyVariable(1), MyVariable(2)
node = MyOp.make_node(r1, r2)
assert inputs(node.outputs) == [r1, r2]
def test_inputs_deep(self):
r1, r2, r5 = MyVariable(1), MyVariable(2), MyVariable(5)
node = MyOp.make_node(r1, r2)
node2 = MyOp.make_node(node.outputs[0], r5)
i = inputs(node2.outputs)
assert i == [r1, r2, r5], i
class X:
def leaf_formatter(self, leaf):
return str(leaf.type)
......@@ -145,7 +131,7 @@ class TestClone(X):
node = MyOp.make_node(MyOp.make_node(r1, r2).outputs[0], r5)
_, new = clone([r1, r2, r5], node.outputs, False)
new_node = new[0].owner
new_node.inputs = MyVariable(7), MyVariable(8)
new_node.inputs = [MyVariable(7), MyVariable(8)]
assert self.str(inputs(new_node.outputs), new_node.outputs) == ["MyOp(R7, R8)"]
assert self.str(inputs(node.outputs), node.outputs) == [
"MyOp(MyOp(R1, R2), R5)"
......@@ -156,7 +142,7 @@ class TestClone(X):
node = MyOp.make_node(MyOp.make_node(r1, r2).outputs[0], r5)
_, new = clone([r1, r2, r5], node.outputs, False)
new_node = new[0].owner
new_node.inputs = MyVariable(7), MyVariable(8)
new_node.inputs = [MyVariable(7), MyVariable(8)]
c1 = tensor.constant(1.5)
i, o = clone([c1], [c1])
......@@ -181,19 +167,36 @@ def prenode(obj):
class TestToposort:
def test_0(self):
def test_simple(self):
# Test a simple graph
r1, r2, r5 = MyVariable(1), MyVariable(2), MyVariable(5)
o = MyOp.make_node(r1, r2)
o2 = MyOp.make_node(o.outputs[0], r5)
all = general_toposort(o2.outputs, prenode)
assert all == [r5, r2, r1, o, o.outputs[0], o2, o2.outputs[0]]
all = io_toposort([r5], o2.outputs)
assert all == [o, o2]
def test_1(self):
o = MyOp(r1, r2)
o.name = "o1"
o2 = MyOp(o, r5)
o2.name = "o2"
clients = {}
res = general_toposort([o2], prenode, clients=clients)
assert clients == {
o2.owner: [o2],
o: [o2.owner],
r5: [o2.owner],
o.owner: [o],
r1: [o.owner],
r2: [o.owner],
}
assert res == [r5, r2, r1, o.owner, o, o2.owner, o2]
with pytest.raises(ValueError):
general_toposort(
[o2], prenode, compute_deps_cache=lambda x: None, deps_cache=None
)
res = io_toposort([r5], [o2])
assert res == [o.owner, o2.owner]
def test_double_dependencies(self):
# Test a graph with double dependencies
r1, r5 = MyVariable(1), MyVariable(5)
o = MyOp.make_node(r1, r1)
......@@ -201,7 +204,7 @@ class TestToposort:
all = general_toposort(o2.outputs, prenode)
assert all == [r5, r1, o, o.outputs[0], o2, o2.outputs[0]]
def test_2(self):
def test_inputs_owners(self):
# Test a graph where the inputs have owners
r1, r5 = MyVariable(1), MyVariable(5)
o = MyOp.make_node(r1, r1)
......@@ -214,7 +217,7 @@ class TestToposort:
all = io_toposort([r2b], o2.outputs)
assert all == [o2]
def test_3(self):
def test_not_connected(self):
# Test a graph which is not connected
r1, r2, r3, r4 = MyVariable(1), MyVariable(2), MyVariable(3), MyVariable(4)
o0 = MyOp.make_node(r1, r2)
......@@ -222,7 +225,7 @@ class TestToposort:
all = io_toposort([r1, r2, r3, r4], o0.outputs + o1.outputs)
assert all == [o1, o0] or all == [o0, o1]
def test_4(self):
def test_io_chain(self):
# Test inputs and outputs mixed together in a chain graph
r1, r2 = MyVariable(1), MyVariable(2)
o0 = MyOp.make_node(r1, r2)
......@@ -230,7 +233,7 @@ class TestToposort:
all = io_toposort([r1, o0.outputs[0]], [o0.outputs[0], o1.outputs[0]])
assert all == [o1]
def test_5(self):
def test_outputs_clients(self):
# Test when outputs have clients
r1, r2, r4 = MyVariable(1), MyVariable(2), MyVariable(4)
o0 = MyOp.make_node(r1, r2)
......@@ -326,3 +329,134 @@ def test_equal_computations():
max_argmax1 = tensor.max_and_argmax(m)
max_argmax2 = tensor.max_and_argmax(m)
assert equal_computations(max_argmax1, max_argmax2)
def test_stack_search():
r1, r2, r3 = MyVariable(1), MyVariable(2), MyVariable(3)
o1 = MyOp(r1, r2)
o1.name = "o1"
o2 = MyOp(r3, o1)
o2.name = "o2"
def expand(r):
if r.owner:
return r.owner.inputs
res = stack_search([o2], expand, bfs=True, return_children=False)
res_list = list(res)
assert res_list == [o2, r3, o1, r1, r2]
res = stack_search([o2], expand, bfs=False, return_children=False)
res_list = list(res)
assert res_list == [o2, o1, r2, r1, r3]
res = stack_search([o2], expand, bfs=True, return_children=True)
res_list = list(res)
assert res_list == [
(o2, [r3, o1]),
(r3, None),
(o1, [r1, r2]),
(r1, None),
(r2, None),
]
def test_ancestors():
r1, r2, r3 = MyVariable(1), MyVariable(2), MyVariable(3)
o1 = MyOp(r1, r2)
o1.name = "o1"
o2 = MyOp(r3, o1)
o2.name = "o2"
res = ancestors([o2], blockers=None)
res_list = list(res)
assert res_list == [o2, r3, o1, r1, r2]
res = ancestors([o2], blockers=None)
assert r3 in res
res_list = list(res)
assert res_list == [o1, r1, r2]
res = ancestors([o2], blockers=[o1])
res_list = list(res)
assert res_list == [o2, r3, o1]
def test_inputs():
r1, r2, r3 = MyVariable(1), MyVariable(2), MyVariable(3)
o1 = MyOp(r1, r2)
o1.name = "o1"
o2 = MyOp(r3, o1)
o2.name = "o2"
res = inputs([o2], blockers=None)
res_list = list(res)
assert res_list == [r3, r1, r2]
def test_variables_and_orphans():
r1, r2, r3 = MyVariable(1), MyVariable(2), MyVariable(3)
o1 = MyOp(r1, r2)
o1.name = "o1"
o2 = MyOp(r3, o1)
o2.name = "o2"
vars_res = variables([r1, r2], [o2])
orphans_res = orphans([r1, r2], [o2])
vars_res_list = list(vars_res)
orphans_res_list = list(orphans_res)
assert vars_res_list == [o2, o1, r3, r2, r1]
assert orphans_res_list == [r3]
def test_ops():
r1, r2, r3, r4 = MyVariable(1), MyVariable(2), MyVariable(3), MyVariable(4)
o1 = MyOp(r1, r2)
o1.name = "o1"
o2 = MyOp(r3, r4)
o2.name = "o2"
o3 = MyOp(r3, o1, o2)
o3.name = "o3"
res = ops([r1, r2], [o3])
res_list = list(res)
assert res_list == [o3.owner, o2.owner, o1.owner]
def test_list_of_nodes():
r1, r2, r3 = MyVariable(1), MyVariable(2), MyVariable(3)
o1 = MyOp(r1, r2)
o1.name = "o1"
o2 = MyOp(r3, o1)
o2.name = "o2"
res = list_of_nodes([r1, r2], [o2])
assert res == [o2.owner, o1.owner]
def test_is_in_ancestors():
r1, r2, r3 = MyVariable(1), MyVariable(2), MyVariable(3)
o1 = MyOp(r1, r2)
o1.name = "o1"
o2 = MyOp(r3, o1)
o2.name = "o2"
assert is_in_ancestors(o2.owner, o1.owner)
@pytest.mark.xfail(reason="Not implemented")
def test_io_connection_pattern():
raise AssertionError()
@pytest.mark.xfail(reason="Not implemented")
def test_view_roots():
raise AssertionError()
......@@ -1336,7 +1336,6 @@ def test_grad_useless_sum():
x = TensorType(theano.config.floatX, (True,))("x")
l = tt.log(1.0 - sigmoid(x))[0]
g = tt.grad(l, x)
nodes = theano.gof.graph.ops([x], [g])
f = theano.function([x], g, mode=mode)
test_values = [-100, -1, 0, 1, 100]
......@@ -1349,7 +1348,9 @@ def test_grad_useless_sum():
finally:
TensorType.values_eq_approx = old_values_eq_approx
assert not any([isinstance(node.op, Sum) for node in nodes])
assert not any(
[isinstance(node.op, Sum) for node in theano.gof.graph.ops([x], [g])]
)
assert np.allclose(
outputs, [[-3.72007598e-44], [-0.26894142], [-0.5], [-0.73105858], [-1.0]]
)
......
......@@ -22,7 +22,7 @@ def grad_sources_inputs(sources, inputs):
the new interface so the tests don't need to be rewritten.
"""
if inputs is None:
inputs = theano.gof.graph.inputs([source[0] for source in sources])
inputs = list(theano.gof.graph.inputs([source[0] for source in sources]))
return dict(
zip(
inputs,
......
......@@ -2415,9 +2415,11 @@ class _Maker(FunctionMaker): # inheritance buys a few helper functions
inputs = [self.wrap_in(i) for i in inputs]
outputs = [self.wrap_out(o) for o in outputs]
_inputs = gof.graph.inputs(
[o.variable for o in outputs]
+ [i.update for i in inputs if getattr(i, "update", False)]
_inputs = list(
gof.graph.inputs(
[o.variable for o in outputs]
+ [i.update for i in inputs if getattr(i, "update", False)]
)
)
# Check if some input variables are unused
......
......@@ -1206,7 +1206,7 @@ def insert_deepcopy(fgraph, wrapped_inputs, wrapped_outputs):
}
# We can't use fgraph.inputs as this don't include Constant Value.
all_graph_inputs = gof.graph.inputs(fgraph.outputs)
all_graph_inputs = list(gof.graph.inputs(fgraph.outputs))
has_destroyers_attr = hasattr(fgraph, "has_destroyers")
for i in range(len(fgraph.outputs)):
......@@ -1553,9 +1553,11 @@ class FunctionMaker:
# Wrap them in In or Out instances if needed.
inputs = [self.wrap_in(i) for i in inputs]
outputs = [self.wrap_out(o) for o in outputs]
_inputs = gof.graph.inputs(
[o.variable for o in outputs]
+ [i.update for i in inputs if getattr(i, "update", False)]
_inputs = list(
gof.graph.inputs(
[o.variable for o in outputs]
+ [i.update for i in inputs if getattr(i, "update", False)]
)
)
# Check if some input variables are unused
......@@ -1697,12 +1699,14 @@ class FunctionMaker:
# There should be two categories of variables in inputs:
# - variables that have to be provided (used_inputs)
# - shared variables that will be updated
used_inputs = gof.graph.ancestors(
(
[o.variable for o in outputs]
+ [i.update for i in inputs if getattr(i, "update", False)]
),
blockers=[i.variable for i in inputs],
used_inputs = list(
gof.graph.ancestors(
(
[o.variable for o in outputs]
+ [i.update for i in inputs if getattr(i, "update", False)]
),
blockers=[i.variable for i in inputs],
)
)
msg = (
......
......@@ -710,7 +710,7 @@ class FunctionGraph(utils.MetaObject):
Call this for a diagnosis if things go awry.
"""
nodes = ops_between(self.inputs, self.outputs)
nodes = set(ops_between(self.inputs, self.outputs))
if self.apply_nodes != nodes:
missing = nodes.difference(self.apply_nodes)
excess = self.apply_nodes.difference(nodes)
......
"""
Node classes (`Apply`, `Variable`) and expression graph algorithms.
"""
"""Core graph classes."""
import contextlib
import warnings
from collections import deque
from copy import copy
from itertools import count
from typing import (
Callable,
Collection,
Deque,
Dict,
Generator,
Hashable,
Iterable,
List,
Optional,
Sequence,
Set,
TypeVar,
Union,
)
import numpy as np
......@@ -23,7 +36,7 @@ from theano.gof.utils import (
from theano.misc.ordered_set import OrderedSet
__docformat__ = "restructuredtext en"
T = TypeVar("T")
NoParams = object()
......@@ -648,76 +661,92 @@ class Constant(Variable):
# index is not defined, because the `owner` attribute must necessarily be None
def stack_search(start, expand, mode="bfs", build_inv=False):
"""
Search through a graph, either breadth- or depth-first.
def stack_search(
nodes: Iterable[T],
expand: Callable[[T], Optional[Sequence[T]]],
bfs: bool = True,
return_children: bool = False,
hash_fn: Callable[[T], Hashable] = id,
) -> Generator[T, None, Dict[T, List[T]]]:
"""Walk through a graph, either breadth- or depth-first.
Parameters
----------
start : deque
Search from these nodes.
expand : callable
When we get to a node, add expand(node) to the list of nodes to visit.
This function should return a list, or None.
mode : string
'bfs' or 'dfs' for breath first search or depth first search.
Returns
-------
list of `Variable` or `Apply` instances (depends on `expend`)
The list of nodes in order of traversal.
nodes: deque
The nodes from which to start walking.
expand: callable
A callable that is applied to each node in `nodes`, the results of
which are either new nodes to visit or ``None``.
bfs: bool
If ``True``, breath first search is used; otherwise, depth first
search.
return_children: bool
If ``True``, each output node will be accompanied by the output of
`expand` (i.e. the corresponding child nodes).
hash_fn: callable
The function used to produce hashes of the elements in `nodes`.
The default is ``id``.
Yields
------
nodes
When `build_inv` is ``True``, a inverse map is returned.
Notes
-----
A node will appear at most once in the return value, even if it
appears multiple times in the start parameter.
:postcondition: every element of start is transferred to the returned list.
:postcondition: start is empty.
appears multiple times in the `nodes` parameter.
"""
if mode not in ("bfs", "dfs"):
raise ValueError("mode should be bfs or dfs", mode)
rval_set = set()
rval_list = list()
if mode == "bfs":
start_pop = start.popleft
nodes = deque(nodes)
rval_set: Set[T] = set()
if bfs:
nodes_pop: Callable[[], T] = nodes.popleft
else:
start_pop = start.pop
expand_inv = {} # var: clients
while start:
l = start_pop()
if id(l) not in rval_set:
rval_list.append(l)
rval_set.add(id(l))
expand_l = expand(l)
if expand_l:
if build_inv:
for r in expand_l:
expand_inv.setdefault(r, []).append(l)
start.extend(expand_l)
assert len(rval_list) == len(rval_set)
if build_inv:
return rval_list, expand_inv
return rval_list
def ancestors(variable_list, blockers=None):
"""
Return the variables that contribute to those in variable_list (inclusive).
nodes_pop: Callable[[], T] = nodes.pop
while nodes:
node: T = nodes_pop()
node_hash: Hashable = hash_fn(node)
if node_hash not in rval_set:
rval_set.add(node_hash)
new_nodes: Sequence[T] = expand(node)
if return_children:
yield node, new_nodes
else:
yield node
if new_nodes:
nodes.extend(new_nodes)
def ancestors(
graphs: Iterable[Variable], blockers: Collection[Variable] = None
) -> Generator[Variable, None, None]:
"""Return the variables that contribute to those in given graphs (inclusive).
Parameters
----------
variable_list : list of `Variable` instances
graphs: list of `Variable` instances
Output `Variable` instances from which to search backward through
owners.
blockers: list of `Variable` instances
A collection of `Variable`s that, when found, prevent the graph search
from preceding from that point.
Returns
-------
list of `Variable` instances
Yields
------
`Variable`s
All input nodes, in the order found by a left-recursive depth-first
search started at the nodes in `variable_list`.
search started at the nodes in `graphs`.
"""
......@@ -725,142 +754,124 @@ def ancestors(variable_list, blockers=None):
if r.owner and (not blockers or r not in blockers):
return reversed(r.owner.inputs)
dfs_variables = stack_search(deque(variable_list), expand, "dfs")
return dfs_variables
yield from stack_search(graphs, expand, False)
def inputs(variable_list, blockers=None):
"""
Return the inputs required to compute the given Variables.
def inputs(
graphs: Iterable[Variable], blockers: Collection[Variable] = None
) -> Generator[Variable, None, None]:
"""Return the inputs required to compute the given Variables.
Parameters
----------
variable_list : list of `Variable` instances
graphs: list of `Variable` instances
Output `Variable` instances from which to search backward through
owners.
blockers: list of `Variable` instances
A collection of `Variable`s that, when found, prevent the graph search
from preceding from that point.
Returns
-------
list of `Variable` instances
Yields
------
`Variable`s
Input nodes with no owner, in the order found by a left-recursive
depth-first search started at the nodes in `variable_list`.
depth-first search started at the nodes in `graphs`.
"""
vlist = ancestors(variable_list, blockers)
rval = [r for r in vlist if r.owner is None]
return rval
yield from (r for r in ancestors(graphs, blockers) if r.owner is None)
def variables_and_orphans(i, o):
"""
Extract list of variables between i and o nodes via
dfs traversal and chooses the orphans among them
def variables(
ins: Collection[Variable], outs: Iterable[Variable]
) -> Generator[Variable, None, None]:
"""Extract the `Variable`s within the sub-graph between input and output nodes.
Parameters
----------
i : list
Input variables.
o : list
Output variables.
ins: list
Input `Variable`s.
outs: list
Output `Variable`s.
Yields
------
`Variable`s
The `Variable`s that are involved in the subgraph that lies
between `ins` and `outs`. This includes `ins`, `outs`,
``orphans(ins, outs)`` and all values of all intermediary steps from
`ins` to `outs`.
"""
def expand(r):
if r.owner and r not in i:
l = list(r.owner.inputs) + list(r.owner.outputs)
l.reverse()
return l
if r.owner and r not in ins:
return reversed(r.owner.inputs + r.owner.outputs)
variables = stack_search(deque(o), expand, "dfs")
orphans = [r for r in variables if r.owner is None and r not in i]
return variables, orphans
yield from stack_search(outs, expand)
def ops(i, o):
"""
Set of Ops contained within the subgraph between i and o
def orphans(
ins: Collection[Variable], outs: Iterable[Variable]
) -> Generator[Variable, None, None]:
"""Extract the `Variable`s not within the sub-graph between input and output nodes.
Parameters
----------
i : list
Input variables.
o : list
Output variables.
ins: list
Input `Variable`s.
outs: list
Output `Variable`s.
Returns
Yields
-------
object
The set of ops that are contained within the subgraph that lies
between i and o, including the owners of the variables in o and
intermediary ops between i and o, but not the owners of the variables
in i.
"""
ops = set()
variables, orphans = variables_and_orphans(i, o)
for r in variables:
if r not in i and r not in orphans:
if r.owner is not None:
ops.add(r.owner)
return ops
`Variable`s
The `Variable`s upon which one or more Variables in `outs`
depend, but are neither in `ins` nor in the sub-graph that lies between
them.
def variables(i, o):
"""
Extracts list of variables within input and output nodes via dfs travesal
Parameters
----------
i : list
Input variables.
o : list
Output variables.
Returns
-------
object
The set of Variables that are involved in the subgraph that lies
between i and o. This includes i, o, orphans(i, o) and all values of
all intermediary steps from i to o.
Examples
--------
>>> orphans([x], [(x+y).out])
[y]
"""
return variables_and_orphans(i, o)[0]
yield from (r for r in variables(ins, outs) if r.owner is None and r not in ins)
def orphans(i, o):
"""
Extracts list of variables within input and output nodes
via dfs travesal and returns the orphans among them
def ops(
ins: Collection[Variable], outs: Iterable[Variable]
) -> Generator[Apply, None, None]:
"""Extract the `Apply`s contained within the sub-graph between given input and output variables.
Parameters
----------
i : list
Input Variables.
o : list
Output Variables.
Returns
-------
object
The set of Variables which one or more Variables in o depend on but are
neither in i nor in the subgraph that lies between i and o.
Examples
--------
orphans([x], [(x+y).out]) => [y]
ins: list
Input `Variable`s.
outs: list
Output `Variable`s.
Yields
------
`Apply`s
The `Apply`s that are contained within the sub-graph that lies
between `ins` and `outs`, including the owners of the `Variable`s in
`outs` and intermediary `Apply`s between `ins` and `outs`, but not the
owners of the `Variable`s in `ins`.
"""
return variables_and_orphans(i, o)[1]
yield from (
r.owner for r in variables(ins, outs) if r not in ins and r.owner is not None
)
def clone(i, o, copy_inputs=True, copy_orphans=None):
"""Copies the subgraph contained between i and o.
def clone(inputs, outputs, copy_inputs=True, copy_orphans=None):
"""Copies the sub-graph contained between inputs and outputs.
Parameters
----------
i : list
inputs : list
Input Variables.
o : list
outputs : list
Output Variables.
copy_inputs : bool
If True, the inputs will be copied (defaults to True).
......@@ -877,15 +888,15 @@ def clone(i, o, copy_inputs=True, copy_orphans=None):
Notes
-----
A constant, if in the ``i`` list is not an orpha. So it will be
copied depending of the ``copy_inputs`` parameter. Otherwise it
will be copied depending of the ``copy_orphans`` parameter.
A constant, if in the `inputs` list is not an orphan. So it will be copied
depending of the `copy_inputs` parameter. Otherwise it will be copied
depending of the `copy_orphans` parameter.
"""
if copy_orphans is None:
copy_orphans = copy_inputs
equiv = clone_get_equiv(i, o, copy_inputs, copy_orphans)
return [equiv[input] for input in i], [equiv[output] for output in o]
equiv = clone_get_equiv(inputs, outputs, copy_inputs, copy_orphans)
return [equiv[input] for input in inputs], [equiv[output] for output in outputs]
def clone_get_equiv(inputs, outputs, copy_inputs=True, copy_orphans=True, memo=None):
......@@ -951,28 +962,27 @@ def clone_get_equiv(inputs, outputs, copy_inputs=True, copy_orphans=True, memo=N
def general_toposort(
outputs,
deps,
debug_print=False,
compute_deps_cache=None,
deps_cache=None,
clients=None,
):
"""
WRITEME
outputs: Iterable[T],
deps: Callable[[T], Union[OrderedSet, List[T]]],
compute_deps_cache: Optional[Callable[[T], Union[OrderedSet, List[T]]]] = None,
deps_cache: Optional[Dict[T, List[T]]] = None,
clients: Optional[Dict[T, List[T]]] = None,
) -> List[T]:
"""Perform a topological sort of all nodes starting from a given node.
Parameters
----------
deps
deps: callable
A python function that takes a node as input and returns its dependence.
compute_deps_cache : optional
compute_deps_cache: optional
If provided deps_cache should also be provided. This is a function like
deps, but that also cache its results in a dict passed as deps_cache.
deps_cache : dict
Must be used with compute_deps_cache.
clients : dict
If a dict is passed it will be filled with a mapping of node
-> clients for each node in the subgraph.
deps_cache: dict
A dict mapping nodes to their children. This is populated by
`compute_deps_cache`.
clients: dict
If a dict is passed it will be filled with a mapping of
nodes-to-clients for each node in the subgraph.
Notes
-----
......@@ -991,37 +1001,53 @@ def general_toposort(
"""
if compute_deps_cache is None:
deps_cache = {}
if deps_cache is None:
deps_cache = {}
def compute_deps_cache(io):
if io not in deps_cache:
d = deps(io)
if d:
if not isinstance(d, (list, OrderedSet)):
raise TypeError(
"Non-deterministic collections here make"
"Non-deterministic collections found; make"
" toposort non-deterministic."
)
deps_cache[io] = list(d)
else:
deps_cache[io] = d
deps_cache[io] = None
return d
else:
return deps_cache[io]
assert deps_cache is not None
if deps_cache is None:
raise ValueError("deps_cache cannot be None")
assert isinstance(outputs, (tuple, list, deque))
search_res: List[T, Optional[List[T]]] = list(
stack_search(outputs, compute_deps_cache, bfs=False, return_children=True)
)
_clients: Dict[T, List[T]] = {}
sources: Deque[T] = deque()
search_res_len: int = 0
for node, children in search_res:
search_res_len += 1
if children:
for child in children:
_clients.setdefault(child, []).append(node)
if not deps_cache.get(node):
sources.append(node)
reachable, _clients = stack_search(deque(outputs), compute_deps_cache, "dfs", True)
if clients is not None:
clients.update(_clients)
sources = deque([r for r in reachable if not deps_cache.get(r, None)])
rset = set()
rlist = []
rset: Set[T] = set()
rlist: List[T] = []
while sources:
node = sources.popleft()
node: T = sources.popleft()
if node not in rset:
rlist.append(node)
rset.add(node)
......@@ -1031,31 +1057,31 @@ def general_toposort(
if not d:
sources.append(client)
if len(rlist) != len(reachable):
if debug_print:
print("")
print(reachable)
print(rlist)
if len(rlist) != search_res_len:
raise ValueError("graph contains cycles")
return rlist
def io_toposort(inputs, outputs, orderings=None, clients=None):
"""
Perform topological sort from input and output nodes
def io_toposort(
inputs: List[Variable],
outputs: List[Variable],
orderings: Optional[Dict[Apply, List[Apply]]] = None,
clients: Optional[Dict[Variable, List[Variable]]] = None,
) -> List[Apply]:
"""Perform topological sort from input and output nodes.
Parameters
----------
inputs : list or tuple of Variable instances
Graph inputs.
outputs : list or tuple of Apply instances
Graph outputs.
orderings : dict
Key: Apply instance. Value: list of Apply instance.
It is important that the value be a container with a deterministic
iteration order. No sets allowed!
Keys are `Apply` instances, values are lists of `Apply` instances.
clients : dict
If a dict is provided it will be filled with mappings of
node->clients for each node in the subgraph that is sorted
If provided, it will be filled with mappings of nodes-to-clients for
each node in the subgraph that is sorted.
"""
if not orderings and clients is None: # ordering can be None or empty dict
......@@ -1100,11 +1126,6 @@ def io_toposort(inputs, outputs, orderings=None, clients=None):
elif isinstance(obj, Apply):
rval = list(obj.inputs)
if rval:
if not isinstance(rval, (list, OrderedSet)):
raise TypeError(
"Non-deterministic collections here make"
" toposort non-deterministic."
)
deps_cache[obj] = list(rval)
else:
deps_cache[obj] = rval
......@@ -1228,16 +1249,18 @@ def op_as_string(
def as_string(
i, o, leaf_formatter=default_leaf_formatter, node_formatter=default_node_formatter
):
"""
Returns a string representation of the subgraph between i and o
inputs: List[Variable],
outputs: List[Variable],
leaf_formatter=default_leaf_formatter,
node_formatter=default_node_formatter,
) -> List[str]:
"""Returns a string representation of the subgraph between inputs and outputs.
Parameters
----------
i : list
inputs : list
Input `Variable` s.
o : list
outputs : list
Output `Variable` s.
leaf_formatter : callable
Takes a `Variable` and returns a string to describe it.
......@@ -1247,28 +1270,28 @@ def as_string(
Returns
-------
str
Returns a string representation of the subgraph between i and o. If the
same op is used by several other ops, the first occurrence will be
marked as :literal:`*n -> description` and all subsequent occurrences
will be marked as :literal:`*n`, where n is an id number (ids are
attributed in an unspecified order and only exist for viewing
convenience).
list of str
Returns a string representation of the subgraph between `inputs` and
`outputs`. If the same node is used by several other nodes, the first
occurrence will be marked as :literal:`*n -> description` and all
subsequent occurrences will be marked as :literal:`*n`, where n is an id
number (ids are attributed in an unspecified order and only exist for
viewing convenience).
"""
i = set(i)
i = set(inputs)
orph = orphans(i, o)
orph = list(orphans(i, outputs))
multi = set()
seen = set()
for output in o:
for output in outputs:
op = output.owner
if op in seen:
multi.add(op)
else:
seen.add(op)
for op in ops(i, o):
for op in ops(i, outputs):
for input in op.inputs:
op2 = input.owner
if input in i or input in orph or op2 is None:
......@@ -1303,60 +1326,72 @@ def as_string(
else:
return leaf_formatter(r)
return [describe(output) for output in o]
return [describe(output) for output in outputs]
def view_roots(r):
"""
Utility function that returns the leaves of a search through
consecutive view_map()s.
WRITEME
"""
owner = r.owner
def view_roots(node: Variable) -> List[Variable]:
"""Return the leaves from a search through consecutive view-maps."""
owner = node.owner
if owner is not None:
try:
view_map = owner.op.view_map
view_map = {owner.outputs[o]: i for o, i in view_map.items()}
except AttributeError:
return [r]
if r in view_map:
return [node]
if node in view_map:
answer = []
for i in view_map[r]:
for i in view_map[node]:
answer += view_roots(owner.inputs[i])
return answer
else:
return [r]
return [node]
else:
return [r]
return [node]
def list_of_nodes(inputs, outputs):
"""
Return the apply nodes of the graph between inputs and outputs.
def list_of_nodes(
inputs: Collection[Variable], outputs: Iterable[Variable]
) -> List[Apply]:
"""Return the `Apply` nodes of the graph between `inputs` and `outputs`.
Parameters
----------
inputs: list of Variable
Input `Variable`s.
outputs: list of Variable
Output `Variable`s.
"""
return stack_search(
deque([o.owner for o in outputs]),
lambda o: [
inp.owner
for inp in o.inputs
if inp.owner and not any(i in inp.owner.outputs for i in inputs)
],
return list(
stack_search(
[o.owner for o in outputs],
lambda o: [
inp.owner
for inp in o.inputs
if inp.owner and not any(i in inp.owner.outputs for i in inputs)
],
)
)
def is_in_ancestors(l_node, f_node):
r"""
Goes up in the graph and returns True if the apply node f_node is found.
def is_in_ancestors(l_apply: Apply, f_node: Apply) -> bool:
"""Determine if `f_node` is in the graph given by `l_apply`.
Parameters
----------
l_apply: Apply
The node to walk.
f_apply: Apply
The node to find in `l_apply`.
Returns
-------
bool
Use a stack implementation as the vm algo.
We suppose all nodes are not lazy
(i.e. for IfElse we suppose all inputs are computed)
"""
computed = set()
todo = [l_node]
todo = [l_apply]
while todo:
cur = todo.pop()
if cur.outputs[0] in computed:
......@@ -1375,7 +1410,7 @@ def is_in_ancestors(l_node, f_node):
def nodes_constructed():
"""
A contextmanager that is used in inherit_stack_trace and keeps track
of all the newly created varaible nodes inside an optimization. A list
of all the newly created variable nodes inside an optimization. A list
of new_nodes is instantiated but will be filled in a lazy manner (when
Variable.notify_construction_observers is called).
......
......@@ -35,10 +35,6 @@ _logger = logging.getLogger("theano.gof.opt")
_optimizer_idx = [0]
def _list_of_nodes(fgraph):
return list(graph.io_toposort(fgraph.inputs, fgraph.outputs))
class LocalMetaOptimizerSkipAssertionError(AssertionError):
"""This is an AssertionError, but instead of having the
LocalMetaOptimizer print the error, it just skip that
......@@ -1344,7 +1340,9 @@ class LocalOptGroup(LocalOptimizer):
else: # It must be a dict
new_vars = list(new_repl.values())
if self.profile:
self.node_created[opt] += len(graph.ops(fgraph.variables, new_vars))
self.node_created[opt] += len(
list(graph.ops(fgraph.variables, new_vars))
)
self.applied_true[opt] += 1
break # break from the for loop over optimization.
if not new_repl: # No optimization applied in the last iteration
......@@ -1454,7 +1452,9 @@ class GraphToGPULocalOptGroup(LocalOptGroup):
if not new_repl:
continue
if self.profile:
self.node_created[opt] += len(graph.ops(fgraph.variables, new_repl))
self.node_created[opt] += len(
list(graph.ops(fgraph.variables, new_repl))
)
self.applied_true[opt] += 1
return new_repl
......
......@@ -807,7 +807,7 @@ def is_same_graph_with_merge(var1, var2, givens=None):
vars = copied[0:2]
givens = copied[2]
# Create FunctionGraph.
graph_inputs = inputs(vars)
graph_inputs = list(inputs(vars))
# The clone isn't needed as we did a deepcopy and we cloning will
# break the mapping in givens.
fgraph = theano.gof.fg.FunctionGraph(graph_inputs, vars, clone=False)
......
......@@ -637,7 +637,7 @@ class CLinker(Linker):
# We need to include the unused inputs in our variables,
# otherwise we can't pass them to the module.
self.variables = [var for var in self.inputs if not len(fgraph.clients[var])]
self.variables += get_variables(self.inputs, self.outputs)
self.variables += list(get_variables(self.inputs, self.outputs))
# This adds a hidden input which is the params for each node
# that needs it
......
......@@ -820,7 +820,7 @@ def pydotprint(
fct = fct.outputs
assert isinstance(fct, (list, tuple))
assert all(isinstance(v, gof.Variable) for v in fct)
fct = gof.FunctionGraph(inputs=gof.graph.inputs(fct), outputs=fct)
fct = gof.FunctionGraph(inputs=list(gof.graph.inputs(fct)), outputs=fct)
profile = None
outputs = fct.outputs
topo = fct.toposort()
......
......@@ -150,7 +150,7 @@ def remove_constants_and_unused_inputs_scan(fgraph, node):
# Same for the outer graph, initialized w/ number of steps
nw_outer = [node.inputs[0]]
all_ins = gof.graph.inputs(op_outs)
all_ins = list(gof.graph.inputs(op_outs))
for idx in range(op.n_seqs):
node_inp = node.inputs[idx + 1]
if (
......
......@@ -268,7 +268,7 @@ def map_variables(replacer, graphs, additional_inputs=None):
return new_graph
graphs = list(graphs)
inputs_ = list(set(gof.graph.inputs(graphs) + list(additional_inputs)))
inputs_ = list(set(list(gof.graph.inputs(graphs)) + list(additional_inputs)))
# perform any desired replacement of input variables. these
# aren't replaced by the local optimizer approach because they are
......@@ -280,7 +280,7 @@ def map_variables(replacer, graphs, additional_inputs=None):
if new_input is not input_
]
graphs = clone(graphs, share_inputs=True, replace=replacements)
inputs_ = list(set(gof.graph.inputs(graphs) + list(additional_inputs)))
inputs_ = list(set(list(gof.graph.inputs(graphs)) + list(additional_inputs)))
fg = gof.fg.FunctionGraph(inputs_, graphs, clone=False)
......@@ -714,7 +714,7 @@ def scan_can_remove_outs(op, out_idxs):
"""
non_removable = [o for i, o in enumerate(op.outputs) if i not in out_idxs]
required_inputs = gof.graph.inputs(non_removable)
required_inputs = list(gof.graph.inputs(non_removable))
out_ins = []
offset = op.n_seqs
......@@ -734,7 +734,7 @@ def scan_can_remove_outs(op, out_idxs):
if out_idxs_mask[pos] and any([x in required_inputs for x in out_ins[idx]]):
# This output is required ..
out_idxs_mask[pos] = 0
required_inputs += gof.graph.inputs([op.outputs[idx]])
required_inputs += list(gof.graph.inputs([op.outputs[idx]]))
added = True
required_outs = [x for i, x in enumerate(out_idxs) if out_idxs_mask[i] == 0]
......@@ -900,7 +900,7 @@ def reconstruct_graph(inputs, outputs, tag=None):
givens = OrderedDict()
for nw_x, x in zip(nw_inputs, inputs):
givens[x] = nw_x
allinputs = theano.gof.graph.inputs(outputs)
allinputs = list(theano.gof.graph.inputs(outputs))
for inp in allinputs:
if isinstance(inp, theano.Constant):
givens[inp] = inp.clone()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论