提交 2cbba6b9 authored 作者: David Warde-Farley's avatar David Warde-Farley

Merge pull request #82 from goodfeli/get_debug_values

Functions for clean interactive debugger code
......@@ -587,3 +587,71 @@ def missing_test_message(msg):
warnings.warn(msg, stacklevel = 2)
else:
assert action in [ 'ignore', 'off' ]
def debug_error_message(msg):
""" Displays a message saying that an error was found in some
test_values. Becomes a warning or a ValueError depending on
config.compute_test_value"""
action = config.compute_test_value
#this message should never be called when the debugger is off
assert action != 'off'
if action in ['raise','ignore']:
raise ValueError(msg)
else:
assert action == 'warn'
warnings.warn(msg, stacklevel = 2)
def debug_assert(condition, msg):
if not condition:
action = config.compute_test_value
if action in ['raise', 'ignore']:
raise AssertionError(msg)
else:
assert action == 'warn'
warnings.warn(msg, stacklevel = 2)
def get_debug_values(*args):
""" Given a list of variables, does one of three things:
1. If the interactive debugger is off, returns an empty list
2. If the interactive debugger is on, and all variables have
debug values, returns a list containing a single element.
This single element is a tuple containing debug values of
all the variables.
3. If the interactive debugger is on, and some variable does
not have a debug value, issue a missing_test_message about
the variable, and, if still in control of execution, return
an empty list
Intended use:
for val_1, ..., val_n in get_debug_values(var_1, ..., var_n):
if some condition on val_1, ..., val_n is not met:
debug_error_message("condition was not met")
"""
if config.compute_test_value == 'off':
return []
rval = []
for i, arg in enumerate(args):
try:
rval.append(get_test_value(arg))
except AttributeError:
if hasattr(arg, 'name') and arg.name is not None:
missing_test_message("Argument " + str(i) + "('" + arg.name + "') has no test value")
else:
missing_test_message("Argument " + str(i) + " has no test value")
return []
return [ tuple(rval) ]
......@@ -222,5 +222,114 @@ def test_test_value_op():
finally:
config.compute_test_value = prev_value
def test_get_debug_values_no_debugger():
'get_debug_values should return [] when debugger is off'
prev_value = config.compute_test_value
try:
config.compute_test_value = 'off'
x = T.vector()
for x_val in op.get_debug_values(x):
assert False
finally:
config.compute_test_value = prev_value
def test_get_det_debug_values_ignore():
"""get_debug_values should return [] when debugger is ignore
and some values are missing """
prev_value = config.compute_test_value
try:
config.compute_test_value = 'ignore'
x = T.vector()
for x_val in op.get_debug_values(x):
assert False
finally:
config.compute_test_value = prev_value
def test_get_debug_values_success():
"""tests that get_debug_value returns values when available
(and the debugger is on)"""
prev_value = config.compute_test_value
for mode in [ 'ignore', 'warn', 'raise' ]:
try:
config.compute_test_value = mode
x = T.vector()
x.tag.test_value = numpy.zeros((4,))
y = numpy.zeros((5,5))
iters = 0
for x_val, y_val in op.get_debug_values(x, y):
assert x_val.shape == (4,)
assert y_val.shape == (5,5)
iters += 1
assert iters == 1
finally:
config.compute_test_value = prev_value
def test_get_debug_values_exc():
"""tests that get_debug_value raises an exception when
debugger is set to raise and a value is missing """
prev_value = config.compute_test_value
try:
config.compute_test_value = 'raise'
x = T.vector()
try:
for x_val in op.get_debug_values(x):
#this assert catches the case where we
#erroneously get a value returned
assert False
raised = False
except AttributeError:
raised = True
#this assert catches the case where we got []
#returned, and possibly issued a warning,
#rather than raising an exception
assert raised
finally:
config.compute_test_value = prev_value
def test_debug_error_message():
"""tests that debug_error_message raises an
exception when it should."""
prev_value = config.compute_test_value
for mode in [ 'ignore', 'raise' ]:
try:
config.compute_test_value = mode
try:
op.debug_error_message('msg')
raised = False
except ValueError:
raised = True
assert raised
finally:
config.compute_test_value = prev_value
if __name__ == '__main__':
unittest.main()
......@@ -618,9 +618,20 @@ class Elemwise(Op):
# of the current apply node is c
ograds = map(as_tensor_variable, ograds)
scalar_inputs = [Scalar(dtype = t.type.dtype)() for t in inputs]
scalar_ograds = [Scalar(dtype = ograd.type.dtype)() for ograd in ograds]
scalar_igrads = self.scalar_op.grad(scalar_inputs, scalar_ograds)
prev_setting = theano.config.compute_test_value
try:
theano.config.compute_test_value = 'off'
scalar_inputs = [Scalar(dtype = t.type.dtype)() for t in inputs]
scalar_ograds = [Scalar(dtype = ograd.type.dtype)() for ograd in ograds]
scalar_igrads = self.scalar_op.grad(scalar_inputs, scalar_ograds)
finally:
theano.config.compute_test_value = prev_setting
nd = len(inputs[0].type.broadcastable) # this is the same for everyone
def transform(r):
# From a graph of ScalarOps, make a graph of Broadcast ops.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论