提交 6757ec46 authored 作者: Frederic's avatar Frederic

Speed up io_toposort()

上级 ea2b1686
...@@ -716,7 +716,8 @@ def clone_get_equiv(inputs, outputs, ...@@ -716,7 +716,8 @@ def clone_get_equiv(inputs, outputs,
return memo return memo
def general_toposort(r_out, deps, debug_print=False): def general_toposort(r_out, deps, debug_print=False,
_deps=None, deps_cache=None):
"""WRITEME """WRITEME
:note: :note:
...@@ -727,7 +728,13 @@ def general_toposort(r_out, deps, debug_print=False): ...@@ -727,7 +728,13 @@ def general_toposort(r_out, deps, debug_print=False):
:note: :note:
The order of the return value list is determined by the order of nodes returned by the deps() function. The order of the return value list is determined by the order of nodes returned by the deps() function.
:note: deps should be provided or can be None and the caller
provide _deps and deps_cache. The second option remove a
Python function call, so is faster.
""" """
if _deps is None:
deps_cache = {} deps_cache = {}
def _deps(io): def _deps(io):
...@@ -735,7 +742,8 @@ def general_toposort(r_out, deps, debug_print=False): ...@@ -735,7 +742,8 @@ def general_toposort(r_out, deps, debug_print=False):
d = deps(io) d = deps(io)
if d: if d:
if not isinstance(d, (list, OrderedSet)): if not isinstance(d, (list, OrderedSet)):
raise TypeError("Non-deterministic collections here make" raise TypeError(
"Non-deterministic collections here make"
" toposort non-deterministic.") " toposort non-deterministic.")
deps_cache[io] = list(d) deps_cache[io] = list(d)
else: else:
...@@ -786,12 +794,39 @@ def io_toposort(inputs, outputs, orderings=None): ...@@ -786,12 +794,39 @@ def io_toposort(inputs, outputs, orderings=None):
order. no sets allowed! order. no sets allowed!
""" """
if orderings is None:
orderings = {}
# the inputs are used only here in the function that decides what 'predecessors' to explore # the inputs are used only here in the function that decides what 'predecessors' to explore
iset = set(inputs) iset = set(inputs)
# We build 2 functions as a speed up
deps_cache = {}
deps = None
_deps = None
if not orderings: # can be None or empty dict
# Specialized function that is faster when no ordering.
# Also include the cache in the function itself for speed up.
def _deps(obj):
if obj in deps_cache:
return deps_cache[io]
rval = []
if obj not in iset:
if isinstance(obj, Variable):
if obj.owner:
rval = [obj.owner]
elif isinstance(obj, Apply):
rval = list(obj.inputs)
if rval:
if not isinstance(rval, (list, OrderedSet)):
raise TypeError(
"Non-deterministic collections here make"
" toposort non-deterministic.")
deps_cache[obj] = list(rval)
else:
deps_cache[obj] = rval
else:
deps_cache[obj] = rval
return rval
else:
def deps(obj): def deps(obj):
rval = [] rval = []
if obj not in iset: if obj not in iset:
...@@ -805,7 +840,8 @@ def io_toposort(inputs, outputs, orderings=None): ...@@ -805,7 +840,8 @@ def io_toposort(inputs, outputs, orderings=None):
assert not orderings.get(obj, []) assert not orderings.get(obj, [])
return rval return rval
topo = general_toposort(outputs, deps) topo = general_toposort(outputs, deps=deps, _deps=_deps,
deps_cache=deps_cache)
return [o for o in topo if isinstance(o, Apply)] return [o for o in topo if isinstance(o, Apply)]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论