提交 5f889573 authored 作者: Ian Goodfellow's avatar Ian Goodfellow

added some convenience methods for writing clean code using the

interactive debugger
上级 8036de32
......@@ -587,3 +587,54 @@ def missing_test_message(msg):
warnings.warn(msg, stacklevel = 2)
else:
assert action in [ 'ignore', 'off' ]
def debug_error_message(msg):
""" Displays a message saying that an error was found in some
test_values. Becomes a warning or a ValueError depending on
config.compute_test_value"""
action = config.compute_test_value
#this message should never be called when the debugger is off
assert action != 'off'
if action in ['raise','ignore']:
raise ValueError(msg)
else:
assert action == 'warn'
warnings.warn(msg, stacklevel = 2)
def get_debug_values(*args):
""" Given a list of variables, does one of three things:
1. If the interactive debugger is off, returns an empty list
2. If the interactive debugger is on, and all variables have
debug values, returns a list containing a single element.
This single element is a tuple containing debug values of
all the variables.
3. If the interactive debugger is on, and some variable does
have a debug value, issue a missing_test_message about
the variable, and, if still in control of execution, return
an empty list
Intended use:
for val_1, ..., val_n in get_debug_values(var_1, ..., var_n):
if some condition on val_1, ..., val_n is not met:
debug_error_message("condition was not met")
"""
if config.compute_test_value == 'off':
return []
rval = []
for i, arg in enumerate(args):
try:
rval.append(get_test_value(arg))
except AttributeError:
missing_test_message("Argument "+str(i)+" has no test value")
return []
return [ tuple(rval) ]
......@@ -222,5 +222,109 @@ def test_test_value_op():
finally:
config.compute_test_value = prev_value
def test_get_debug_values_no_debugger():
'get_debug_values should return [] when debugger is off'
prev_value = config.compute_test_value
try:
config.compute_test_value = 'off'
x = T.vector()
for x_val in op.get_debug_values(x):
assert False
finally:
config.compute_test_value = prev_value
def test_get_det_debug_values_ignore():
"""get_debug_values should return [] when debugger is ignore
and some values are missing """
prev_value = config.compute_test_value
try:
config.compute_test_value = 'ignore'
x = T.vector()
for x_val in op.get_debug_values(x):
assert False
finally:
config.compute_test_value = prev_value
def test_get_debug_values_success():
"""tests that get_debug_value returns values when available
(and the debugger is on)"""
prev_value = config.compute_test_value
for mode in [ 'ignore', 'warn', 'raise' ]:
try:
config.compute_test_value = mode
x = T.vector()
x.tag.test_value = numpy.zeros((4,))
y = numpy.zeros((5,5))
iters = 0
for x_val, y_val in op.get_debug_values(x, y):
assert x_val.shape == (4,)
assert y_val.shape == (5,5)
iters += 1
assert iters == 1
finally:
config.compute_test_value = prev_value
def test_get_debug_values_exc():
"""tests that get_debug_value raises an exception when
debugger is set to raise and a value is missing """
prev_value = config.compute_test_value
try:
config.compute_test_value = 'raise'
x = T.vector()
try:
for x_val in op.get_debug_values(x):
assert False
raised = False
except AttributeError:
raised = True
assert raised
finally:
config.compute_test_value = prev_value
def test_debug_error_message():
"""tests that debug_error_message raises an
exception when it should."""
prev_value = config.compute_test_value
for mode in [ 'ignore', 'raise' ]:
try:
config.compute_test_value = mode
try:
op.debug_error_message('msg')
raised = False
except ValueError:
raised = True
assert raised
finally:
config.compute_test_value = prev_value
if __name__ == '__main__':
unittest.main()
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论