traceback is now stored in the Ops and printed out by a custom excepthook when an exception occurs

上级 da3a7ffd
import unittest
import numpy
from tensor import tensor, Tensor
import gof
from gof import modes, Env
from elemwise import *
class ElemwiseAdd(Elemwise):
def var_desc(self):
return [('x', 1), ('y', 1)], [('z', 1)]
# def destroy_map(self):
# return {self.out: [self.inputs[0]]}
def c_code_foreach(self):
return "%(z)s_i = %(x)s_i + %(y)s_i;"
def inputs():
l1 = [[1.0, 2.0], [3.0, 4.0]]
l2 = [[3.0, 4.0], [1.0, 2.0]]
l3 = numpy.ones((2, 3))
x = modes.build(tensor(l1, name = 'x'))
y = modes.build(tensor(l2, name = 'y'))
z = modes.build(tensor(l3, name = 'z'))
return x, y, z
def env(inputs, outputs, validate = True, features = []):
return Env(inputs, outputs, features = features, consistency_check = validate)
class _test_Elemwise(unittest.TestCase):
def test_0(self):
x, y, z = inputs()
e = ElemwiseAdd(x, y).out
fn, i, o = gof.CLinker(env([x, y], [e])).make_thunk(True)
fn()
assert (e.data == numpy.array([[4, 6], [4, 6]])).all()
x.data.resize((1, 4))
y.data.resize((1, 4))
fn()
assert (e.data == numpy.array([[4, 6, 4, 6]])).all()
if __name__ == '__main__':
unittest.main()
...@@ -31,7 +31,7 @@ class Double(ResultBase): ...@@ -31,7 +31,7 @@ class Double(ResultBase):
return """ return """
%(name)s = 0; %(name)s = 0;
%(name)s_bad_thing = malloc(100000); %(name)s_bad_thing = malloc(100000);
printf("Initializing %(name)s\\n"); //printf("Initializing %(name)s\\n");
""" """
def c_literal(self): def c_literal(self):
...@@ -45,7 +45,7 @@ class Double(ResultBase): ...@@ -45,7 +45,7 @@ class Double(ResultBase):
} }
%(name)s = PyFloat_AsDouble(py_%(name)s); %(name)s = PyFloat_AsDouble(py_%(name)s);
%(name)s_bad_thing = NULL; %(name)s_bad_thing = NULL;
printf("Extracting %(name)s\\n"); //printf("Extracting %(name)s\\n");
""" """
def c_sync(self): def c_sync(self):
...@@ -54,12 +54,12 @@ class Double(ResultBase): ...@@ -54,12 +54,12 @@ class Double(ResultBase):
py_%(name)s = PyFloat_FromDouble(%(name)s); py_%(name)s = PyFloat_FromDouble(%(name)s);
if (!py_%(name)s) if (!py_%(name)s)
py_%(name)s = Py_None; py_%(name)s = Py_None;
printf("Syncing %(name)s\\n"); //printf("Syncing %(name)s\\n");
""" """
def c_cleanup(self): def c_cleanup(self):
return """ return """
printf("Cleaning up %(name)s\\n"); //printf("Cleaning up %(name)s\\n");
if (%(name)s_bad_thing) if (%(name)s_bad_thing)
free(%(name)s_bad_thing); free(%(name)s_bad_thing);
""" """
...@@ -101,6 +101,13 @@ class Mul(Binary): ...@@ -101,6 +101,13 @@ class Mul(Binary):
return "%(z)s = %(x)s * %(y)s;" return "%(z)s = %(x)s * %(y)s;"
class Div(Binary): class Div(Binary):
def c_validate_update(self):
return """
if (%(y)s == 0.0) {
PyErr_SetString(PyExc_ZeroDivisionError, "division by zero");
%(fail)s
}
"""
def c_code(self): def c_code(self):
return "%(z)s = %(x)s / %(y)s;" return "%(z)s = %(x)s / %(y)s;"
...@@ -115,48 +122,62 @@ def inputs(): ...@@ -115,48 +122,62 @@ def inputs():
return x, y, z return x, y, z
def env(inputs, outputs, validate = True, features = []): def env(inputs, outputs, validate = True, features = []):
# inputs = [input.r for input in inputs]
# outputs = [output.r for output in outputs]
return Env(inputs, outputs, features = features, consistency_check = validate) return Env(inputs, outputs, features = features, consistency_check = validate)
class _test_CLinker(unittest.TestCase): class _test_CLinker(unittest.TestCase):
def test_0(self): def test_straightforward(self):
x, y, z = inputs() x, y, z = inputs()
e = add(mul(add(x, y), div(x, y)), sub(sub(x, y), z)) e = add(mul(add(x, y), div(x, y)), sub(sub(x, y), z))
lnk = CLinker(env([x, y, z], [e])) #, [x.r, y.r, z.r], [e.r]) lnk = CLinker(env([x, y, z], [e]))
cgen = lnk.code_gen() fn = lnk.make_function()
fn = lnk.make_function() #[x.r, y.r, z.r], [e.r]) self.failUnless(fn(2.0, 2.0, 2.0) == 2.0)
print fn(2.0, 2.0, 2.0)
# fn = 0
def test_1(self): def test_orphan(self):
x, y, z = inputs() x, y, z = inputs()
z.constant = True z.data = 4.12345678
e = add(mul(add(x, y), div(x, y)), sub(sub(x, y), z)) e = add(mul(add(x, y), div(x, y)), sub(sub(x, y), z))
lnk = CLinker(env([x, y], [e])) #, [x.r, y.r], [e.r]) lnk = CLinker(env([x, y], [e]))
cgen = lnk.code_gen() fn = lnk.make_function()
fn = lnk.make_function() #[x.r, y.r], [e.r]) self.failUnless(abs(fn(2.0, 2.0) + 0.12345678) < 1e-9)
print fn(2.0, 2.0) self.failUnless("4.12345678" not in lnk.code_gen()) # we do not expect the number to be inlined
# fn = 0
def test_2(self): def test_literal_inlining(self):
x, y, z = inputs()
z.data = 4.12345678
z.constant = True # this should tell the compiler to inline z as a literal
e = add(mul(add(x, y), div(x, y)), sub(sub(x, y), z))
lnk = CLinker(env([x, y], [e]))
fn = lnk.make_function()
self.failUnless(abs(fn(2.0, 2.0) + 0.12345678) < 1e-9)
self.failUnless("4.12345678" in lnk.code_gen()) # we expect the number to be inlined
def test_single_op(self):
x, y, z = inputs() x, y, z = inputs()
op = Add(x, y) op = Add(x, y)
lnk = CLinker(op) lnk = CLinker(op)
cgen = lnk.code_gen() fn = lnk.make_function()
fn = lnk.make_function() #[x.r, y.r], [op.out]) self.failUnless(fn(2.0, 7.0) == 9)
print fn(2.0, 7.0)
# fn = 0 def test_dups(self):
# Testing that duplicate inputs are allowed.
x, y, z = inputs()
op = Add(x, x)
lnk = CLinker(op)
fn = lnk.make_function()
self.failUnless(fn(2.0, 2.0) == 4)
# note: for now the behavior of fn(2.0, 7.0) is undefined
class _test_OpWiseCLinker(unittest.TestCase):
def test_3(self): def test_straightforward(self):
x, y, z = inputs() x, y, z = inputs()
e = add(mul(add(x, y), div(x, y)), sub(sub(x, y), z)) e = add(mul(add(x, y), div(x, y)), sub(sub(x, y), z))
lnk = OpWiseCLinker(env([x, y, z], [e])) lnk = OpWiseCLinker(env([x, y, z], [e]))
fn = lnk.make_function() fn = lnk.make_function()
print fn(2.0, 2.0, 2.0) self.failUnless(fn(2.0, 2.0, 2.0) == 2.0)
# fn = 0
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -9,6 +9,7 @@ import platform ...@@ -9,6 +9,7 @@ import platform
from scipy import weave from scipy import weave
import cutils import cutils
import utils import utils
import traceback
def compile_dir(): def compile_dir():
...@@ -266,7 +267,7 @@ class CLinker(Linker): ...@@ -266,7 +267,7 @@ class CLinker(Linker):
blocks = [] blocks = []
failure_var = "__failure" failure_var = "__failure"
id = 0 id = 1
sub = dict(failure_var = failure_var) sub = dict(failure_var = failure_var)
...@@ -319,10 +320,10 @@ class CLinker(Linker): ...@@ -319,10 +320,10 @@ class CLinker(Linker):
builder, block = struct_result_codeblocks(result, policy, id, symbol, sub) builder, block = struct_result_codeblocks(result, policy, id, symbol, sub)
init_tasks.append((result, 'init')) init_tasks.append((result, 'init', id))
init_blocks.append(builder) init_blocks.append(builder)
tasks.append((result, 'get')) tasks.append((result, 'get', id + 1))
blocks.append(block) blocks.append(block)
id += 2 id += 2
...@@ -345,7 +346,7 @@ class CLinker(Linker): ...@@ -345,7 +346,7 @@ class CLinker(Linker):
sub['id'] = id sub['id'] = id
blocks.append(CodeBlock("", validate_behavior, validate_cleanup, sub)) blocks.append(CodeBlock("", validate_behavior, validate_cleanup, sub))
tasks.append((op, 'validate_update')) tasks.append((op, 'validate_update', id))
id += 1 id += 1
# c_code # c_code
...@@ -357,7 +358,7 @@ class CLinker(Linker): ...@@ -357,7 +358,7 @@ class CLinker(Linker):
sub['id'] = id sub['id'] = id
blocks.append(CodeBlock("", behavior, cleanup, sub)) blocks.append(CodeBlock("", behavior, cleanup, sub))
tasks.append((op, 'code')) tasks.append((op, 'code', id))
id += 1 id += 1
args = [] args = []
...@@ -386,6 +387,7 @@ class CLinker(Linker): ...@@ -386,6 +387,7 @@ class CLinker(Linker):
def find_task(self, failure_code): def find_task(self, failure_code):
failure_code -= 1
n = len(self.init_tasks) n = len(self.init_tasks)
if failure_code < 2 * n: if failure_code < 2 * n:
return [self.init_tasks, self.tasks][failure_code % 2][failure_code/2] return [self.init_tasks, self.tasks][failure_code % 2][failure_code/2]
...@@ -441,7 +443,21 @@ class CLinker(Linker): ...@@ -441,7 +443,21 @@ class CLinker(Linker):
def execute(): def execute():
failure = cutils.run_cthunk(cthunk) failure = cutils.run_cthunk(cthunk)
if failure: if failure:
raise error_storage[0], error_storage[1] + " " + str(self.find_task(failure - 1)) task, taskname, id = self.find_task(failure)
#exc = traceback.format_exception_only(error_storage[0], error_storage[1])
try:
trace = task.trace
except AttributeError:
trace = ()
class X:pass
__x = X()
__x.__thunk_trace__ = trace
__x.__str__ = lambda: str(error_storage[1])
raise error_storage[0], __x
## raise ThunkException, (error_storage[0], error_storage[1], trace)
# for stack_element in traceback.format_list(trace):
# print >>sys.stderr, stack_element,
# raise error_storage[0], error_storage[1] + " (error occurred in: " + str(task) + ")"
return execute, in_results, out_results return execute, in_results, out_results
# def make_function(self, inplace = False, unpack_single = True): # def make_function(self, inplace = False, unpack_single = True):
......
...@@ -4,6 +4,50 @@ ...@@ -4,6 +4,50 @@
from utils import AbstractFunctionError from utils import AbstractFunctionError
import utils import utils
import sys
import traceback
__excepthook = sys.excepthook
def thunk_hook(type, value, trace):
if len(value.args) > 0 and hasattr(value[0], '__thunk_trace__'):
# such a hack :(
trace2 = value[0].__thunk_trace__ #.exc_info
print>>sys.stderr, "Definition in: "
for line in traceback.format_list(trace2):
print>>sys.stderr, line,
__excepthook(type, value, trace)
sys.excepthook = thunk_hook
class Thunk:
def __init__(self):
self.results = None
self.is_valid = False
self.exc_info = ()
self.inputs = []
self.outputs = []
def call_thunk(self):
raise AbstractFunctionError
def exc_print(self, f = sys.stderr):
if self.is_valid:
return
type, value, trace = self.exc_info
for line in traceback.format_list(trace):
print>>f, line,
print>>f, traceback.format_exception_only(type, value)
def call_thunk_and_raise(self):
self.call_thunk()
if not self.is_valid:
type, value, trace = self.exc_info
raise self.type, self.value
def __call__(self, *inputs):
raise AbstractFunctionError
class Linker: class Linker:
......
import utils import utils
import traceback
from op import Op from op import Op
__all__ = ['ModalConstructor', __all__ = ['ModalConstructor',
...@@ -48,15 +49,21 @@ def add_modal_members(cls, *members): ...@@ -48,15 +49,21 @@ def add_modal_members(cls, *members):
setattr(cls, member, fn(member)) setattr(cls, member, fn(member))
def attach_trace(op):
stack = traceback.extract_stack()[:-3]
op.trace = stack
def build_mode(op): def build_mode(op):
pass attach_trace(op)
def eval_mode(op): def eval_mode(op):
attach_trace(op)
op.perform() op.perform()
for output in op.outputs: for output in op.outputs:
output._role = None output._role = None
def build_eval_mode(op): def build_eval_mode(op):
attach_trace(op)
op.perform() op.perform()
......
...@@ -14,6 +14,13 @@ class AbstractFunctionError(Exception): ...@@ -14,6 +14,13 @@ class AbstractFunctionError(Exception):
function has been left out of an implementation class. function has been left out of an implementation class.
""" """
def uniq(seq):
return [x for i, x in enumerate(seq) if seq.index(x) == i]
def difference(seq1, seq2):
return [x for x in seq1 if x not in seq2]
def attr_checker(*attrs): def attr_checker(*attrs):
def f(candidate): def f(candidate):
for attr in attrs: for attr in attrs:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论