CLinker tolerates duplicate inputs

上级 d676edaf
...@@ -234,16 +234,12 @@ class CLinker(Linker): ...@@ -234,16 +234,12 @@ class CLinker(Linker):
env = self.env env = self.env
self.inputs = env.inputs self.inputs = env.inputs
if len(set(self.inputs)) != len(self.inputs):
raise Exception("CLinker doesn't support duplicate inputs.")
self.outputs = env.outputs self.outputs = env.outputs
if len(set(self.outputs)) != len(self.outputs):
raise Exception("CLinker doesn't support duplicate outputs.")
try: self.results = list(env.results()) try: self.results = list(env.results())
except AttributeError: self.results = self.inputs + self.outputs except AttributeError: self.results = self.inputs + self.outputs
try: self.orphans = list(env.orphans()) try: self.orphans = list(env.orphans().difference(self.outputs))
except AttributeError: self.orphans = [] except AttributeError: self.orphans = []
try: self.temps = list(set(self.results).difference(self.inputs).difference(self.outputs).difference(self.orphans)) try: self.temps = list(set(self.results).difference(self.inputs).difference(self.outputs).difference(self.orphans))
...@@ -274,7 +270,8 @@ class CLinker(Linker): ...@@ -274,7 +270,8 @@ class CLinker(Linker):
sub = dict(failure_var = failure_var) sub = dict(failure_var = failure_var)
for result in self.results: for result in set(self.results):
if getattr(result, 'constant', False): if getattr(result, 'constant', False):
if result in self.outputs or result in self.temps: if result in self.outputs or result in self.temps:
raise Exception("Temporaries and outputs should not be marked constant. Check your graph.") raise Exception("Temporaries and outputs should not be marked constant. Check your graph.")
...@@ -366,7 +363,7 @@ class CLinker(Linker): ...@@ -366,7 +363,7 @@ class CLinker(Linker):
args = [] args = []
in_arg_order = [] in_arg_order = []
args += ["storage_%s" % symbol[result] for result in self.inputs + self.outputs + self.orphans] args += ["storage_%s" % symbol[result] for result in utils.uniq(self.inputs + self.outputs + self.orphans)]
struct_code = struct_gen(args, init_blocks, blocks, dict(failure_var = failure_var)) struct_code = struct_gen(args, init_blocks, blocks, dict(failure_var = failure_var))
...@@ -384,6 +381,8 @@ class CLinker(Linker): ...@@ -384,6 +381,8 @@ class CLinker(Linker):
self.init_tasks = init_tasks self.init_tasks = init_tasks
self.blocks = blocks self.blocks = blocks
self.tasks = tasks self.tasks = tasks
all = self.inputs + self.outputs + self.orphans
self.dupidx = [i for i, x in enumerate(all) if all.count(x) > 1 and all.index(x) != i]
def find_task(self, failure_code): def find_task(self, failure_code):
...@@ -445,29 +444,32 @@ class CLinker(Linker): ...@@ -445,29 +444,32 @@ class CLinker(Linker):
raise error_storage[0], error_storage[1] + " " + str(self.find_task(failure - 1)) raise error_storage[0], error_storage[1] + " " + str(self.find_task(failure - 1))
return execute, in_results, out_results return execute, in_results, out_results
def make_function(self, inplace = False, unpack_single = True): # def make_function(self, inplace = False, unpack_single = True):
cthunk, in_results, out_results, error_storage = self.__compile__(inplace) # cthunk, in_results, out_results, error_storage = self.__compile__(inplace)
# out_storage = [result._data for result in out_results] # # out_storage = [result._data for result in out_results]
def execute(*args): # def execute(*args):
for arg, result in zip(args, in_results): # for arg, result in zip(args, in_results):
result.data = arg # result.data = arg
failure = cutils.run_cthunk(cthunk) # failure = cutils.run_cthunk(cthunk)
if failure: # if failure:
raise error_storage[0], error_storage[1] + " " + str(self.find_task(failure - 1)) # raise error_storage[0], error_storage[1] + " " + str(self.find_task(failure - 1))
if unpack_single: # if unpack_single:
return utils.to_return_values([result.data for result in out_results]) # return utils.to_return_values([result.data for result in out_results])
else: # else:
return [result.data for result in out_results] # return [result.data for result in out_results]
# return utils.to_return_values([storage[0] for storage in out_storage]) # # return utils.to_return_values([storage[0] for storage in out_storage])
return execute # return execute
def cthunk_factory(self, error_storage, in_storage, out_storage): def cthunk_factory(self, error_storage, in_storage, out_storage):
if not getattr(self, 'instantiate', False): if not getattr(self, 'instantiate', False):
self.code_gen() self.code_gen()
out_storage = [x for i, x in enumerate(out_storage) if (i+len(in_storage)) not in self.dupidx]
in_storage = [x for i, x in enumerate(in_storage) if i not in self.dupidx]
cthunk = object() cthunk = object()
module_name = self.hash module_name = self.hash
mod = weave.ext_tools.ext_module(module_name) mod = weave.ext_tools.ext_module(module_name)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论