提交 6818d012 authored 作者: James Bergstra's avatar James Bergstra

refactored orderings calculation in Env

上级 98213004
...@@ -311,6 +311,9 @@ class Env(utils.object2): ...@@ -311,6 +311,9 @@ class Env(utils.object2):
self.__import_r__([new_r]) self.__import_r__([new_r])
self.__add_clients__(new_r, [(node, i)]) self.__add_clients__(new_r, [(node, i)])
prune = self.__remove_clients__(r, [(node, i)], False) prune = self.__remove_clients__(r, [(node, i)], False)
# Precondition: the substitution is semantically valid
# However it may introduce cycles to the graph, in which case the
# transaction will be reverted later.
self.execute_callbacks('on_change_input', node, i, r, new_r, reason=reason) self.execute_callbacks('on_change_input', node, i, r, new_r, reason=reason)
if prune: if prune:
...@@ -438,16 +441,32 @@ class Env(utils.object2): ...@@ -438,16 +441,32 @@ class Env(utils.object2):
if len(self.nodes) < 2: if len(self.nodes) < 2:
# optimization # optimization
# when there are 0 or 1 nodes, no sorting is necessary # when there are 0 or 1 nodes, no sorting is necessary
# This special case happens a lot because the OpWiseCLinker produces
# 1-element graphs.
return list(self.nodes) return list(self.nodes)
env = self env = self
ords = {} ords = self.orderings()
for feature in env._features:
if hasattr(feature, 'orderings'):
for op, prereqs in feature.orderings(env).items():
ords.setdefault(op, []).extend(prereqs)
order = graph.io_toposort(env.inputs, env.outputs, ords) order = graph.io_toposort(env.inputs, env.outputs, ords)
return order return order
def orderings(self):
"""
Return dict d s.t. d[node] is a list of nodes that must be evaluated
before node itself can be evaluated.
This is used primarily by the destroy_handler feature to ensure that all
clients of any destroyed inputs have already computed their outputs.
"""
ords = {}
for feature in self._features:
if hasattr(feature, 'orderings'):
for node, prereqs in feature.orderings(self).items():
ords.setdefault(node, []).extend(prereqs)
# eliminate duplicate prereqs
for (node,prereqs) in ords.items():
ords[node] = list(set(prereqs))
return ords
def nclients(self, r): def nclients(self, r):
"""WRITEME Same as len(self.clients(r)).""" """WRITEME Same as len(self.clients(r))."""
return len(self.clients(r)) return len(self.clients(r))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论