提交 feadd031 authored 作者: Frederic Bastien's avatar Frederic Bastien 提交者: Reyhane Askari

Use the old slower code when we need the clients.

上级 8ab7da6f
...@@ -982,8 +982,7 @@ def io_toposort(inputs, outputs, orderings=None, clients=None): ...@@ -982,8 +982,7 @@ def io_toposort(inputs, outputs, orderings=None, clients=None):
node->clients for each node in the subgraph that is sorted node->clients for each node in the subgraph that is sorted
""" """
if not orderings and clients is None: # ordering can be None or empty dict
if not orderings: # can be None or empty dict
# Specialized function that is faster when more then ~10 nodes # Specialized function that is faster when more then ~10 nodes
# when no ordering. # when no ordering.
...@@ -1005,27 +1004,56 @@ def io_toposort(inputs, outputs, orderings=None, clients=None): ...@@ -1005,27 +1004,56 @@ def io_toposort(inputs, outputs, orderings=None, clients=None):
todo.extend(i.owner for i in cur.inputs if i.owner) todo.extend(i.owner for i in cur.inputs if i.owner)
return order return order
# no ordering compute_deps = None
compute_deps_cache = None
# the inputs are used only here in the function that decides what
# 'predecessors' to explore
iset = set(inputs) iset = set(inputs)
deps_cache = {}
if not orderings: # ordering can be None or empty dict
# Specialized function that is faster when no ordering.
# Also include the cache in the function itself for speed up.
def compute_deps_cache(obj):
if obj in deps_cache:
return deps_cache[obj]
rval = []
if obj not in iset:
if isinstance(obj, Variable):
if obj.owner:
rval = [obj.owner]
elif isinstance(obj, Apply):
rval = list(obj.inputs)
if rval:
if not isinstance(rval, (list, OrderedSet)):
raise TypeError(
"Non-deterministic collections here make"
" toposort non-deterministic.")
deps_cache[obj] = list(rval)
else:
deps_cache[obj] = rval
else:
deps_cache[obj] = rval
return rval
else:
def compute_deps(obj): # the inputs are used only here in the function that decides what
rval = [] # 'predecessors' to explore
if obj not in iset: def compute_deps(obj):
if isinstance(obj, Variable): rval = []
if obj.owner: if obj not in iset:
rval = [obj.owner] if isinstance(obj, Variable):
elif isinstance(obj, Apply): if obj.owner:
rval = list(obj.inputs) rval = [obj.owner]
rval.extend(orderings.get(obj, [])) elif isinstance(obj, Apply):
else: rval = list(obj.inputs)
assert not orderings.get(obj, None) rval.extend(orderings.get(obj, []))
return rval else:
assert not orderings.get(obj, None)
return rval
topo = general_toposort(outputs, deps=compute_deps, topo = general_toposort(outputs, deps=compute_deps,
clients=clients) compute_deps_cache=compute_deps_cache,
deps_cache=deps_cache, clients=clients)
return [o for o in topo if isinstance(o, Apply)] return [o for o in topo if isinstance(o, Apply)]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论