提交 e4a5bf7d authored 作者: Frederic Bastien's avatar Frederic Bastien

Move BadOptimization error to allow to reuse

上级 56da8ca8
......@@ -153,165 +153,8 @@ class BadThunkOutput(DebugModeError):
return ret
class BadOptimization(DebugModeError):
"""
Exception: some variable and its substitute take different runtime values.
"""
new_r = None
"""
A `Variable` instance that took a different value from `old_r`,
but which replaced `old_r`.
"""
old_r = None
"""
A `Variable` instance that was replaced by `new_r`.
"""
old_r_val = None
"""
The value computed for `old_r`.
"""
new_r_val = None
"""
The value computed for `new_r`.
"""
reason = None
"""
An object that indicates why old_r was turned into new_r.
Convention is that this is the name of the optimization that
requested the replacement.
"""
old_graph = ""
"""
A multiline string representation of the graph leading to
old_r, at the time of the replacement.
"""
new_graph = ""
"""
A multiline string representation of the graph leading to
new_r, at the time of the replacement.
"""
def __init__(self, old_r, new_r, old_r_val, new_r_val, reason,
old_graph, new_graph):
super(BadOptimization, self).__init__()
self.old_r = old_r
self.new_r = new_r
self.old_r_val = old_r_val
self.new_r_val = new_r_val
self.reason = reason
self.old_graph = old_graph
self.new_graph = new_graph
def __str__(self):
return self.str_diagnostic()
def str_diagnostic(self):
"""
Return a pretty multiline string representating the cause
of the exception.
"""
sio = StringIO()
val_str_len_limit = 800
print("BadOptimization Error", super(BadOptimization,
self).__str__(), file=sio)
print(" Variable: id", id(self.new_r), self.new_r, file=sio)
print(" Op", self.new_r.owner, file=sio)
print(" Value Type:", type(self.new_r_val), file=sio)
try:
ssio = StringIO()
print(" Old Value shape, dtype, strides:", end=' ', file=ssio)
print(self.old_r_val.shape, end=' ', file=ssio)
print(self.old_r_val.dtype, end=' ', file=ssio)
print(self.old_r_val.strides, file=ssio)
# only if all succeeds to we add anything to sio
print(ssio.getvalue(), file=sio)
except Exception:
pass
str_old_r_val = str(self.old_r_val)
if len(str_old_r_val) > val_str_len_limit:
print(" Old Value: ", str(self.old_r_val)[
:val_str_len_limit], '...', file=sio)
else:
print(" Old Value: ", str(self.old_r_val), file=sio)
try:
ssio = StringIO()
print(" New Value shape, dtype, strides:", end=' ', file=ssio)
print(self.new_r_val.shape, end=' ', file=ssio)
print(self.new_r_val.dtype, end=' ', file=ssio)
print(self.new_r_val.strides, file=ssio)
# only if all succeeds to we add anything to sio
print(ssio.getvalue(), file=sio)
except Exception:
pass
str_new_r_val = str(self.new_r_val)
if len(str_new_r_val) > val_str_len_limit:
print(" New Value: ", str(self.new_r_val)[
:val_str_len_limit], '...', file=sio)
else:
print(" New Value: ", str(self.new_r_val), file=sio)
try:
ov = np.asarray(self.old_r_val)
nv = np.asarray(self.new_r_val)
ssio = StringIO()
abs_diff = np.absolute(nv - ov)
print(" Max Abs Diff: ", np.max(abs_diff), file=ssio)
print(" Mean Abs Diff: ", np.mean(abs_diff), file=ssio)
print(" Median Abs Diff: ", np.median(abs_diff), file=ssio)
print(" Std Abs Diff: ", np.std(abs_diff), file=ssio)
arg_max_val = np.argmax(abs_diff)
values_at_max = (nv.flatten()[arg_max_val],
ov.flatten()[arg_max_val])
print(" Value at Max Diff: ", values_at_max, file=ssio)
# N.B. the maximum(..., 1e-8) protects against div by 0 when
# nv == ov == 0
reldiff = (abs_diff /
np.maximum(np.absolute(nv) + np.absolute(ov),
1e-8))
print(" Max Rel Diff: ", np.max(reldiff), file=ssio)
print(" Mean Rel Diff: ", np.mean(reldiff), file=ssio)
print(" Median Rel Diff: ", np.median(reldiff), file=ssio)
print(" Std Rel Diff: ", np.std(reldiff), file=ssio)
arg_max_val = np.argmax(reldiff)
values_at_max = (nv.flatten()[arg_max_val],
ov.flatten()[arg_max_val])
print(" Value at Max Diff: ", values_at_max, file=ssio)
# only if all succeeds to we add anything to sio
print(ssio.getvalue(), file=sio)
except Exception:
pass
print(" Reason: ", str(self.reason), file=sio)
print(" Old Graph:", file=sio)
print(self.old_graph, file=sio)
print(" New Graph:", file=sio)
print(self.new_graph, file=sio)
print("", file=sio)
print("Hint: relax the tolerance by setting tensor.cmp_sloppy=1",
file=sio)
print(" or even tensor.cmp_sloppy=2 for less-strict comparison",
file=sio)
return sio.getvalue()
class BadOptimization(DebugModeError, theano.gof.toolbox.BadOptimization):
pass
class BadDestroyMap(DebugModeError):
......
......@@ -6,6 +6,9 @@ import sys
import time
import inspect
import numpy as np
from six.moves import StringIO
import theano
from theano import config
from theano.gof import graph
......@@ -33,6 +36,167 @@ class ReplacementDidntRemovedError(Exception):
pass
class BadOptimization(Exception):
"""
Exception: some variable and its substitute take different runtime values.
"""
new_r = None
"""
A `Variable` instance that took a different value from `old_r`,
but which replaced `old_r`.
"""
old_r = None
"""
A `Variable` instance that was replaced by `new_r`.
"""
old_r_val = None
"""
The value computed for `old_r`.
"""
new_r_val = None
"""
The value computed for `new_r`.
"""
reason = None
"""
An object that indicates why old_r was turned into new_r.
Convention is that this is the name of the optimization that
requested the replacement.
"""
old_graph = ""
"""
A multiline string representation of the graph leading to
old_r, at the time of the replacement.
"""
new_graph = ""
"""
A multiline string representation of the graph leading to
new_r, at the time of the replacement.
"""
def __init__(self, old_r, new_r, old_r_val, new_r_val, reason,
old_graph, new_graph):
super(BadOptimization, self).__init__()
self.old_r = old_r
self.new_r = new_r
self.old_r_val = old_r_val
self.new_r_val = new_r_val
self.reason = reason
self.old_graph = old_graph
self.new_graph = new_graph
def __str__(self):
return self.str_diagnostic()
def str_diagnostic(self):
"""
Return a pretty multiline string representating the cause
of the exception.
"""
sio = StringIO()
val_str_len_limit = 800
print("BadOptimization Error", super(BadOptimization,
self).__str__(), file=sio)
print(" Variable: id", id(self.new_r), self.new_r, file=sio)
print(" Op", self.new_r.owner, file=sio)
print(" Value Type:", type(self.new_r_val), file=sio)
try:
ssio = StringIO()
print(" Old Value shape, dtype, strides:", end=' ', file=ssio)
print(self.old_r_val.shape, end=' ', file=ssio)
print(self.old_r_val.dtype, end=' ', file=ssio)
print(self.old_r_val.strides, file=ssio)
# only if all succeeds to we add anything to sio
print(ssio.getvalue(), file=sio)
except Exception:
pass
str_old_r_val = str(self.old_r_val)
if len(str_old_r_val) > val_str_len_limit:
print(" Old Value: ", str(self.old_r_val)[
:val_str_len_limit], '...', file=sio)
else:
print(" Old Value: ", str(self.old_r_val), file=sio)
try:
ssio = StringIO()
print(" New Value shape, dtype, strides:", end=' ', file=ssio)
print(self.new_r_val.shape, end=' ', file=ssio)
print(self.new_r_val.dtype, end=' ', file=ssio)
print(self.new_r_val.strides, file=ssio)
# only if all succeeds to we add anything to sio
print(ssio.getvalue(), file=sio)
except Exception:
pass
str_new_r_val = str(self.new_r_val)
if len(str_new_r_val) > val_str_len_limit:
print(" New Value: ", str(self.new_r_val)[
:val_str_len_limit], '...', file=sio)
else:
print(" New Value: ", str(self.new_r_val), file=sio)
try:
ov = np.asarray(self.old_r_val)
nv = np.asarray(self.new_r_val)
ssio = StringIO()
abs_diff = np.absolute(nv - ov)
print(" Max Abs Diff: ", np.max(abs_diff), file=ssio)
print(" Mean Abs Diff: ", np.mean(abs_diff), file=ssio)
print(" Median Abs Diff: ", np.median(abs_diff), file=ssio)
print(" Std Abs Diff: ", np.std(abs_diff), file=ssio)
arg_max_val = np.argmax(abs_diff)
values_at_max = (nv.flatten()[arg_max_val],
ov.flatten()[arg_max_val])
print(" Value at Max Diff: ", values_at_max, file=ssio)
# N.B. the maximum(..., 1e-8) protects against div by 0 when
# nv == ov == 0
reldiff = (abs_diff /
np.maximum(np.absolute(nv) + np.absolute(ov),
1e-8))
print(" Max Rel Diff: ", np.max(reldiff), file=ssio)
print(" Mean Rel Diff: ", np.mean(reldiff), file=ssio)
print(" Median Rel Diff: ", np.median(reldiff), file=ssio)
print(" Std Rel Diff: ", np.std(reldiff), file=ssio)
arg_max_val = np.argmax(reldiff)
values_at_max = (nv.flatten()[arg_max_val],
ov.flatten()[arg_max_val])
print(" Value at Max Diff: ", values_at_max, file=ssio)
# only if all succeeds to we add anything to sio
print(ssio.getvalue(), file=sio)
except Exception:
pass
print(" Reason: ", str(self.reason), file=sio)
print(" Old Graph:", file=sio)
print(self.old_graph, file=sio)
print(" New Graph:", file=sio)
print(self.new_graph, file=sio)
print("", file=sio)
print("Hint: relax the tolerance by setting tensor.cmp_sloppy=1",
file=sio)
print(" or even tensor.cmp_sloppy=2 for less-strict comparison",
file=sio)
return sio.getvalue()
class Feature(object):
"""
Base class for FunctionGraph extensions.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论