提交 b6834f2f authored 作者: Olivier Breuleux's avatar Olivier Breuleux

fixed bugs in env and graph

上级 6092703d
...@@ -384,8 +384,8 @@ class Env(utils.object2): ...@@ -384,8 +384,8 @@ class Env(utils.object2):
raise Exception("Input of node should belong to the env.", result, (node, i)) raise Exception("Input of node should belong to the env.", result, (node, i))
if (node, i) not in result.clients: if (node, i) not in result.clients:
raise Exception("Inconsistent clients list.", (node, i), result.clients) raise Exception("Inconsistent clients list.", (node, i), result.clients)
results = graph.results(self.inputs, self.outputs) results = set(graph.results(self.inputs, self.outputs))
if self.results != results: if set(self.results) != results:
missing = results.difference(self.results) missing = results.difference(self.results)
excess = self.results.difference(results) excess = self.results.difference(results)
raise Exception("The results are inappropriately cached. missing, in excess: ", missing, excess) raise Exception("The results are inappropriately cached. missing, in excess: ", missing, excess)
...@@ -414,9 +414,11 @@ class Env(utils.object2): ...@@ -414,9 +414,11 @@ class Env(utils.object2):
return self.clone_get_equiv()[0] return self.clone_get_equiv()[0]
def clone_get_equiv(self): def clone_get_equiv(self):
g, equiv = graph.clone_get_equiv(self.inputs, self.outputs) equiv = graph.clone_get_equiv(self.inputs, self.outputs)
self.check_integrity()
e = Env([equiv[i] for i in self.inputs], e = Env([equiv[i] for i in self.inputs],
[equiv[o] for o in self.outputs]) [equiv[o] for o in self.outputs])
e.check_integrity()
for feature in self._features: for feature in self._features:
e.extend(feature) e.extend(feature)
return e, equiv return e, equiv
......
...@@ -10,6 +10,36 @@ from env import InconsistencyError ...@@ -10,6 +10,36 @@ from env import InconsistencyError
class DestroyHandler(toolbox.Bookkeeper): class DestroyHandler(toolbox.Bookkeeper):
def __init__(self):
self.map = {}
def on_attach(self, env):
dh = self.map.setdefault(env, DestroyHandlerHelper())
dh.on_attach(env)
def on_detach(self, env):
self.map[env].on_detach(env)
def on_import(self, env, op):
self.map[env].on_import(env, op)
def on_prune(self, env, op):
self.map[env].on_prune(env, op)
def on_change_input(self, env, node, i, r, new_r):
self.map[env].on_change_input(env, node, i, r, new_r)
def validate(self, env):
self.map[env].validate(env)
def orderings(self, env):
return self.map[env].orderings(env)
class DestroyHandlerHelper(toolbox.Bookkeeper):
""" """
This feature ensures that an env represents a consistent data flow This feature ensures that an env represents a consistent data flow
when some Ops overwrite their inputs and/or provide "views" over when some Ops overwrite their inputs and/or provide "views" over
......
...@@ -241,7 +241,7 @@ def results_and_orphans(i, o): ...@@ -241,7 +241,7 @@ def results_and_orphans(i, o):
""" """
def expand(r): def expand(r):
if r.owner and r not in i: if r.owner and r not in i:
l = list(r.owner.inputs) l = list(r.owner.inputs) + list(r.owner.outputs)
l.reverse() l.reverse()
return l return l
results = stack_search(deque(o), expand, 'dfs') results = stack_search(deque(o), expand, 'dfs')
...@@ -316,7 +316,7 @@ def clone(i, o, copy_inputs = True): ...@@ -316,7 +316,7 @@ def clone(i, o, copy_inputs = True):
return [equiv[input] for input in i], [equiv[output] for output in o] return [equiv[input] for input in i], [equiv[output] for output in o]
def clone_get_equiv(i, o, copy_inputs_and_orphans = False): def clone_get_equiv(i, o, copy_inputs_and_orphans = True):
""" """
@type i: list @type i: list
@param i: input L{Result}s @param i: input L{Result}s
......
...@@ -264,7 +264,7 @@ class MetaLinker(Linker): ...@@ -264,7 +264,7 @@ class MetaLinker(Linker):
self.wrapper = wrapper self.wrapper = wrapper
self.no_recycling = no_recycling self.no_recycling = no_recycling
def pre(self, order, thunk_groups): def pre(self, f, inputs, order, thunk_groups):
pass pass
def make_thunk(self, **kwargs): def make_thunk(self, **kwargs):
...@@ -272,11 +272,15 @@ class MetaLinker(Linker): ...@@ -272,11 +272,15 @@ class MetaLinker(Linker):
You can pass an alternate env to use with the 'alt_env' You can pass an alternate env to use with the 'alt_env'
option. option.
The 'wrapf' option must be a function that will be used
to wrap the thunk (eg to add methods to it).
The rest of the options will be passed to all the linkers The rest of the options will be passed to all the linkers
associated with this MetaLinker. associated with this MetaLinker.
""" """
env = kwargs.pop("alt_env", self.env) env = kwargs.pop("alt_env", self.env)
wrapf = kwargs.pop("wrapf", None)
no_recycling = self.no_recycling no_recycling = self.no_recycling
fns, input_lists, output_lists, thunk_lists, order_lists = zip(*[linker(env, no_recycling = no_recycling).make_all(**kwargs) fns, input_lists, output_lists, thunk_lists, order_lists = zip(*[linker(env, no_recycling = no_recycling).make_all(**kwargs)
...@@ -308,13 +312,16 @@ class MetaLinker(Linker): ...@@ -308,13 +312,16 @@ class MetaLinker(Linker):
input2.storage[0] = copy(input1.storage[0]) input2.storage[0] = copy(input1.storage[0])
for x in to_reset: for x in to_reset:
x[0] = None x[0] = None
pre(inputs, order, thunk_groups) pre(f, [input.data for input in input_lists[0]], order, thunk_groups)
for i, (thunks, node) in enumerate(zip(thunk_groups, order)): for i, (thunks, node) in enumerate(zip(thunk_groups, order)):
try: try:
wrapper(i, node, *thunks) wrapper(f, i, node, *thunks)
except: except:
raise_with_op(node) raise_with_op(node)
if wrapf is not None:
f = wrapf(f)
return f, inputs0, outputs0 return f, inputs0, outputs0
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论