提交 e89f6631 authored 作者: Olivier Breuleux's avatar Olivier Breuleux

changed the way the trace is communicated

...@@ -188,9 +188,9 @@ class T_pow(unittest.TestCase): ...@@ -188,9 +188,9 @@ class T_pow(unittest.TestCase):
verify_grad(self, DivElemwise, [numpy.random.rand(3,4), numpy.random.rand(3,4)+0.1]) verify_grad(self, DivElemwise, [numpy.random.rand(3,4), numpy.random.rand(3,4)+0.1])
verify_grad(self, PowElemwise, [numpy.random.rand(3,4), numpy.random.rand(3,4)]) verify_grad(self, PowElemwise, [numpy.random.rand(3,4), numpy.random.rand(3,4)])
def test_scalar_l(self): def test_scalar_l(self):
verify_grad(self, PowScalarL, [numpy.random.rand(3), 3.0]) verify_grad(self, PowScalarL, [numpy.random.rand(3), numpy.asarray(3.0)])
def test_scalar_r(self): def test_scalar_r(self):
verify_grad(self, PowScalarR, [numpy.random.rand(3), 3.0]) verify_grad(self, PowScalarR, [numpy.random.rand(3), numpy.asarray(3.0)])
class _testCase_matinv:#(unittest.TestCase): class _testCase_matinv:#(unittest.TestCase):
......
...@@ -180,4 +180,9 @@ class _test_OpWiseCLinker(unittest.TestCase): ...@@ -180,4 +180,9 @@ class _test_OpWiseCLinker(unittest.TestCase):
self.failUnless(fn(2.0, 2.0, 2.0) == 2.0) self.failUnless(fn(2.0, 2.0, 2.0) == 2.0)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() # unittest.main()
x, y, z = inputs()
e = add(mul(add(x, y), div(x, y)), sub(sub(x, y), z))
lnk = CLinker(env([x, y, z], [e]))
fn = lnk.make_function()
fn(2.0, 0.0, 2.0)
...@@ -444,40 +444,16 @@ class CLinker(Linker): ...@@ -444,40 +444,16 @@ class CLinker(Linker):
failure = cutils.run_cthunk(cthunk) failure = cutils.run_cthunk(cthunk)
if failure: if failure:
task, taskname, id = self.find_task(failure) task, taskname, id = self.find_task(failure)
#exc = traceback.format_exception_only(error_storage[0], error_storage[1])
try: try:
trace = task.trace trace = task.trace
except AttributeError: except AttributeError:
trace = () trace = ()
class X:pass exc_type, _exc_value, exc_trace = error_storage
__x = X() exc_value = exc_type(_exc_value, task)
__x.__thunk_trace__ = trace exc_value.__thunk_trace__ = trace
__x.__str__ = lambda: str(error_storage[1]) raise exc_type, exc_value, exc_trace
raise error_storage[0], __x
## raise ThunkException, (error_storage[0], error_storage[1], trace)
# for stack_element in traceback.format_list(trace):
# print >>sys.stderr, stack_element,
# raise error_storage[0], error_storage[1] + " (error occurred in: " + str(task) + ")"
return execute, in_results, out_results return execute, in_results, out_results
# def make_function(self, inplace = False, unpack_single = True):
# cthunk, in_results, out_results, error_storage = self.__compile__(inplace)
# # out_storage = [result._data for result in out_results]
# def execute(*args):
# for arg, result in zip(args, in_results):
# result.data = arg
# failure = cutils.run_cthunk(cthunk)
# if failure:
# raise error_storage[0], error_storage[1] + " " + str(self.find_task(failure - 1))
# if unpack_single:
# return utils.to_return_values([result.data for result in out_results])
# else:
# return [result.data for result in out_results]
# # return utils.to_return_values([storage[0] for storage in out_storage])
# return execute
def cthunk_factory(self, error_storage, in_storage, out_storage): def cthunk_factory(self, error_storage, in_storage, out_storage):
if not getattr(self, 'instantiate', False): if not getattr(self, 'instantiate', False):
......
...@@ -10,9 +10,9 @@ import traceback ...@@ -10,9 +10,9 @@ import traceback
__excepthook = sys.excepthook __excepthook = sys.excepthook
def thunk_hook(type, value, trace): def thunk_hook(type, value, trace):
if len(value.args) > 0 and hasattr(value[0], '__thunk_trace__'): if hasattr(value, '__thunk_trace__'):
# such a hack :( # such a hack :(
trace2 = value[0].__thunk_trace__ #.exc_info trace2 = value.__thunk_trace__
if trace2 is None: if trace2 is None:
print>>sys.stderr, "Could not find where this Op was defined." print>>sys.stderr, "Could not find where this Op was defined."
print>>sys.stderr, " * You might have instantiated this Op directly instead of using a constructor." print>>sys.stderr, " * You might have instantiated this Op directly instead of using a constructor."
...@@ -24,6 +24,22 @@ def thunk_hook(type, value, trace): ...@@ -24,6 +24,22 @@ def thunk_hook(type, value, trace):
__excepthook(type, value, trace) __excepthook(type, value, trace)
sys.excepthook = thunk_hook sys.excepthook = thunk_hook
# __excepthook = sys.excepthook
# def thunk_hook(type, value, trace):
# if len(value.args) > 0 and hasattr(value[0], '__thunk_trace__'):
# # such a hack :(
# trace2 = value[0].__thunk_trace__ #.exc_info
# if trace2 is None:
# print>>sys.stderr, "Could not find where this Op was defined."
# print>>sys.stderr, " * You might have instantiated this Op directly instead of using a constructor."
# print>>sys.stderr, " * The Op you constructed might have been optimized. Try turning off optimizations."
# elif trace2:
# print>>sys.stderr, "Definition in: "
# for line in traceback.format_list(trace2):
# print>>sys.stderr, line,
# __excepthook(type, value, trace)
# sys.excepthook = thunk_hook
class Linker: class Linker:
...@@ -110,11 +126,9 @@ class PerformLinker(Linker): ...@@ -110,11 +126,9 @@ class PerformLinker(Linker):
trace = op.trace trace = op.trace
except AttributeError: except AttributeError:
trace = () trace = ()
class X:pass exc_value.__thunk_trace__ = trace
__x = X() exc_value.args = exc_value.args + (op, )
__x.__thunk_trace__ = trace raise exc_type, exc_value, exc_trace
__x.__str__ = lambda: str(exc_value) + " (in op: " + str(op) + ")"
raise exc_type, __x, exc_trace
return f, env.inputs, env.outputs return f, env.inputs, env.outputs
......
...@@ -119,6 +119,8 @@ class _Op(Op): ...@@ -119,6 +119,8 @@ class _Op(Op):
nin = -1 nin = -1
nout = 1 nout = 1
_destroy_map = {}
def __init__(self, *inputs): def __init__(self, *inputs):
def as_tensor(obj): def as_tensor(obj):
...@@ -148,17 +150,28 @@ class _Op(Op): ...@@ -148,17 +150,28 @@ class _Op(Op):
def propagate_dtype(self, *i_dtypes): def propagate_dtype(self, *i_dtypes):
def upcast(dtype, *dtypes): def upcast(dtype, *dtypes):
z = numpy.zeros((), dtype = dtype) z = numpy.zeros((), dtype = dtype)
#print '----', self.__class__
#print type(z), dtype
for dtype in dtypes: for dtype in dtypes:
z = z + numpy.zeros((), dtype = dtype) z = z + numpy.zeros((), dtype = dtype)
#print type(z), type(dtype), dtype
return str(z.dtype) return str(z.dtype)
for dtype in i_dtypes: for dtype in i_dtypes:
if dtype is None: if dtype is None:
raise TypeError("Expected a Tensor.") raise TypeError("Expected a Tensor.")
rval = upcast(*i_dtypes) upcasted = upcast(*i_dtypes)
return rval return [upcasted] * self.nout
# try:
# dmap = self.destroy_map()
# except AttributeError:
# dmap = {}
# rval = []
# for i in xrange(self.nout):
# if i in dmap:
# destroyed = dmap[output]
# if len(destroyed) != 1:
# raise TypeError("Cannot infer dtype of output %s because it destroys more than one input." % output)
# rval.append(destroyed[0])
# else:
# rval.append(upcasted)
# return rval
def impl(self, *inputs): def impl(self, *inputs):
raise AbstractFunctionError() raise AbstractFunctionError()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论