提交 c0fd55c9 authored 作者: James Bergstra's avatar James Bergstra

Merge pull request #44 from goodfeli/refactor_execute

Refactored _execute to make it more readable (using class instead of closure)
...@@ -784,7 +784,7 @@ class CLinker(link.Linker): ...@@ -784,7 +784,7 @@ class CLinker(link.Linker):
init_tasks, tasks = self.get_init_tasks() init_tasks, tasks = self.get_init_tasks()
cthunk, in_storage, out_storage, error_storage = self.__compile__(input_storage, output_storage, cthunk, in_storage, out_storage, error_storage = self.__compile__(input_storage, output_storage,
keep_lock=keep_lock) keep_lock=keep_lock)
res = _execute(cthunk, init_tasks, tasks, error_storage), in_storage, out_storage res = _CThunk(cthunk, init_tasks, tasks, error_storage), in_storage, out_storage
return res return res
def cmodule_key(self): def cmodule_key(self):
...@@ -1110,48 +1110,65 @@ class CLinker(link.Linker): ...@@ -1110,48 +1110,65 @@ class CLinker(link.Linker):
print >> code, " return thunk; }" print >> code, " return thunk; }"
return code.getvalue() return code.getvalue()
def _execute(cthunk, init_tasks, tasks, error_storage): class _CThunk(object):
"""WRITEME""" """
A thunk with a C implementation
"""
def __init__(self, cthunk, init_tasks, tasks, error_storage):
"""
Parameters
----------
cthunk: the CObject pointer used by run_cthunk
init_tasks: WRITEME
tasks: WRITEME
error_storage: WRITEME
"""
global run_cthunk global run_cthunk
if run_cthunk is None: if run_cthunk is None:
# Lazy import to avoid compilation when importing theano. # Lazy import to avoid compilation when importing theano.
from theano.gof.cutils import run_cthunk from theano.gof.cutils import run_cthunk
self.cthunk = cthunk
self.init_tasks = init_tasks
self.tasks = tasks
self.error_storage = error_storage
def find_task(failure_code): def find_task(self, failure_code):
""" """
Maps a failure code to the task that is associated to it. Maps a failure code to the task that is associated to it.
""" """
failure_code -= 1 failure_code -= 1
n = len(init_tasks) n = len(self.init_tasks)
# note that the failure code is distributed in two lists # note that the failure code is distributed in two lists
if failure_code < 2 * n: if failure_code < 2 * n:
return [init_tasks, tasks][failure_code % 2][failure_code/2] return [self.init_tasks, self.tasks][failure_code % 2][failure_code/2]
else: else:
return tasks[failure_code - n] return self.tasks[failure_code - n]
def execute():
failure = run_cthunk(cthunk) def __call__(self):
failure = run_cthunk(self.cthunk)
if failure: if failure:
task, taskname, id = find_task(failure) task, taskname, id = self.find_task(failure)
try: try:
trace = task.trace trace = task.trace
except AttributeError: except AttributeError:
trace = () trace = ()
try: try:
exc_type, _exc_value, exc_trace = error_storage exc_type, _exc_value, exc_trace = self.error_storage
if hasattr(task, "outputs"): if hasattr(task, "outputs"):
exc_value = exc_type(_exc_value, task, task.outputs) exc_value = exc_type(_exc_value, task, task.outputs)
else: else:
exc_value = exc_type(_exc_value, task) exc_value = exc_type(_exc_value, task)
exc_value.__thunk_trace__ = trace # this can be used to retrieve the location the Op was declared exc_value.__thunk_trace__ = trace # this can be used to retrieve the location the Op was declared
except Exception: except Exception:
print >> sys.stderr, 'ERROR retrieving error_storage', error_storage print >> sys.stderr, 'ERROR retrieving error_storage', self.error_storage
raise raise
raise exc_type, exc_value, exc_trace raise exc_type, exc_value, exc_trace
execute.cthunk = cthunk
return execute
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论