提交 b75cf2e1 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Merge pull request #2823 from nouiz/faster_test

finish gh-2768
......@@ -717,7 +717,7 @@ def clone_get_equiv(inputs, outputs,
def general_toposort(r_out, deps, debug_print=False,
_deps=None, deps_cache=None):
compute_deps_cache=None, deps_cache=None):
"""WRITEME
:note:
......@@ -729,15 +729,24 @@ def general_toposort(r_out, deps, debug_print=False,
:note:
The order of the return value list is determined by the order of nodes returned by the deps() function.
:param deps: a python function that take a node as input and
return its dependence.
:param compute_deps_cache: Optional,
if provided deps_cache should also be provided. This is a
function like deps, but that also cache its results in a dict
passed as deps_cache.
:param deps_cache: a dict. Must be used with compute_deps_cache.
:note: deps should be provided or can be None and the caller
provide _deps and deps_cache. The second option remove a
Python function call, so is faster.
provide compute_deps_cache and deps_cache. The second option
remove a Python function call, and allow for more specialized
code, so it can be faster.
"""
if _deps is None:
if compute_deps_cache is None:
deps_cache = {}
def _deps(io):
def compute_deps_cache(io):
if io not in deps_cache:
d = deps(io)
if d:
......@@ -751,10 +760,12 @@ def general_toposort(r_out, deps, debug_print=False,
return d
else:
return deps_cache[io]
assert deps_cache is not None
assert isinstance(r_out, (tuple, list, deque))
reachable, clients = stack_search(deque(r_out), _deps, 'dfs', True)
reachable, clients = stack_search(deque(r_out), compute_deps_cache,
'dfs', True)
sources = deque([r for r in reachable if not deps_cache.get(r, None)])
rset = set()
......@@ -800,12 +811,12 @@ def io_toposort(inputs, outputs, orderings=None):
# We build 2 functions as a speed up
deps_cache = {}
deps = None
_deps = None
compute_deps = None
compute_deps_cache = None
if not orderings: # can be None or empty dict
# Specialized function that is faster when no ordering.
# Also include the cache in the function itself for speed up.
def _deps(obj):
def compute_deps_cache(obj):
if obj in deps_cache:
return deps_cache[io]
rval = []
......@@ -827,7 +838,7 @@ def io_toposort(inputs, outputs, orderings=None):
deps_cache[obj] = rval
return rval
else:
def deps(obj):
def compute_deps(obj):
rval = []
if obj not in iset:
if isinstance(obj, Variable):
......@@ -840,7 +851,8 @@ def io_toposort(inputs, outputs, orderings=None):
assert not orderings.get(obj, [])
return rval
topo = general_toposort(outputs, deps=deps, _deps=_deps,
topo = general_toposort(outputs, deps=compute_deps,
compute_deps_cache=compute_deps_cache,
deps_cache=deps_cache)
return [o for o in topo if isinstance(o, Apply)]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论