提交 c7c57776 authored 作者: James Bergstra's avatar James Bergstra

cleaned up the printing of errors during testing tensor/test_basic

上级 24718ae9
......@@ -900,9 +900,8 @@ class EquilibriumOptimizer(NavigatorOptimizer):
lopt_change = self.process_node(env, node, lopt)
process_count[lopt] += 1 if lopt_change else 0
changed |= lopt_change
except:
finally:
self.detach_updater(env, u)
raise
self.detach_updater(env, u)
if max_use_abort:
print >> sys.stderr, "WARNING: EquilibriumOptimizer max'ed out"
......
......@@ -120,7 +120,7 @@ class EquilibriumDB(DB):
return opt.EquilibriumOptimizer(opts,
max_depth=5,
max_use_ratio=10,
failure_callback=opt.NavigatorOptimizer.warn)
failure_callback=opt.NavigatorOptimizer.warn_inplace)
class SequenceDB(DB):
......
......@@ -858,7 +858,7 @@ class MaxAndArgmax(Op):
axis = x.type.ndim - 1
axis = _as_tensor_variable(axis)
inputs = [x, axis]
broadcastable = [False] * (x.type.ndim - 1)
broadcastable = [False] * (x.type.ndim - 1) #TODO: be less conservative
outputs = [tensor(x.type.dtype, broadcastable),
tensor(axis.type.dtype, broadcastable)]
return Apply(self, inputs, outputs)
......
......@@ -628,18 +628,26 @@ class T_max_and_argmax(unittest.TestCase):
self.failUnless(numpy.all(i == numpy.argmax(data,0)))
def test2_invalid(self):
n = as_tensor_variable(numpy.random.rand(2,3))
old_stderr = sys.stderr
sys.stderr = StringIO.StringIO()
try:
eval_outputs(max_and_argmax(n,3))
assert False
except ValueError, e:
return
self.fail()
pass
finally:
sys.stderr = old_stderr
def test2_invalid_neg(self):
n = as_tensor_variable(numpy.random.rand(2,3))
old_stderr = sys.stderr
sys.stderr = StringIO.StringIO()
try:
eval_outputs(max_and_argmax(n,-3))
assert False
except ValueError, e:
return
self.fail()
pass
finally:
sys.stderr = old_stderr
def test2_valid_neg(self):
n = as_tensor_variable(numpy.random.rand(2,3))
v,i = eval_outputs(max_and_argmax(n,-1))
......@@ -678,13 +686,16 @@ class T_subtensor(unittest.TestCase):
n = as_tensor_variable(numpy.ones(3))
t = n[7]
self.failUnless(isinstance(t.owner.op, Subtensor))
old_stderr = sys.stderr
sys.stderr = StringIO.StringIO()
try:
tval = eval_outputs([t])
assert 0
except Exception, e:
if e[0] != 'index out of bounds':
raise
return
self.fail()
finally:
sys.stderr = old_stderr
def test1_err_subslice(self):
n = as_tensor_variable(numpy.ones(3))
try:
......@@ -748,20 +759,28 @@ class T_subtensor(unittest.TestCase):
n = as_tensor_variable(numpy.ones((2,3))*5)
t = n[0,4]
self.failUnless(isinstance(t.owner.op, Subtensor))
old_stderr = sys.stderr
sys.stderr = StringIO.StringIO()
try:
tval = eval_outputs([t])
assert 0
except IndexError, e:
return
self.fail()
pass
finally:
sys.stderr = old_stderr
def test2_err_bounds1(self):
n = as_tensor_variable(numpy.ones((2,3))*5)
t = n[4:5,2]
self.failUnless(isinstance(t.owner.op, Subtensor))
old_stderr = sys.stderr
sys.stderr = StringIO.StringIO()
try:
tval = eval_outputs([t])
except Exception, e:
if e[0] != 'index out of bounds':
raise
finally:
sys.stderr = old_stderr
def test2_ok_elem(self):
n = as_tensor_variable(numpy.asarray(range(6)).reshape((2,3)))
t = n[0,2]
......@@ -1364,9 +1383,9 @@ class t_dot(unittest.TestCase):
def not_aligned(self, x, y):
z = dot(x,y)
old_stderr = sys.stderr
# constant folding will complain to stderr that things are not aligned
# this is normal, testers are not interested in seeing that output.
old_stderr = sys.stderr
sys.stderr = StringIO.StringIO()
try:
tz = eval_outputs([z])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论