提交 c0fd55c9 authored 作者: James Bergstra's avatar James Bergstra

Merge pull request #44 from goodfeli/refactor_execute

Refactored _execute to make it more readable (using class instead of closure)
......@@ -784,7 +784,7 @@ class CLinker(link.Linker):
init_tasks, tasks = self.get_init_tasks()
cthunk, in_storage, out_storage, error_storage = self.__compile__(input_storage, output_storage,
keep_lock=keep_lock)
res = _execute(cthunk, init_tasks, tasks, error_storage), in_storage, out_storage
res = _CThunk(cthunk, init_tasks, tasks, error_storage), in_storage, out_storage
return res
def cmodule_key(self):
......@@ -1110,48 +1110,65 @@ class CLinker(link.Linker):
print >> code, " return thunk; }"
return code.getvalue()
def _execute(cthunk, init_tasks, tasks, error_storage):
"""WRITEME"""
global run_cthunk
if run_cthunk is None:
# Lazy import to avoid compilation when importing theano.
from theano.gof.cutils import run_cthunk
class _CThunk(object):
"""
A thunk with a C implementation
"""
def find_task(failure_code):
def __init__(self, cthunk, init_tasks, tasks, error_storage):
"""
Parameters
----------
cthunk: the CObject pointer used by run_cthunk
init_tasks: WRITEME
tasks: WRITEME
error_storage: WRITEME
"""
global run_cthunk
if run_cthunk is None:
# Lazy import to avoid compilation when importing theano.
from theano.gof.cutils import run_cthunk
self.cthunk = cthunk
self.init_tasks = init_tasks
self.tasks = tasks
self.error_storage = error_storage
def find_task(self, failure_code):
"""
Maps a failure code to the task that is associated to it.
"""
failure_code -= 1
n = len(init_tasks)
n = len(self.init_tasks)
# note that the failure code is distributed in two lists
if failure_code < 2 * n:
return [init_tasks, tasks][failure_code % 2][failure_code/2]
return [self.init_tasks, self.tasks][failure_code % 2][failure_code/2]
else:
return tasks[failure_code - n]
def execute():
failure = run_cthunk(cthunk)
return self.tasks[failure_code - n]
def __call__(self):
failure = run_cthunk(self.cthunk)
if failure:
task, taskname, id = find_task(failure)
task, taskname, id = self.find_task(failure)
try:
trace = task.trace
except AttributeError:
trace = ()
try:
exc_type, _exc_value, exc_trace = error_storage
exc_type, _exc_value, exc_trace = self.error_storage
if hasattr(task, "outputs"):
exc_value = exc_type(_exc_value, task, task.outputs)
else:
exc_value = exc_type(_exc_value, task)
exc_value.__thunk_trace__ = trace # this can be used to retrieve the location the Op was declared
except Exception:
print >> sys.stderr, 'ERROR retrieving error_storage', error_storage
print >> sys.stderr, 'ERROR retrieving error_storage', self.error_storage
raise
raise exc_type, exc_value, exc_trace
execute.cthunk = cthunk
return execute
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论