提交 e4a5bf7d authored 作者: Frederic Bastien's avatar Frederic Bastien

Move BadOptimization error to allow to reuse

上级 56da8ca8
...@@ -153,165 +153,8 @@ class BadThunkOutput(DebugModeError): ...@@ -153,165 +153,8 @@ class BadThunkOutput(DebugModeError):
return ret return ret
class BadOptimization(DebugModeError): class BadOptimization(DebugModeError, theano.gof.toolbox.BadOptimization):
""" pass
Exception: some variable and its substitute take different runtime values.
"""
new_r = None
"""
A `Variable` instance that took a different value from `old_r`,
but which replaced `old_r`.
"""
old_r = None
"""
A `Variable` instance that was replaced by `new_r`.
"""
old_r_val = None
"""
The value computed for `old_r`.
"""
new_r_val = None
"""
The value computed for `new_r`.
"""
reason = None
"""
An object that indicates why old_r was turned into new_r.
Convention is that this is the name of the optimization that
requested the replacement.
"""
old_graph = ""
"""
A multiline string representation of the graph leading to
old_r, at the time of the replacement.
"""
new_graph = ""
"""
A multiline string representation of the graph leading to
new_r, at the time of the replacement.
"""
def __init__(self, old_r, new_r, old_r_val, new_r_val, reason,
old_graph, new_graph):
super(BadOptimization, self).__init__()
self.old_r = old_r
self.new_r = new_r
self.old_r_val = old_r_val
self.new_r_val = new_r_val
self.reason = reason
self.old_graph = old_graph
self.new_graph = new_graph
def __str__(self):
return self.str_diagnostic()
def str_diagnostic(self):
"""
Return a pretty multiline string representating the cause
of the exception.
"""
sio = StringIO()
val_str_len_limit = 800
print("BadOptimization Error", super(BadOptimization,
self).__str__(), file=sio)
print(" Variable: id", id(self.new_r), self.new_r, file=sio)
print(" Op", self.new_r.owner, file=sio)
print(" Value Type:", type(self.new_r_val), file=sio)
try:
ssio = StringIO()
print(" Old Value shape, dtype, strides:", end=' ', file=ssio)
print(self.old_r_val.shape, end=' ', file=ssio)
print(self.old_r_val.dtype, end=' ', file=ssio)
print(self.old_r_val.strides, file=ssio)
# only if all succeeds to we add anything to sio
print(ssio.getvalue(), file=sio)
except Exception:
pass
str_old_r_val = str(self.old_r_val)
if len(str_old_r_val) > val_str_len_limit:
print(" Old Value: ", str(self.old_r_val)[
:val_str_len_limit], '...', file=sio)
else:
print(" Old Value: ", str(self.old_r_val), file=sio)
try:
ssio = StringIO()
print(" New Value shape, dtype, strides:", end=' ', file=ssio)
print(self.new_r_val.shape, end=' ', file=ssio)
print(self.new_r_val.dtype, end=' ', file=ssio)
print(self.new_r_val.strides, file=ssio)
# only if all succeeds to we add anything to sio
print(ssio.getvalue(), file=sio)
except Exception:
pass
str_new_r_val = str(self.new_r_val)
if len(str_new_r_val) > val_str_len_limit:
print(" New Value: ", str(self.new_r_val)[
:val_str_len_limit], '...', file=sio)
else:
print(" New Value: ", str(self.new_r_val), file=sio)
try:
ov = np.asarray(self.old_r_val)
nv = np.asarray(self.new_r_val)
ssio = StringIO()
abs_diff = np.absolute(nv - ov)
print(" Max Abs Diff: ", np.max(abs_diff), file=ssio)
print(" Mean Abs Diff: ", np.mean(abs_diff), file=ssio)
print(" Median Abs Diff: ", np.median(abs_diff), file=ssio)
print(" Std Abs Diff: ", np.std(abs_diff), file=ssio)
arg_max_val = np.argmax(abs_diff)
values_at_max = (nv.flatten()[arg_max_val],
ov.flatten()[arg_max_val])
print(" Value at Max Diff: ", values_at_max, file=ssio)
# N.B. the maximum(..., 1e-8) protects against div by 0 when
# nv == ov == 0
reldiff = (abs_diff /
np.maximum(np.absolute(nv) + np.absolute(ov),
1e-8))
print(" Max Rel Diff: ", np.max(reldiff), file=ssio)
print(" Mean Rel Diff: ", np.mean(reldiff), file=ssio)
print(" Median Rel Diff: ", np.median(reldiff), file=ssio)
print(" Std Rel Diff: ", np.std(reldiff), file=ssio)
arg_max_val = np.argmax(reldiff)
values_at_max = (nv.flatten()[arg_max_val],
ov.flatten()[arg_max_val])
print(" Value at Max Diff: ", values_at_max, file=ssio)
# only if all succeeds to we add anything to sio
print(ssio.getvalue(), file=sio)
except Exception:
pass
print(" Reason: ", str(self.reason), file=sio)
print(" Old Graph:", file=sio)
print(self.old_graph, file=sio)
print(" New Graph:", file=sio)
print(self.new_graph, file=sio)
print("", file=sio)
print("Hint: relax the tolerance by setting tensor.cmp_sloppy=1",
file=sio)
print(" or even tensor.cmp_sloppy=2 for less-strict comparison",
file=sio)
return sio.getvalue()
class BadDestroyMap(DebugModeError): class BadDestroyMap(DebugModeError):
......
...@@ -6,6 +6,9 @@ import sys ...@@ -6,6 +6,9 @@ import sys
import time import time
import inspect import inspect
import numpy as np
from six.moves import StringIO
import theano import theano
from theano import config from theano import config
from theano.gof import graph from theano.gof import graph
...@@ -33,6 +36,167 @@ class ReplacementDidntRemovedError(Exception): ...@@ -33,6 +36,167 @@ class ReplacementDidntRemovedError(Exception):
pass pass
class BadOptimization(Exception):
"""
Exception: some variable and its substitute take different runtime values.
"""
new_r = None
"""
A `Variable` instance that took a different value from `old_r`,
but which replaced `old_r`.
"""
old_r = None
"""
A `Variable` instance that was replaced by `new_r`.
"""
old_r_val = None
"""
The value computed for `old_r`.
"""
new_r_val = None
"""
The value computed for `new_r`.
"""
reason = None
"""
An object that indicates why old_r was turned into new_r.
Convention is that this is the name of the optimization that
requested the replacement.
"""
old_graph = ""
"""
A multiline string representation of the graph leading to
old_r, at the time of the replacement.
"""
new_graph = ""
"""
A multiline string representation of the graph leading to
new_r, at the time of the replacement.
"""
def __init__(self, old_r, new_r, old_r_val, new_r_val, reason,
old_graph, new_graph):
super(BadOptimization, self).__init__()
self.old_r = old_r
self.new_r = new_r
self.old_r_val = old_r_val
self.new_r_val = new_r_val
self.reason = reason
self.old_graph = old_graph
self.new_graph = new_graph
def __str__(self):
return self.str_diagnostic()
def str_diagnostic(self):
"""
Return a pretty multiline string representating the cause
of the exception.
"""
sio = StringIO()
val_str_len_limit = 800
print("BadOptimization Error", super(BadOptimization,
self).__str__(), file=sio)
print(" Variable: id", id(self.new_r), self.new_r, file=sio)
print(" Op", self.new_r.owner, file=sio)
print(" Value Type:", type(self.new_r_val), file=sio)
try:
ssio = StringIO()
print(" Old Value shape, dtype, strides:", end=' ', file=ssio)
print(self.old_r_val.shape, end=' ', file=ssio)
print(self.old_r_val.dtype, end=' ', file=ssio)
print(self.old_r_val.strides, file=ssio)
# only if all succeeds to we add anything to sio
print(ssio.getvalue(), file=sio)
except Exception:
pass
str_old_r_val = str(self.old_r_val)
if len(str_old_r_val) > val_str_len_limit:
print(" Old Value: ", str(self.old_r_val)[
:val_str_len_limit], '...', file=sio)
else:
print(" Old Value: ", str(self.old_r_val), file=sio)
try:
ssio = StringIO()
print(" New Value shape, dtype, strides:", end=' ', file=ssio)
print(self.new_r_val.shape, end=' ', file=ssio)
print(self.new_r_val.dtype, end=' ', file=ssio)
print(self.new_r_val.strides, file=ssio)
# only if all succeeds to we add anything to sio
print(ssio.getvalue(), file=sio)
except Exception:
pass
str_new_r_val = str(self.new_r_val)
if len(str_new_r_val) > val_str_len_limit:
print(" New Value: ", str(self.new_r_val)[
:val_str_len_limit], '...', file=sio)
else:
print(" New Value: ", str(self.new_r_val), file=sio)
try:
ov = np.asarray(self.old_r_val)
nv = np.asarray(self.new_r_val)
ssio = StringIO()
abs_diff = np.absolute(nv - ov)
print(" Max Abs Diff: ", np.max(abs_diff), file=ssio)
print(" Mean Abs Diff: ", np.mean(abs_diff), file=ssio)
print(" Median Abs Diff: ", np.median(abs_diff), file=ssio)
print(" Std Abs Diff: ", np.std(abs_diff), file=ssio)
arg_max_val = np.argmax(abs_diff)
values_at_max = (nv.flatten()[arg_max_val],
ov.flatten()[arg_max_val])
print(" Value at Max Diff: ", values_at_max, file=ssio)
# N.B. the maximum(..., 1e-8) protects against div by 0 when
# nv == ov == 0
reldiff = (abs_diff /
np.maximum(np.absolute(nv) + np.absolute(ov),
1e-8))
print(" Max Rel Diff: ", np.max(reldiff), file=ssio)
print(" Mean Rel Diff: ", np.mean(reldiff), file=ssio)
print(" Median Rel Diff: ", np.median(reldiff), file=ssio)
print(" Std Rel Diff: ", np.std(reldiff), file=ssio)
arg_max_val = np.argmax(reldiff)
values_at_max = (nv.flatten()[arg_max_val],
ov.flatten()[arg_max_val])
print(" Value at Max Diff: ", values_at_max, file=ssio)
# only if all succeeds to we add anything to sio
print(ssio.getvalue(), file=sio)
except Exception:
pass
print(" Reason: ", str(self.reason), file=sio)
print(" Old Graph:", file=sio)
print(self.old_graph, file=sio)
print(" New Graph:", file=sio)
print(self.new_graph, file=sio)
print("", file=sio)
print("Hint: relax the tolerance by setting tensor.cmp_sloppy=1",
file=sio)
print(" or even tensor.cmp_sloppy=2 for less-strict comparison",
file=sio)
return sio.getvalue()
class Feature(object): class Feature(object):
""" """
Base class for FunctionGraph extensions. Base class for FunctionGraph extensions.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论