提交 cc51a643 authored 作者: Frederic's avatar Frederic

Better BadOptimization error (add max abs and rel diff values, and keep common ids in debugprint)

上级 87ad4106
...@@ -293,6 +293,10 @@ class BadOptimization(DebugModeError): ...@@ -293,6 +293,10 @@ class BadOptimization(DebugModeError):
nv - ov)) nv - ov))
print >> ssio, " Std Abs Diff: ", numpy.std(numpy.absolute( print >> ssio, " Std Abs Diff: ", numpy.std(numpy.absolute(
nv - ov)) nv - ov))
arg_max_val = numpy.argmax(numpy.absolute(nv - ov))
values_at_max = (nv.flatten()[arg_max_val],
ov.flatten()[arg_max_val])
print >> ssio, " Value at Max Diff: ", values_at_max
# N.B. the maximum(..., 1e-8) protects against div by 0 when # N.B. the maximum(..., 1e-8) protects against div by 0 when
# nv == ov == 0 # nv == ov == 0
...@@ -304,6 +308,10 @@ class BadOptimization(DebugModeError): ...@@ -304,6 +308,10 @@ class BadOptimization(DebugModeError):
print >> ssio, " Mean Rel Diff: ", numpy.mean(reldiff) print >> ssio, " Mean Rel Diff: ", numpy.mean(reldiff)
print >> ssio, " Median Rel Diff: ", numpy.median(reldiff) print >> ssio, " Median Rel Diff: ", numpy.median(reldiff)
print >> ssio, " Std Rel Diff: ", numpy.std(reldiff) print >> ssio, " Std Rel Diff: ", numpy.std(reldiff)
arg_max_val = numpy.argmax(reldiff)
values_at_max = (nv.flatten()[arg_max_val],
ov.flatten()[arg_max_val])
print >> ssio, " Value at Max Diff: ", values_at_max
# only if all succeeds to we add anything to sio # only if all succeeds to we add anything to sio
print >> sio, ssio.getvalue() print >> sio, ssio.getvalue()
except Exception: except Exception:
...@@ -1559,10 +1567,13 @@ class _VariableEquivalenceTracker(object): ...@@ -1559,10 +1567,13 @@ class _VariableEquivalenceTracker(object):
if append_reason: if append_reason:
# N.B. compute the debugprint now, because future # N.B. compute the debugprint now, because future
# optimizations will change the graph # optimizations will change the graph
done = dict()
self.reasons[new_r].append((reason, self.reasons[new_r].append((reason,
r, r,
debugprint(r, prefix=' ', depth=6, file=StringIO()).getvalue(), debugprint(r, prefix=' ', depth=6,
debugprint(new_r, prefix=' ', depth=6, file=StringIO()).getvalue())) file=StringIO(), done=done).getvalue(),
debugprint(new_r, prefix=' ', depth=6,
file=StringIO(), done=done).getvalue()))
self.replaced_by[r].append((reason, new_r)) self.replaced_by[r].append((reason, new_r))
if r in self.equiv: if r in self.equiv:
......
...@@ -34,7 +34,8 @@ VALID_ASSOC = set(['left', 'right', 'either']) ...@@ -34,7 +34,8 @@ VALID_ASSOC = set(['left', 'right', 'either'])
def debugprint(obj, depth=-1, print_type=False, def debugprint(obj, depth=-1, print_type=False,
file=None, ids='CHAR', stop_on_name=False): file=None, ids='CHAR', stop_on_name=False,
done=None):
"""Print a computation graph as text to stdout or a file. """Print a computation graph as text to stdout or a file.
:type obj: Variable, Apply, or Function instance :type obj: Variable, Apply, or Function instance
...@@ -53,6 +54,9 @@ def debugprint(obj, depth=-1, print_type=False, ...@@ -53,6 +54,9 @@ def debugprint(obj, depth=-1, print_type=False,
"" - don't print an identifier "" - don't print an identifier
:param stop_on_name: When True, if a node in the graph has a name, :param stop_on_name: When True, if a node in the graph has a name,
we don't print anything below it. we don't print anything below it.
:type done: None or dict
:param done: A dict where we store the ids of printed node.
Useful to have multiple call to debugprint share the ids for the same node.
:returns: string if `file` == 'str', else file arg :returns: string if `file` == 'str', else file arg
...@@ -80,6 +84,7 @@ def debugprint(obj, depth=-1, print_type=False, ...@@ -80,6 +84,7 @@ def debugprint(obj, depth=-1, print_type=False,
_file = sys.stdout _file = sys.stdout
else: else:
_file = file _file = file
if done is None:
done = dict() done = dict()
results_to_print = [] results_to_print = []
profile_list = [] profile_list = []
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论