提交 b6834f2f authored 作者: Olivier Breuleux's avatar Olivier Breuleux

fixed bugs in env and graph

上级 6092703d
......@@ -384,8 +384,8 @@ class Env(utils.object2):
raise Exception("Input of node should belong to the env.", result, (node, i))
if (node, i) not in result.clients:
raise Exception("Inconsistent clients list.", (node, i), result.clients)
results = graph.results(self.inputs, self.outputs)
if self.results != results:
results = set(graph.results(self.inputs, self.outputs))
if set(self.results) != results:
missing = results.difference(self.results)
excess = self.results.difference(results)
raise Exception("The results are inappropriately cached. missing, in excess: ", missing, excess)
......@@ -414,9 +414,11 @@ class Env(utils.object2):
return self.clone_get_equiv()[0]
def clone_get_equiv(self):
g, equiv = graph.clone_get_equiv(self.inputs, self.outputs)
equiv = graph.clone_get_equiv(self.inputs, self.outputs)
self.check_integrity()
e = Env([equiv[i] for i in self.inputs],
[equiv[o] for o in self.outputs])
e.check_integrity()
for feature in self._features:
e.extend(feature)
return e, equiv
......
......@@ -10,6 +10,36 @@ from env import InconsistencyError
class DestroyHandler(toolbox.Bookkeeper):
def __init__(self):
self.map = {}
def on_attach(self, env):
dh = self.map.setdefault(env, DestroyHandlerHelper())
dh.on_attach(env)
def on_detach(self, env):
self.map[env].on_detach(env)
def on_import(self, env, op):
self.map[env].on_import(env, op)
def on_prune(self, env, op):
self.map[env].on_prune(env, op)
def on_change_input(self, env, node, i, r, new_r):
self.map[env].on_change_input(env, node, i, r, new_r)
def validate(self, env):
self.map[env].validate(env)
def orderings(self, env):
return self.map[env].orderings(env)
class DestroyHandlerHelper(toolbox.Bookkeeper):
"""
This feature ensures that an env represents a consistent data flow
when some Ops overwrite their inputs and/or provide "views" over
......
......@@ -241,7 +241,7 @@ def results_and_orphans(i, o):
"""
def expand(r):
if r.owner and r not in i:
l = list(r.owner.inputs)
l = list(r.owner.inputs) + list(r.owner.outputs)
l.reverse()
return l
results = stack_search(deque(o), expand, 'dfs')
......@@ -316,7 +316,7 @@ def clone(i, o, copy_inputs = True):
return [equiv[input] for input in i], [equiv[output] for output in o]
def clone_get_equiv(i, o, copy_inputs_and_orphans = False):
def clone_get_equiv(i, o, copy_inputs_and_orphans = True):
"""
@type i: list
@param i: input L{Result}s
......
......@@ -264,7 +264,7 @@ class MetaLinker(Linker):
self.wrapper = wrapper
self.no_recycling = no_recycling
def pre(self, order, thunk_groups):
def pre(self, f, inputs, order, thunk_groups):
pass
def make_thunk(self, **kwargs):
......@@ -272,11 +272,15 @@ class MetaLinker(Linker):
You can pass an alternate env to use with the 'alt_env'
option.
The 'wrapf' option must be a function that will be used
to wrap the thunk (eg to add methods to it).
The rest of the options will be passed to all the linkers
associated with this MetaLinker.
"""
env = kwargs.pop("alt_env", self.env)
wrapf = kwargs.pop("wrapf", None)
no_recycling = self.no_recycling
fns, input_lists, output_lists, thunk_lists, order_lists = zip(*[linker(env, no_recycling = no_recycling).make_all(**kwargs)
......@@ -308,12 +312,15 @@ class MetaLinker(Linker):
input2.storage[0] = copy(input1.storage[0])
for x in to_reset:
x[0] = None
pre(inputs, order, thunk_groups)
pre(f, [input.data for input in input_lists[0]], order, thunk_groups)
for i, (thunks, node) in enumerate(zip(thunk_groups, order)):
try:
wrapper(i, node, *thunks)
wrapper(f, i, node, *thunks)
except:
raise_with_op(node)
if wrapf is not None:
f = wrapf(f)
return f, inputs0, outputs0
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论