提交 9d2ae407 authored 作者: James Bergstra's avatar James Bergstra

minor improvements to DebugMode

上级 5332d39d
...@@ -128,20 +128,44 @@ class BadOptimization(DebugModeError): ...@@ -128,20 +128,44 @@ class BadOptimization(DebugModeError):
def str_diagnostic(self): def str_diagnostic(self):
"""Return a pretty multiline string representating the cause of the exception""" """Return a pretty multiline string representating the cause of the exception"""
sio = StringIO() sio = StringIO()
val_str_len_limit = 800
print >> sio, "BadOptimization Error", super(BadOptimization, self).__str__() print >> sio, "BadOptimization Error", super(BadOptimization, self).__str__()
print >> sio, " Variable: id", id(self.new_r), self.new_r print >> sio, " Variable: id", id(self.new_r), self.new_r
print >> sio, " Op", self.new_r.owner print >> sio, " Op", self.new_r.owner
print >> sio, " Value Type:", type(self.new_r_val) print >> sio, " Value Type:", type(self.new_r_val)
try:
ssio = StringIO()
print >> ssio, " Old Value shape, dtype, strides:",
print >> ssio, self.old_r_val.shape,
print >> ssio, self.old_r_val.dtype,
print >> ssio, self.old_r_val.strides
# only if all succeeds to we add anything to sio
print >> sio, ssio.getvalue()
except:
pass
str_old_r_val = str(self.old_r_val) str_old_r_val = str(self.old_r_val)
if len(str_old_r_val) > 80: if len(str_old_r_val) > val_str_len_limit:
print >> sio, " Old Value: ", str(self.old_r_val)[:80], '...' print >> sio, " Old Value: ", str(self.old_r_val)[:val_str_len_limit], '...'
else: else:
print >> sio, " Old Value: ", str(self.old_r_val) print >> sio, " Old Value: ", str(self.old_r_val)
try:
ssio = StringIO()
print >> ssio, " New Value shape, dtype, strides:",
print >> ssio, self.new_r_val.shape,
print >> ssio, self.new_r_val.dtype,
print >> ssio, self.new_r_val.strides
# only if all succeeds to we add anything to sio
print >> sio, ssio.getvalue()
except:
pass
str_new_r_val = str(self.new_r_val) str_new_r_val = str(self.new_r_val)
if len(str_new_r_val) > 80: if len(str_new_r_val) > val_str_len_limit:
print >> sio, " New Value: ", str(self.new_r_val)[:80], '...' print >> sio, " New Value: ", str(self.new_r_val)[:val_str_len_limit], '...'
else: else:
print >> sio, " New Value: ", str(self.new_r_val) print >> sio, " New Value: ", str(self.new_r_val)
print >> sio, " Reason: ", str(self.reason) print >> sio, " Reason: ", str(self.reason)
print >> sio, " Old Graph:" print >> sio, " Old Graph:"
print >> sio, self.old_graph print >> sio, self.old_graph
...@@ -177,7 +201,7 @@ class BadDestroyMap(DebugModeError): ...@@ -177,7 +201,7 @@ class BadDestroyMap(DebugModeError):
print >> sio, " value min (new-old):", (npy_new_val-npy_old_val).min() print >> sio, " value min (new-old):", (npy_new_val-npy_old_val).min()
print >> sio, " value max (new-old):", (npy_new_val-npy_old_val).max() print >> sio, " value max (new-old):", (npy_new_val-npy_old_val).max()
print >> sio, "" print >> sio, ""
print >> sio, " Hint: this can also be caused by a deficient values_eq_approx() or __eq__() implementation that compares node input values" print >> sio, " Hint: this can also be caused by a deficient values_eq_approx() or __eq__() implementation [which compared input values]"
return sio.getvalue() return sio.getvalue()
except Exception, e: except Exception, e:
return str(e) return str(e)
...@@ -220,24 +244,45 @@ class StochasticOrder(DebugModeError): ...@@ -220,24 +244,45 @@ class StochasticOrder(DebugModeError):
class InvalidValueError(DebugModeError): class InvalidValueError(DebugModeError):
"""Exception: some Op an output value that is inconsistent with the Type of that output""" """Exception: some Op an output value that is inconsistent with the Type of that output"""
def __init__(self, r, v, client_node=None): def __init__(self, r, v, client_node=None, hint='none'):
super(InvalidValueError, self).__init__() super(InvalidValueError, self).__init__()
self.r = r self.r = r
self.v = v self.v = v
self.client_node = client_node self.client_node = client_node
self.hint=hint
def __str__(self): def __str__(self):
r, v = self.r, self.v r, v = self.r, self.v
type_r = type(r) type_r = r.type
type_v = type(v) type_v = type(v)
v_val = str(v)[0:100] v_val = str(v)[0:100]
v_dtype = 'N/A'
v_shape = 'N/A'
v_min = 'N/A'
v_max = 'N/A'
v_isfinite = 'N/A'
try:
v_shape = v.shape
v_dtype = v.dtype
v_min = v.min()
v_max = v.max()
v_isfinite = numpy.all(numpy.isfinite(v))
except:
pass
client_node = self.client_node client_node = self.client_node
hint = self.hint
return """InvalidValueError return """InvalidValueError
type(variable) = %(type_r)s type(variable) = %(type_r)s
variable = %(r)s variable = %(r)s
type(value) = %(type_v)s type(value) = %(type_v)s
dtype(value) = %(v_dtype)s
shape(value) = %(v_shape)s
value = %(v_val)s value = %(v_val)s
min(value) = %(v_min)s
max(value) = %(v_max)s
isfinite = %(v_isfinite)s
client_node = %(client_node)s client_node = %(client_node)s
hint = %(hint)s
""" % locals() """ % locals()
######################## ########################
...@@ -930,7 +975,7 @@ class _Linker(gof.link.LocalLinker): ...@@ -930,7 +975,7 @@ class _Linker(gof.link.LocalLinker):
# check output values for type-correctness # check output values for type-correctness
for r in node.outputs: for r in node.outputs:
if not r.type.is_valid_value(storage_map[r][0]): if not r.type.is_valid_value(storage_map[r][0]):
raise InvalidValueError(r, storage_map[r][0]) raise InvalidValueError(r, storage_map[r][0], hint='perform output')
#if r in r_vals: #if r in r_vals:
_check_inputs(node, storage_map, r_vals, dr_vals, active_order_set, _check_inputs(node, storage_map, r_vals, dr_vals, active_order_set,
...@@ -963,7 +1008,7 @@ class _Linker(gof.link.LocalLinker): ...@@ -963,7 +1008,7 @@ class _Linker(gof.link.LocalLinker):
for r in node.outputs: for r in node.outputs:
# check output values for type-correctness # check output values for type-correctness
if not r.type.is_valid_value(storage_map[r][0]): if not r.type.is_valid_value(storage_map[r][0]):
raise InvalidValueError(r, storage_map[r][0]) raise InvalidValueError(r, storage_map[r][0], hint='c output')
if thunk_py: if thunk_py:
assert r in r_vals #because we put it in during the thunk_py branch assert r in r_vals #because we put it in during the thunk_py branch
...@@ -1140,7 +1185,7 @@ class _Maker(FunctionMaker): #inheritance buys a few helper functions ...@@ -1140,7 +1185,7 @@ class _Maker(FunctionMaker): #inheritance buys a few helper functions
print >> infolog, "-----------------------------------------------------" print >> infolog, "-----------------------------------------------------"
for j in xrange(max(len(li), len(l0))): for j in xrange(max(len(li), len(l0))):
if j >= len(li) or j >= len(l0) or li[j] != l0[j]: if j >= len(li) or j >= len(l0) or li[j] != l0[j]:
print >> infolog, "* ", j, print >> infolog, "* ", j
if j < len(li): if j < len(li):
msg = str(li[j]) msg = str(li[j])
else: else:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论