traceback is now stored in the Ops and printed out by a custom excepthook when an exception occurs

上级 da3a7ffd
import unittest
import numpy
from tensor import tensor, Tensor
import gof
from gof import modes, Env
from elemwise import *
class ElemwiseAdd(Elemwise):
def var_desc(self):
return [('x', 1), ('y', 1)], [('z', 1)]
# def destroy_map(self):
# return {self.out: [self.inputs[0]]}
def c_code_foreach(self):
return "%(z)s_i = %(x)s_i + %(y)s_i;"
def inputs():
l1 = [[1.0, 2.0], [3.0, 4.0]]
l2 = [[3.0, 4.0], [1.0, 2.0]]
l3 = numpy.ones((2, 3))
x = modes.build(tensor(l1, name = 'x'))
y = modes.build(tensor(l2, name = 'y'))
z = modes.build(tensor(l3, name = 'z'))
return x, y, z
def env(inputs, outputs, validate = True, features = []):
return Env(inputs, outputs, features = features, consistency_check = validate)
class _test_Elemwise(unittest.TestCase):
def test_0(self):
x, y, z = inputs()
e = ElemwiseAdd(x, y).out
fn, i, o = gof.CLinker(env([x, y], [e])).make_thunk(True)
fn()
assert (e.data == numpy.array([[4, 6], [4, 6]])).all()
x.data.resize((1, 4))
y.data.resize((1, 4))
fn()
assert (e.data == numpy.array([[4, 6, 4, 6]])).all()
if __name__ == '__main__':
unittest.main()
......@@ -31,7 +31,7 @@ class Double(ResultBase):
return """
%(name)s = 0;
%(name)s_bad_thing = malloc(100000);
printf("Initializing %(name)s\\n");
//printf("Initializing %(name)s\\n");
"""
def c_literal(self):
......@@ -45,7 +45,7 @@ class Double(ResultBase):
}
%(name)s = PyFloat_AsDouble(py_%(name)s);
%(name)s_bad_thing = NULL;
printf("Extracting %(name)s\\n");
//printf("Extracting %(name)s\\n");
"""
def c_sync(self):
......@@ -54,12 +54,12 @@ class Double(ResultBase):
py_%(name)s = PyFloat_FromDouble(%(name)s);
if (!py_%(name)s)
py_%(name)s = Py_None;
printf("Syncing %(name)s\\n");
//printf("Syncing %(name)s\\n");
"""
def c_cleanup(self):
return """
printf("Cleaning up %(name)s\\n");
//printf("Cleaning up %(name)s\\n");
if (%(name)s_bad_thing)
free(%(name)s_bad_thing);
"""
......@@ -101,6 +101,13 @@ class Mul(Binary):
return "%(z)s = %(x)s * %(y)s;"
class Div(Binary):
def c_validate_update(self):
return """
if (%(y)s == 0.0) {
PyErr_SetString(PyExc_ZeroDivisionError, "division by zero");
%(fail)s
}
"""
def c_code(self):
return "%(z)s = %(x)s / %(y)s;"
......@@ -115,48 +122,62 @@ def inputs():
return x, y, z
def env(inputs, outputs, validate = True, features = []):
# inputs = [input.r for input in inputs]
# outputs = [output.r for output in outputs]
return Env(inputs, outputs, features = features, consistency_check = validate)
class _test_CLinker(unittest.TestCase):
def test_0(self):
def test_straightforward(self):
x, y, z = inputs()
e = add(mul(add(x, y), div(x, y)), sub(sub(x, y), z))
lnk = CLinker(env([x, y, z], [e])) #, [x.r, y.r, z.r], [e.r])
cgen = lnk.code_gen()
fn = lnk.make_function() #[x.r, y.r, z.r], [e.r])
print fn(2.0, 2.0, 2.0)
# fn = 0
lnk = CLinker(env([x, y, z], [e]))
fn = lnk.make_function()
self.failUnless(fn(2.0, 2.0, 2.0) == 2.0)
def test_1(self):
def test_orphan(self):
x, y, z = inputs()
z.constant = True
z.data = 4.12345678
e = add(mul(add(x, y), div(x, y)), sub(sub(x, y), z))
lnk = CLinker(env([x, y], [e])) #, [x.r, y.r], [e.r])
cgen = lnk.code_gen()
fn = lnk.make_function() #[x.r, y.r], [e.r])
print fn(2.0, 2.0)
# fn = 0
lnk = CLinker(env([x, y], [e]))
fn = lnk.make_function()
self.failUnless(abs(fn(2.0, 2.0) + 0.12345678) < 1e-9)
self.failUnless("4.12345678" not in lnk.code_gen()) # we do not expect the number to be inlined
def test_2(self):
def test_literal_inlining(self):
x, y, z = inputs()
z.data = 4.12345678
z.constant = True # this should tell the compiler to inline z as a literal
e = add(mul(add(x, y), div(x, y)), sub(sub(x, y), z))
lnk = CLinker(env([x, y], [e]))
fn = lnk.make_function()
self.failUnless(abs(fn(2.0, 2.0) + 0.12345678) < 1e-9)
self.failUnless("4.12345678" in lnk.code_gen()) # we expect the number to be inlined
def test_single_op(self):
x, y, z = inputs()
op = Add(x, y)
lnk = CLinker(op)
cgen = lnk.code_gen()
fn = lnk.make_function() #[x.r, y.r], [op.out])
print fn(2.0, 7.0)
# fn = 0
fn = lnk.make_function()
self.failUnless(fn(2.0, 7.0) == 9)
def test_dups(self):
# Testing that duplicate inputs are allowed.
x, y, z = inputs()
op = Add(x, x)
lnk = CLinker(op)
fn = lnk.make_function()
self.failUnless(fn(2.0, 2.0) == 4)
# note: for now the behavior of fn(2.0, 7.0) is undefined
class _test_OpWiseCLinker(unittest.TestCase):
def test_3(self):
def test_straightforward(self):
x, y, z = inputs()
e = add(mul(add(x, y), div(x, y)), sub(sub(x, y), z))
lnk = OpWiseCLinker(env([x, y, z], [e]))
fn = lnk.make_function()
print fn(2.0, 2.0, 2.0)
# fn = 0
self.failUnless(fn(2.0, 2.0, 2.0) == 2.0)
if __name__ == '__main__':
unittest.main()
......@@ -9,6 +9,7 @@ import platform
from scipy import weave
import cutils
import utils
import traceback
def compile_dir():
......@@ -266,7 +267,7 @@ class CLinker(Linker):
blocks = []
failure_var = "__failure"
id = 0
id = 1
sub = dict(failure_var = failure_var)
......@@ -319,10 +320,10 @@ class CLinker(Linker):
builder, block = struct_result_codeblocks(result, policy, id, symbol, sub)
init_tasks.append((result, 'init'))
init_tasks.append((result, 'init', id))
init_blocks.append(builder)
tasks.append((result, 'get'))
tasks.append((result, 'get', id + 1))
blocks.append(block)
id += 2
......@@ -345,7 +346,7 @@ class CLinker(Linker):
sub['id'] = id
blocks.append(CodeBlock("", validate_behavior, validate_cleanup, sub))
tasks.append((op, 'validate_update'))
tasks.append((op, 'validate_update', id))
id += 1
# c_code
......@@ -357,7 +358,7 @@ class CLinker(Linker):
sub['id'] = id
blocks.append(CodeBlock("", behavior, cleanup, sub))
tasks.append((op, 'code'))
tasks.append((op, 'code', id))
id += 1
args = []
......@@ -386,6 +387,7 @@ class CLinker(Linker):
def find_task(self, failure_code):
failure_code -= 1
n = len(self.init_tasks)
if failure_code < 2 * n:
return [self.init_tasks, self.tasks][failure_code % 2][failure_code/2]
......@@ -441,7 +443,21 @@ class CLinker(Linker):
def execute():
failure = cutils.run_cthunk(cthunk)
if failure:
raise error_storage[0], error_storage[1] + " " + str(self.find_task(failure - 1))
task, taskname, id = self.find_task(failure)
#exc = traceback.format_exception_only(error_storage[0], error_storage[1])
try:
trace = task.trace
except AttributeError:
trace = ()
class X:pass
__x = X()
__x.__thunk_trace__ = trace
__x.__str__ = lambda: str(error_storage[1])
raise error_storage[0], __x
## raise ThunkException, (error_storage[0], error_storage[1], trace)
# for stack_element in traceback.format_list(trace):
# print >>sys.stderr, stack_element,
# raise error_storage[0], error_storage[1] + " (error occurred in: " + str(task) + ")"
return execute, in_results, out_results
# def make_function(self, inplace = False, unpack_single = True):
......
......@@ -4,6 +4,50 @@
from utils import AbstractFunctionError
import utils
import sys
import traceback
__excepthook = sys.excepthook
def thunk_hook(type, value, trace):
if len(value.args) > 0 and hasattr(value[0], '__thunk_trace__'):
# such a hack :(
trace2 = value[0].__thunk_trace__ #.exc_info
print>>sys.stderr, "Definition in: "
for line in traceback.format_list(trace2):
print>>sys.stderr, line,
__excepthook(type, value, trace)
sys.excepthook = thunk_hook
class Thunk:
def __init__(self):
self.results = None
self.is_valid = False
self.exc_info = ()
self.inputs = []
self.outputs = []
def call_thunk(self):
raise AbstractFunctionError
def exc_print(self, f = sys.stderr):
if self.is_valid:
return
type, value, trace = self.exc_info
for line in traceback.format_list(trace):
print>>f, line,
print>>f, traceback.format_exception_only(type, value)
def call_thunk_and_raise(self):
self.call_thunk()
if not self.is_valid:
type, value, trace = self.exc_info
raise self.type, self.value
def __call__(self, *inputs):
raise AbstractFunctionError
class Linker:
......
import utils
import traceback
from op import Op
__all__ = ['ModalConstructor',
......@@ -48,15 +49,21 @@ def add_modal_members(cls, *members):
setattr(cls, member, fn(member))
def attach_trace(op):
stack = traceback.extract_stack()[:-3]
op.trace = stack
def build_mode(op):
pass
attach_trace(op)
def eval_mode(op):
attach_trace(op)
op.perform()
for output in op.outputs:
output._role = None
def build_eval_mode(op):
attach_trace(op)
op.perform()
......
......@@ -14,6 +14,13 @@ class AbstractFunctionError(Exception):
function has been left out of an implementation class.
"""
def uniq(seq):
return [x for i, x in enumerate(seq) if seq.index(x) == i]
def difference(seq1, seq2):
return [x for x in seq1 if x not in seq2]
def attr_checker(*attrs):
def f(candidate):
for attr in attrs:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论