提交 cc51a643 authored 作者: Frederic's avatar Frederic

Better BadOptimization error (add max abs and rel diff values, and keep common ids in debugprint)

上级 87ad4106
......@@ -290,9 +290,13 @@ class BadOptimization(DebugModeError):
print >> ssio, " Mean Abs Diff: ", numpy.mean(numpy.absolute(nv -
ov))
print >> ssio, " Median Abs Diff: ", numpy.median(numpy.absolute(
nv - ov))
nv - ov))
print >> ssio, " Std Abs Diff: ", numpy.std(numpy.absolute(
nv - ov))
nv - ov))
arg_max_val = numpy.argmax(numpy.absolute(nv - ov))
values_at_max = (nv.flatten()[arg_max_val],
ov.flatten()[arg_max_val])
print >> ssio, " Value at Max Diff: ", values_at_max
# N.B. the maximum(..., 1e-8) protects against div by 0 when
# nv == ov == 0
......@@ -304,6 +308,10 @@ class BadOptimization(DebugModeError):
print >> ssio, " Mean Rel Diff: ", numpy.mean(reldiff)
print >> ssio, " Median Rel Diff: ", numpy.median(reldiff)
print >> ssio, " Std Rel Diff: ", numpy.std(reldiff)
arg_max_val = numpy.argmax(reldiff)
values_at_max = (nv.flatten()[arg_max_val],
ov.flatten()[arg_max_val])
print >> ssio, " Value at Max Diff: ", values_at_max
# only if all succeeds to we add anything to sio
print >> sio, ssio.getvalue()
except Exception:
......@@ -1559,10 +1567,13 @@ class _VariableEquivalenceTracker(object):
if append_reason:
# N.B. compute the debugprint now, because future
# optimizations will change the graph
done = dict()
self.reasons[new_r].append((reason,
r,
debugprint(r, prefix=' ', depth=6, file=StringIO()).getvalue(),
debugprint(new_r, prefix=' ', depth=6, file=StringIO()).getvalue()))
debugprint(r, prefix=' ', depth=6,
file=StringIO(), done=done).getvalue(),
debugprint(new_r, prefix=' ', depth=6,
file=StringIO(), done=done).getvalue()))
self.replaced_by[r].append((reason, new_r))
if r in self.equiv:
......
......@@ -34,7 +34,8 @@ VALID_ASSOC = set(['left', 'right', 'either'])
def debugprint(obj, depth=-1, print_type=False,
file=None, ids='CHAR', stop_on_name=False):
file=None, ids='CHAR', stop_on_name=False,
done=None):
"""Print a computation graph as text to stdout or a file.
:type obj: Variable, Apply, or Function instance
......@@ -53,6 +54,9 @@ def debugprint(obj, depth=-1, print_type=False,
"" - don't print an identifier
:param stop_on_name: When True, if a node in the graph has a name,
we don't print anything below it.
:type done: None or dict
:param done: A dict where we store the ids of printed node.
Useful to have multiple call to debugprint share the ids for the same node.
:returns: string if `file` == 'str', else file arg
......@@ -80,7 +84,8 @@ def debugprint(obj, depth=-1, print_type=False,
_file = sys.stdout
else:
_file = file
done = dict()
if done is None:
done = dict()
results_to_print = []
profile_list = []
order = []
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论