提交 6757ec46 authored 作者: Frederic's avatar Frederic

Speed up io_toposort()

上级 ea2b1686
......@@ -716,7 +716,8 @@ def clone_get_equiv(inputs, outputs,
return memo
def general_toposort(r_out, deps, debug_print=False):
def general_toposort(r_out, deps, debug_print=False,
_deps=None, deps_cache=None):
"""WRITEME
:note:
......@@ -727,22 +728,29 @@ def general_toposort(r_out, deps, debug_print=False):
:note:
The order of the return value list is determined by the order of nodes returned by the deps() function.
"""
deps_cache = {}
def _deps(io):
if io not in deps_cache:
d = deps(io)
if d:
if not isinstance(d, (list, OrderedSet)):
raise TypeError("Non-deterministic collections here make"
:note: deps should be provided or can be None and the caller
provide _deps and deps_cache. The second option remove a
Python function call, so is faster.
"""
if _deps is None:
deps_cache = {}
def _deps(io):
if io not in deps_cache:
d = deps(io)
if d:
if not isinstance(d, (list, OrderedSet)):
raise TypeError(
"Non-deterministic collections here make"
" toposort non-deterministic.")
deps_cache[io] = list(d)
deps_cache[io] = list(d)
else:
deps_cache[io] = d
return d
else:
deps_cache[io] = d
return d
else:
return deps_cache[io]
return deps_cache[io]
assert isinstance(r_out, (tuple, list, deque))
......@@ -786,26 +794,54 @@ def io_toposort(inputs, outputs, orderings=None):
order. no sets allowed!
"""
if orderings is None:
orderings = {}
# the inputs are used only here in the function that decides what 'predecessors' to explore
iset = set(inputs)
def deps(obj):
rval = []
if obj not in iset:
if isinstance(obj, Variable):
if obj.owner:
rval = [obj.owner]
elif isinstance(obj, Apply):
rval = list(obj.inputs)
rval.extend(orderings.get(obj, []))
else:
assert not orderings.get(obj, [])
return rval
# We build 2 functions as a speed up
deps_cache = {}
deps = None
_deps = None
if not orderings: # can be None or empty dict
# Specialized function that is faster when no ordering.
# Also include the cache in the function itself for speed up.
def _deps(obj):
if obj in deps_cache:
return deps_cache[io]
rval = []
if obj not in iset:
if isinstance(obj, Variable):
if obj.owner:
rval = [obj.owner]
elif isinstance(obj, Apply):
rval = list(obj.inputs)
if rval:
if not isinstance(rval, (list, OrderedSet)):
raise TypeError(
"Non-deterministic collections here make"
" toposort non-deterministic.")
deps_cache[obj] = list(rval)
else:
deps_cache[obj] = rval
else:
deps_cache[obj] = rval
return rval
else:
def deps(obj):
rval = []
if obj not in iset:
if isinstance(obj, Variable):
if obj.owner:
rval = [obj.owner]
elif isinstance(obj, Apply):
rval = list(obj.inputs)
rval.extend(orderings.get(obj, []))
else:
assert not orderings.get(obj, [])
return rval
topo = general_toposort(outputs, deps)
topo = general_toposort(outputs, deps=deps, _deps=_deps,
deps_cache=deps_cache)
return [o for o in topo if isinstance(o, Apply)]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论