提交 8f23c938 authored 作者: Olivier Delalleau's avatar Olivier Delalleau

Merge pull request #106 from jaberg/traceback

Traceback comments
...@@ -17,18 +17,23 @@ def thunk_hook(type, value, trace): ...@@ -17,18 +17,23 @@ def thunk_hook(type, value, trace):
and prints it out on L{stderr}. and prints it out on L{stderr}.
The normal excepthook is then called. The normal excepthook is then called.
:note: This hook replaced by nosetests, so it does not run in nose tests.
""" """
if hasattr(value, '__thunk_trace__'): if hasattr(value, '__thunk_trace__'):
trace2 = value.__thunk_trace__ trace2 = value.__thunk_trace__
if trace2 is None: if trace2 is None:
print>>sys.stderr, "Could not find where this Op was defined." print>>sys.stderr, "Could not find where this Op was defined."
print>>sys.stderr, " * You might have instantiated this Op directly instead of using a constructor." print>>sys.stderr, (" * You might have instantiated this Op "
print>>sys.stderr, " * The Op you constructed might have been optimized. Try turning off optimizations." "directly instead of using a constructor.")
print>>sys.stderr, (" * The Op you constructed might have been"
" optimized. Try turning off optimizations.")
elif trace2: elif trace2:
print>>sys.stderr, "Definition in: " print>>sys.stderr, "Definition in: "
for line in traceback.format_list(trace2): for line in traceback.format_list(trace2):
print>>sys.stderr, line, print>>sys.stderr, line,
print>>sys.stderr, "For the full definition stack trace set the Theano flags traceback.limit to -1" print>>sys.stderr, ("For the full definition stack trace set"
" the Theano flags traceback.limit to -1")
__excepthook(type, value, trace) __excepthook(type, value, trace)
sys.excepthook = thunk_hook sys.excepthook = thunk_hook
......
import unittest
from theano.gof import graph from theano.gof import graph
from theano.gof.graph import Variable, Apply, Constant from theano.gof.graph import Variable, Apply, Constant
...@@ -8,26 +9,25 @@ from theano.gof import toolbox ...@@ -8,26 +9,25 @@ from theano.gof import toolbox
from theano.gof.link import * from theano.gof.link import *
#from _test_variable import Double
def as_variable(x): def as_variable(x):
assert isinstance(x, Variable) assert isinstance(x, Variable)
return x return x
class TDouble(Type): class TDouble(Type):
def filter(self, data): def filter(self, data):
return float(data) return float(data)
tdouble = TDouble() tdouble = TDouble()
def double(name): def double(name):
return Variable(tdouble, None, None, name = name) return Variable(tdouble, None, None, name=name)
class MyOp(Op): class MyOp(Op):
def __init__(self, nin, name, impl=None):
def __init__(self, nin, name, impl = None):
self.nin = nin self.nin = nin
self.name = name self.name = name
if impl: if impl:
...@@ -54,11 +54,12 @@ sub = MyOp(2, 'Sub', lambda x, y: x - y) ...@@ -54,11 +54,12 @@ sub = MyOp(2, 'Sub', lambda x, y: x - y)
mul = MyOp(2, 'Mul', lambda x, y: x * y) mul = MyOp(2, 'Mul', lambda x, y: x * y)
div = MyOp(2, 'Div', lambda x, y: x / y) div = MyOp(2, 'Div', lambda x, y: x / y)
def notimpl(self, x): def notimpl(self, x):
raise NotImplementedError() raise NotImplementedError()
raise_err = MyOp(1, 'RaiseErr', notimpl)
raise_err = MyOp(1, 'RaiseErr', notimpl)
def inputs(): def inputs():
...@@ -67,17 +68,18 @@ def inputs(): ...@@ -67,17 +68,18 @@ def inputs():
z = double('z') z = double('z')
return x, y, z return x, y, z
def perform_linker(env): def perform_linker(env):
lnk = PerformLinker().accept(env) lnk = PerformLinker().accept(env)
return lnk return lnk
def Env(inputs, outputs): def Env(inputs, outputs):
e = env.Env(inputs, outputs) e = env.Env(inputs, outputs)
return e return e
class TestPerformLinker: class TestPerformLinker(unittest.TestCase):
def test_thunk(self): def test_thunk(self):
x, y, z = inputs() x, y, z = inputs()
e = mul(add(x, y), div(x, y)) e = mul(add(x, y), div(x, y))
...@@ -107,34 +109,37 @@ class TestPerformLinker: ...@@ -107,34 +109,37 @@ class TestPerformLinker:
def test_input_dependency0(self): def test_input_dependency0(self):
x, y, z = inputs() x, y, z = inputs()
a,d = add(x,y), div(x,y) a, d = add(x, y), div(x, y)
e = mul(a,d) e = mul(a, d)
fn = perform_linker(Env(*graph.clone([x, y, a], [e]))).make_function() fn = perform_linker(Env(*graph.clone([x, y, a], [e]))).make_function()
assert fn(1.0,2.0,9.0) == 4.5 assert fn(1.0, 2.0, 9.0) == 4.5
def test_skiphole(self): def test_skiphole(self):
x,y,z = inputs() x, y, z = inputs()
a = add(x,y) a = add(x, y)
r = raise_err(a) r = raise_err(a)
e = add(r,a) e = add(r, a)
fn = perform_linker(Env(*graph.clone([x, y,r], [e]))).make_function() fn = perform_linker(Env(*graph.clone([x, y, r], [e]))).make_function()
assert fn(1.0,2.0,4.5) == 7.5 assert fn(1.0, 2.0, 4.5) == 7.5
def wrap_linker(env, linkers, wrapper): def wrap_linker(env, linkers, wrapper):
lnk = WrapLinker(linkers, wrapper).accept(env) lnk = WrapLinker(linkers, wrapper).accept(env)
return lnk return lnk
class TestWrapLinker:
class TestWrapLinker(unittest.TestCase):
def test_0(self): def test_0(self):
nodes = [] nodes = []
def wrap(i, node, th): def wrap(i, node, th):
nodes.append(node.op) nodes.append(node.op)
x, y, z = inputs() x, y, z = inputs()
e = mul(add(x, y), div(x, y)) e = mul(add(x, y), div(x, y))
fn, i, o = wrap_linker(Env([x, y, z], [e]), [PerformLinker(allow_gc=False)], wrap).make_thunk() fn, i, o = wrap_linker(
Env([x, y, z], [e]),
[PerformLinker(allow_gc=False)], wrap).make_thunk()
i[0].data = 1 i[0].data = 1
i[1].data = 2 i[1].data = 2
fn() fn()
...@@ -143,20 +148,18 @@ class TestWrapLinker: ...@@ -143,20 +148,18 @@ class TestWrapLinker:
def test_1(self): def test_1(self):
nodes = [] nodes = []
def wrap(i, node, th): def wrap(i, node, th):
nodes.append(node.op) nodes.append(node.op)
th() th()
x, y, z = inputs() x, y, z = inputs()
e = mul(add(x, y), div(x, y)) e = mul(add(x, y), div(x, y))
fn, i, o = wrap_linker(Env([x, y, z], [e]), [PerformLinker(allow_gc=False)], wrap).make_thunk() fn, i, o = wrap_linker(
Env([x, y, z], [e]),
[PerformLinker(allow_gc=False)], wrap).make_thunk()
i[0].data = 1 i[0].data = 1
i[1].data = 2 i[1].data = 2
fn() fn()
assert nodes == [div, add, mul] assert nodes == [div, add, mul]
assert o[0].data == 1.5 assert o[0].data == 1.5
...@@ -5005,7 +5005,7 @@ class AdvancedIncSubtensor1(Op): ...@@ -5005,7 +5005,7 @@ class AdvancedIncSubtensor1(Op):
gx = g_output gx = g_output
gy = advanced_subtensor1(g_output, *idx_list) gy = advanced_subtensor1(g_output, *idx_list)
return [gx, gy] + [None]*len(idx_list) return [gx, gy] + [None] * len(idx_list)
advanced_inc_subtensor1 = AdvancedIncSubtensor1() advanced_inc_subtensor1 = AdvancedIncSubtensor1()
......
...@@ -743,7 +743,7 @@ class Elemwise(Op): ...@@ -743,7 +743,7 @@ class Elemwise(Op):
# Since numpy 1.6, function created with numpy.frompyfunc # Since numpy 1.6, function created with numpy.frompyfunc
# always return an ndarray with dtype object # always return an ndarray with dtype object
variable = numpy.asarray(variable, dtype=nout.dtype) variable = numpy.asarray(variable, dtype=nout.dtype)
if hasattr(variable,'shape') and storage[0].shape != variable.shape: if hasattr(variable, 'shape') and storage[0].shape != variable.shape:
if numpy.prod(variable.shape) == 0: if numpy.prod(variable.shape) == 0:
# numpy don't resize from a shape (1,5) to (0,5) # numpy don't resize from a shape (1,5) to (0,5)
# This bypass the inplace... But I it is important in this case. # This bypass the inplace... But I it is important in this case.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论