提交 f5c7aa57 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Allow callers of io_toposort to get a dictionary of clients as a side effect.

上级 0aa5ff77
......@@ -873,7 +873,8 @@ def clone_get_equiv(inputs, outputs, copy_inputs_and_orphans=True, memo=None):
def general_toposort(r_out, deps, debug_print=False,
compute_deps_cache=None, deps_cache=None):
compute_deps_cache=None, deps_cache=None,
clients=None):
"""
WRITEME
......@@ -886,6 +887,9 @@ def general_toposort(r_out, deps, debug_print=False,
deps, but that also cache its results in a dict passed as deps_cache.
deps_cache : dict
Must be used with compute_deps_cache.
clients : dict
If a dict is passed it will be filled with a mapping of node
-> clients for each node in the subgraph.
Notes
-----
......@@ -924,8 +928,10 @@ def general_toposort(r_out, deps, debug_print=False,
assert isinstance(r_out, (tuple, list, deque))
reachable, clients = stack_search(deque(r_out), compute_deps_cache,
reachable, _clients = stack_search(deque(r_out), compute_deps_cache,
'dfs', True)
if clients is not None:
clients.update(_clients)
sources = deque([r for r in reachable if not deps_cache.get(r, None)])
rset = set()
......@@ -935,7 +941,7 @@ def general_toposort(r_out, deps, debug_print=False,
if node not in rset:
rlist.append(node)
rset.add(node)
for client in clients.get(node, []):
for client in _clients.get(node, []):
deps_cache[client] = [a for a in deps_cache[client]
if a is not node]
if not deps_cache[client]:
......@@ -951,7 +957,7 @@ def general_toposort(r_out, deps, debug_print=False,
return rlist
def io_toposort(inputs, outputs, orderings=None):
def io_toposort(inputs, outputs, orderings=None, clients=None):
"""
WRITEME
......@@ -959,10 +965,13 @@ def io_toposort(inputs, outputs, orderings=None):
----------
inputs : list or tuple of Variable instances
outputs : list or tuple of Apply instances
orderings: dict
orderings : dict
Key: Apply instance. Value: list of Apply instance.
It is important that the value be a container with a deterministic
iteration order. No sets allowed!
clients : dict
If a dict is provided it will be filled with mappings of
node->clients for each node in the subgraph that is sorted
"""
# the inputs are used only here in the function that decides what 'predecessors' to explore
......@@ -1013,7 +1022,7 @@ def io_toposort(inputs, outputs, orderings=None):
topo = general_toposort(outputs, deps=compute_deps,
compute_deps_cache=compute_deps_cache,
deps_cache=deps_cache)
deps_cache=deps_cache, clients=clients)
return [o for o in topo if isinstance(o, Apply)]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论