提交 55dde830 authored 作者: James Bergstra's avatar James Bergstra

destroyhandler - improved performance with inlined toposort

上级 0626f521
...@@ -7,6 +7,7 @@ if sys.version_info[:2] >= (2,5): ...@@ -7,6 +7,7 @@ if sys.version_info[:2] >= (2,5):
import toolbox import toolbox
import graph import graph
from theano.gof import deque
from env import InconsistencyError from env import InconsistencyError
...@@ -44,6 +45,92 @@ class DestroyHandler(object): ...@@ -44,6 +45,92 @@ class DestroyHandler(object):
return self.map[env].orderings(env) return self.map[env].orderings(env)
def _dfs_toposort(i, r_out, orderings):
"""
i - list of inputs
o - list of outputs
orderings - dict of additions to the normal inputs and outputs
Returns nothing. Raises exception for graph with cycles
"""
#this is hard-coded reimplementation of functions from graph.py
# reason: go faster, prepare for port to C.
assert isinstance(r_out, (tuple, list, deque))
# TODO: For more speed - use a defaultdict for the orderings
iset = set(i)
if 0:
def expand(obj):
rval = []
if obj not in iset:
if isinstance(obj, graph.Variable):
if obj.owner:
rval = [obj.owner]
if isinstance(obj, graph.Apply):
rval = list(obj.inputs)
rval.extend(orderings.get(obj, []))
else:
assert not orderings.get(obj, [])
return rval
expand_cache = {}
# reachable, clients = stack_search( deque(r_out), deps, 'dfs', True)
start=deque(r_out)
rval_set = set()
rval_set.add(id(None))
rval_list = list()
expand_inv = {}
sources = deque()
while start:
l = start.pop()# this makes the search dfs
if id(l) not in rval_set:
rval_list.append(l)
rval_set.add(id(l))
if l in iset:
assert not orderings.get(l, [])
expand_l = []
else:
try:
if l.owner:
expand_l = [l.owner]
else:
expand_l = []
except AttributeError:
expand_l = list(l.inputs)
expand_l.extend(orderings.get(l, []))
if expand_l:
for r in expand_l:
expand_inv.setdefault(r, []).append(l)
start.extend(expand_l)
else:
sources.append(l)
expand_cache[l] = expand_l
assert len(rval_list) == len(rval_set)-1
rset = set()
rlist = []
while sources:
node = sources.popleft()
if node not in rset:
rlist.append(node)
rset.add(node)
for client in expand_inv.get(node, []):
expand_cache[client] = [a for a in expand_cache[client] if a is not node]
if not expand_cache[client]:
sources.append(client)
if len(rlist) != len(rval_list):
raise ValueError('graph contains cycles')
#return [o for o in rlist if isinstance(o, graph.Apply)]
def getroot(r, view_i): def getroot(r, view_i):
""" """
For views: Return non-view variable which is ultimatly viewed by r. For views: Return non-view variable which is ultimatly viewed by r.
...@@ -303,7 +390,8 @@ class DestroyHandlerHelper2(toolbox.Bookkeeper): ...@@ -303,7 +390,8 @@ class DestroyHandlerHelper2(toolbox.Bookkeeper):
raise raise
#print 'orderings:', ords #print 'orderings:', ords
try: try:
graph.io_toposort(env.inputs, env.outputs, ords) ### graph.io_toposort(env.inputs, env.outputs, ords)
_dfs_toposort(env.inputs, env.outputs, ords)
except ValueError, e: except ValueError, e:
#print 'not passing.', ords #print 'not passing.', ords
if 'cycles' in str(e): if 'cycles' in str(e):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论