提交 703de495 authored 作者: Frederic Bastien's avatar Frederic Bastien

changed CLinker.cmodule_key to do the real computation in a static method to…

changed CLinker.cmodule_key to do the real computation in a static method to allow calling it from elsewhere.
上级 69ec63f0
...@@ -792,14 +792,24 @@ class CLinker(link.Linker): ...@@ -792,14 +792,24 @@ class CLinker(link.Linker):
function raises a KeyError exception. function raises a KeyError exception.
""" """
order = list(self.env.toposort()) return self.cmodule_key_(self.env, self.no_recycling,
env_inputs_dict = dict((i, [-1, pos]) for pos, i in enumerate(self.env.inputs)) compile_args=self.compile_args(),
libraries=self.libraries()
)
@staticmethod
def cmodule_key_(env, no_recycling, compile_args=None, libraries=None):
"""
Do the actual computation of cmodule_key in a static method
to allow it to be reused in scalar.Composite.__eq__
"""
order = list(env.toposort())
env_computed_set = set() env_computed_set = set()
env_inputs_dict = dict((i, [-1, pos]) for pos, i in enumerate(env.inputs))
constant_ids = dict() constant_ids = dict()
op_pos = {} # Apply -> topological position op_pos = {} # Apply -> topological position
rval = ['CLinker.cmodule_key'] # will be cast to tuple on return rval = ['CLinker.cmodule_key'] # will be cast to tuple on return
rval.append(tuple(self.compile_args())) if compile_args is not None: rval.append(tuple(compile_args))
rval.append(tuple(self.libraries())) if libraries is not None: rval.append(tuple(libraries))
version = [] version = []
# assert that every input to every node is one of' # assert that every input to every node is one of'
...@@ -822,16 +832,16 @@ class CLinker(link.Linker): ...@@ -822,16 +832,16 @@ class CLinker(link.Linker):
else: else:
if i.owner is None: if i.owner is None:
assert all( all(out is not None for out in o.outputs) for o in order) assert all( all(out is not None for out in o.outputs) for o in order)
assert all( input.owner is None for input in self.env.inputs) assert all( input.owner is None for input in env.inputs)
raise Exception('what is this?', (i, type(i), i.clients, self.env)) raise Exception('what is this?', (i, type(i), i.clients, env))
if i in self.env.outputs: if i in env.outputs:
rval += [op_pos[i.owner], # outputs rval += [op_pos[i.owner], # outputs
i.owner.outputs.index(i), i.owner.outputs.index(i),
self.env.outputs.index(i)] env.outputs.index(i)]
else: else:
rval += [op_pos[i.owner], i.owner.outputs.index(i)] # temps rval += [op_pos[i.owner], i.owner.outputs.index(i)] # temps
assert rval assert rval
rval.append(i in self.no_recycling) rval.append(i in no_recycling)
return tuple(rval) return tuple(rval)
for node_pos, node in enumerate(order): for node_pos, node in enumerate(order):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论