CLinker tolerates duplicate inputs

上级 d676edaf
......@@ -234,16 +234,12 @@ class CLinker(Linker):
env = self.env
self.inputs = env.inputs
if len(set(self.inputs)) != len(self.inputs):
raise Exception("CLinker doesn't support duplicate inputs.")
self.outputs = env.outputs
if len(set(self.outputs)) != len(self.outputs):
raise Exception("CLinker doesn't support duplicate outputs.")
try: self.results = list(env.results())
except AttributeError: self.results = self.inputs + self.outputs
try: self.orphans = list(env.orphans())
try: self.orphans = list(env.orphans().difference(self.outputs))
except AttributeError: self.orphans = []
try: self.temps = list(set(self.results).difference(self.inputs).difference(self.outputs).difference(self.orphans))
......@@ -273,8 +269,9 @@ class CLinker(Linker):
id = 0
sub = dict(failure_var = failure_var)
for result in set(self.results):
for result in self.results:
if getattr(result, 'constant', False):
if result in self.outputs or result in self.temps:
raise Exception("Temporaries and outputs should not be marked constant. Check your graph.")
......@@ -366,7 +363,7 @@ class CLinker(Linker):
args = []
in_arg_order = []
args += ["storage_%s" % symbol[result] for result in self.inputs + self.outputs + self.orphans]
args += ["storage_%s" % symbol[result] for result in utils.uniq(self.inputs + self.outputs + self.orphans)]
struct_code = struct_gen(args, init_blocks, blocks, dict(failure_var = failure_var))
......@@ -384,6 +381,8 @@ class CLinker(Linker):
self.init_tasks = init_tasks
self.blocks = blocks
self.tasks = tasks
all = self.inputs + self.outputs + self.orphans
self.dupidx = [i for i, x in enumerate(all) if all.count(x) > 1 and all.index(x) != i]
def find_task(self, failure_code):
......@@ -445,28 +444,31 @@ class CLinker(Linker):
raise error_storage[0], error_storage[1] + " " + str(self.find_task(failure - 1))
return execute, in_results, out_results
def make_function(self, inplace = False, unpack_single = True):
cthunk, in_results, out_results, error_storage = self.__compile__(inplace)
# out_storage = [result._data for result in out_results]
# def make_function(self, inplace = False, unpack_single = True):
# cthunk, in_results, out_results, error_storage = self.__compile__(inplace)
# # out_storage = [result._data for result in out_results]
def execute(*args):
for arg, result in zip(args, in_results):
result.data = arg
failure = cutils.run_cthunk(cthunk)
if failure:
raise error_storage[0], error_storage[1] + " " + str(self.find_task(failure - 1))
if unpack_single:
return utils.to_return_values([result.data for result in out_results])
else:
return [result.data for result in out_results]
# return utils.to_return_values([storage[0] for storage in out_storage])
return execute
# def execute(*args):
# for arg, result in zip(args, in_results):
# result.data = arg
# failure = cutils.run_cthunk(cthunk)
# if failure:
# raise error_storage[0], error_storage[1] + " " + str(self.find_task(failure - 1))
# if unpack_single:
# return utils.to_return_values([result.data for result in out_results])
# else:
# return [result.data for result in out_results]
# # return utils.to_return_values([storage[0] for storage in out_storage])
# return execute
def cthunk_factory(self, error_storage, in_storage, out_storage):
if not getattr(self, 'instantiate', False):
self.code_gen()
out_storage = [x for i, x in enumerate(out_storage) if (i+len(in_storage)) not in self.dupidx]
in_storage = [x for i, x in enumerate(in_storage) if i not in self.dupidx]
cthunk = object()
module_name = self.hash
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论