提交 6757ec46 authored 作者: Frederic's avatar Frederic

Speed up io_toposort()

上级 ea2b1686
...@@ -716,7 +716,8 @@ def clone_get_equiv(inputs, outputs, ...@@ -716,7 +716,8 @@ def clone_get_equiv(inputs, outputs,
return memo return memo
def general_toposort(r_out, deps, debug_print=False): def general_toposort(r_out, deps, debug_print=False,
_deps=None, deps_cache=None):
"""WRITEME """WRITEME
:note: :note:
...@@ -727,22 +728,29 @@ def general_toposort(r_out, deps, debug_print=False): ...@@ -727,22 +728,29 @@ def general_toposort(r_out, deps, debug_print=False):
:note: :note:
The order of the return value list is determined by the order of nodes returned by the deps() function. The order of the return value list is determined by the order of nodes returned by the deps() function.
"""
deps_cache = {}
def _deps(io): :note: deps should be provided or can be None and the caller
if io not in deps_cache: provide _deps and deps_cache. The second option remove a
d = deps(io) Python function call, so is faster.
if d:
if not isinstance(d, (list, OrderedSet)): """
raise TypeError("Non-deterministic collections here make" if _deps is None:
deps_cache = {}
def _deps(io):
if io not in deps_cache:
d = deps(io)
if d:
if not isinstance(d, (list, OrderedSet)):
raise TypeError(
"Non-deterministic collections here make"
" toposort non-deterministic.") " toposort non-deterministic.")
deps_cache[io] = list(d) deps_cache[io] = list(d)
else:
deps_cache[io] = d
return d
else: else:
deps_cache[io] = d return deps_cache[io]
return d
else:
return deps_cache[io]
assert isinstance(r_out, (tuple, list, deque)) assert isinstance(r_out, (tuple, list, deque))
...@@ -786,26 +794,54 @@ def io_toposort(inputs, outputs, orderings=None): ...@@ -786,26 +794,54 @@ def io_toposort(inputs, outputs, orderings=None):
order. no sets allowed! order. no sets allowed!
""" """
if orderings is None:
orderings = {}
# the inputs are used only here in the function that decides what 'predecessors' to explore # the inputs are used only here in the function that decides what 'predecessors' to explore
iset = set(inputs) iset = set(inputs)
def deps(obj): # We build 2 functions as a speed up
rval = [] deps_cache = {}
if obj not in iset:
if isinstance(obj, Variable): deps = None
if obj.owner: _deps = None
rval = [obj.owner] if not orderings: # can be None or empty dict
elif isinstance(obj, Apply): # Specialized function that is faster when no ordering.
rval = list(obj.inputs) # Also include the cache in the function itself for speed up.
rval.extend(orderings.get(obj, [])) def _deps(obj):
else: if obj in deps_cache:
assert not orderings.get(obj, []) return deps_cache[io]
return rval rval = []
if obj not in iset:
if isinstance(obj, Variable):
if obj.owner:
rval = [obj.owner]
elif isinstance(obj, Apply):
rval = list(obj.inputs)
if rval:
if not isinstance(rval, (list, OrderedSet)):
raise TypeError(
"Non-deterministic collections here make"
" toposort non-deterministic.")
deps_cache[obj] = list(rval)
else:
deps_cache[obj] = rval
else:
deps_cache[obj] = rval
return rval
else:
def deps(obj):
rval = []
if obj not in iset:
if isinstance(obj, Variable):
if obj.owner:
rval = [obj.owner]
elif isinstance(obj, Apply):
rval = list(obj.inputs)
rval.extend(orderings.get(obj, []))
else:
assert not orderings.get(obj, [])
return rval
topo = general_toposort(outputs, deps) topo = general_toposort(outputs, deps=deps, _deps=_deps,
deps_cache=deps_cache)
return [o for o in topo if isinstance(o, Apply)] return [o for o in topo if isinstance(o, Apply)]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论