提交 feadd031 authored 作者: Frederic Bastien's avatar Frederic Bastien 提交者: Reyhane Askari

Use the old slower code when we need the clients.

上级 8ab7da6f
...@@ -982,8 +982,7 @@ def io_toposort(inputs, outputs, orderings=None, clients=None): ...@@ -982,8 +982,7 @@ def io_toposort(inputs, outputs, orderings=None, clients=None):
node->clients for each node in the subgraph that is sorted node->clients for each node in the subgraph that is sorted
""" """
if not orderings and clients is None: # ordering can be None or empty dict
if not orderings: # can be None or empty dict
# Specialized function that is faster when more then ~10 nodes # Specialized function that is faster when more then ~10 nodes
# when no ordering. # when no ordering.
...@@ -1005,12 +1004,40 @@ def io_toposort(inputs, outputs, orderings=None, clients=None): ...@@ -1005,12 +1004,40 @@ def io_toposort(inputs, outputs, orderings=None, clients=None):
todo.extend(i.owner for i in cur.inputs if i.owner) todo.extend(i.owner for i in cur.inputs if i.owner)
return order return order
# no ordering compute_deps = None
compute_deps_cache = None
iset = set(inputs)
deps_cache = {}
if not orderings: # ordering can be None or empty dict
# Specialized function that is faster when no ordering.
# Also include the cache in the function itself for speed up.
def compute_deps_cache(obj):
if obj in deps_cache:
return deps_cache[obj]
rval = []
if obj not in iset:
if isinstance(obj, Variable):
if obj.owner:
rval = [obj.owner]
elif isinstance(obj, Apply):
rval = list(obj.inputs)
if rval:
if not isinstance(rval, (list, OrderedSet)):
raise TypeError(
"Non-deterministic collections here make"
" toposort non-deterministic.")
deps_cache[obj] = list(rval)
else:
deps_cache[obj] = rval
else:
deps_cache[obj] = rval
return rval
else:
# the inputs are used only here in the function that decides what # the inputs are used only here in the function that decides what
# 'predecessors' to explore # 'predecessors' to explore
iset = set(inputs)
def compute_deps(obj): def compute_deps(obj):
rval = [] rval = []
if obj not in iset: if obj not in iset:
...@@ -1025,7 +1052,8 @@ def io_toposort(inputs, outputs, orderings=None, clients=None): ...@@ -1025,7 +1052,8 @@ def io_toposort(inputs, outputs, orderings=None, clients=None):
return rval return rval
topo = general_toposort(outputs, deps=compute_deps, topo = general_toposort(outputs, deps=compute_deps,
clients=clients) compute_deps_cache=compute_deps_cache,
deps_cache=deps_cache, clients=clients)
return [o for o in topo if isinstance(o, Apply)] return [o for o in topo if isinstance(o, Apply)]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论