提交 e89f6631 authored 作者: Olivier Breuleux's avatar Olivier Breuleux

changed the way the trace is communicated

......@@ -188,9 +188,9 @@ class T_pow(unittest.TestCase):
verify_grad(self, DivElemwise, [numpy.random.rand(3,4), numpy.random.rand(3,4)+0.1])
verify_grad(self, PowElemwise, [numpy.random.rand(3,4), numpy.random.rand(3,4)])
def test_scalar_l(self):
verify_grad(self, PowScalarL, [numpy.random.rand(3), 3.0])
verify_grad(self, PowScalarL, [numpy.random.rand(3), numpy.asarray(3.0)])
def test_scalar_r(self):
verify_grad(self, PowScalarR, [numpy.random.rand(3), 3.0])
verify_grad(self, PowScalarR, [numpy.random.rand(3), numpy.asarray(3.0)])
class _testCase_matinv:#(unittest.TestCase):
......
......@@ -180,4 +180,9 @@ class _test_OpWiseCLinker(unittest.TestCase):
self.failUnless(fn(2.0, 2.0, 2.0) == 2.0)
if __name__ == '__main__':
unittest.main()
# unittest.main()
x, y, z = inputs()
e = add(mul(add(x, y), div(x, y)), sub(sub(x, y), z))
lnk = CLinker(env([x, y, z], [e]))
fn = lnk.make_function()
fn(2.0, 0.0, 2.0)
......@@ -444,39 +444,15 @@ class CLinker(Linker):
failure = cutils.run_cthunk(cthunk)
if failure:
task, taskname, id = self.find_task(failure)
#exc = traceback.format_exception_only(error_storage[0], error_storage[1])
try:
trace = task.trace
except AttributeError:
trace = ()
class X:pass
__x = X()
__x.__thunk_trace__ = trace
__x.__str__ = lambda: str(error_storage[1])
raise error_storage[0], __x
## raise ThunkException, (error_storage[0], error_storage[1], trace)
# for stack_element in traceback.format_list(trace):
# print >>sys.stderr, stack_element,
# raise error_storage[0], error_storage[1] + " (error occurred in: " + str(task) + ")"
exc_type, _exc_value, exc_trace = error_storage
exc_value = exc_type(_exc_value, task)
exc_value.__thunk_trace__ = trace
raise exc_type, exc_value, exc_trace
return execute, in_results, out_results
# def make_function(self, inplace = False, unpack_single = True):
# cthunk, in_results, out_results, error_storage = self.__compile__(inplace)
# # out_storage = [result._data for result in out_results]
# def execute(*args):
# for arg, result in zip(args, in_results):
# result.data = arg
# failure = cutils.run_cthunk(cthunk)
# if failure:
# raise error_storage[0], error_storage[1] + " " + str(self.find_task(failure - 1))
# if unpack_single:
# return utils.to_return_values([result.data for result in out_results])
# else:
# return [result.data for result in out_results]
# # return utils.to_return_values([storage[0] for storage in out_storage])
# return execute
def cthunk_factory(self, error_storage, in_storage, out_storage):
......
......@@ -10,9 +10,9 @@ import traceback
__excepthook = sys.excepthook
def thunk_hook(type, value, trace):
if len(value.args) > 0 and hasattr(value[0], '__thunk_trace__'):
if hasattr(value, '__thunk_trace__'):
# such a hack :(
trace2 = value[0].__thunk_trace__ #.exc_info
trace2 = value.__thunk_trace__
if trace2 is None:
print>>sys.stderr, "Could not find where this Op was defined."
print>>sys.stderr, " * You might have instantiated this Op directly instead of using a constructor."
......@@ -24,6 +24,22 @@ def thunk_hook(type, value, trace):
__excepthook(type, value, trace)
sys.excepthook = thunk_hook
# __excepthook = sys.excepthook
# def thunk_hook(type, value, trace):
# if len(value.args) > 0 and hasattr(value[0], '__thunk_trace__'):
# # such a hack :(
# trace2 = value[0].__thunk_trace__ #.exc_info
# if trace2 is None:
# print>>sys.stderr, "Could not find where this Op was defined."
# print>>sys.stderr, " * You might have instantiated this Op directly instead of using a constructor."
# print>>sys.stderr, " * The Op you constructed might have been optimized. Try turning off optimizations."
# elif trace2:
# print>>sys.stderr, "Definition in: "
# for line in traceback.format_list(trace2):
# print>>sys.stderr, line,
# __excepthook(type, value, trace)
# sys.excepthook = thunk_hook
class Linker:
......@@ -110,11 +126,9 @@ class PerformLinker(Linker):
trace = op.trace
except AttributeError:
trace = ()
class X:pass
__x = X()
__x.__thunk_trace__ = trace
__x.__str__ = lambda: str(exc_value) + " (in op: " + str(op) + ")"
raise exc_type, __x, exc_trace
exc_value.__thunk_trace__ = trace
exc_value.args = exc_value.args + (op, )
raise exc_type, exc_value, exc_trace
return f, env.inputs, env.outputs
......
......@@ -119,6 +119,8 @@ class _Op(Op):
nin = -1
nout = 1
_destroy_map = {}
def __init__(self, *inputs):
def as_tensor(obj):
......@@ -148,17 +150,28 @@ class _Op(Op):
def propagate_dtype(self, *i_dtypes):
def upcast(dtype, *dtypes):
z = numpy.zeros((), dtype = dtype)
#print '----', self.__class__
#print type(z), dtype
for dtype in dtypes:
z = z + numpy.zeros((), dtype = dtype)
#print type(z), type(dtype), dtype
return str(z.dtype)
for dtype in i_dtypes:
if dtype is None:
raise TypeError("Expected a Tensor.")
rval = upcast(*i_dtypes)
return rval
upcasted = upcast(*i_dtypes)
return [upcasted] * self.nout
# try:
# dmap = self.destroy_map()
# except AttributeError:
# dmap = {}
# rval = []
# for i in xrange(self.nout):
# if i in dmap:
# destroyed = dmap[output]
# if len(destroyed) != 1:
# raise TypeError("Cannot infer dtype of output %s because it destroys more than one input." % output)
# rval.append(destroyed[0])
# else:
# rval.append(upcasted)
# return rval
def impl(self, *inputs):
raise AbstractFunctionError()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论