提交 121bea82 authored 作者: Frederic's avatar Frederic

Refactoring to have BadThunkOutput reuse str_diagnostic()

上级 309224d0
...@@ -171,7 +171,8 @@ class BadThunkOutput(DebugModeError): ...@@ -171,7 +171,8 @@ class BadThunkOutput(DebugModeError):
of the exception""" of the exception"""
sio = StringIO() sio = StringIO()
print >> sio, "BadThunkOutput" print >> sio, "BadThunkOutput"
print >> sio, " variable :", self.r print >> sio, " Apply :", self.r.owner
print >> sio, " op :", self.offending_op()
print >> sio, " Outputs Type:", self.r.type print >> sio, " Outputs Type:", self.r.type
print >> sio, " Outputs Shape:", getattr(self.val1, 'shape', None) print >> sio, " Outputs Shape:", getattr(self.val1, 'shape', None)
print >> sio, " Outputs Strides:", getattr(self.val1, 'strides', None) print >> sio, " Outputs Strides:", getattr(self.val1, 'strides', None)
...@@ -180,60 +181,15 @@ class BadThunkOutput(DebugModeError): ...@@ -180,60 +181,15 @@ class BadThunkOutput(DebugModeError):
for val in self.inputs_val] for val in self.inputs_val]
print >> sio, " Inputs Strides:", [getattr(val, 'strides', None) print >> sio, " Inputs Strides:", [getattr(val, 'strides', None)
for val in self.inputs_val] for val in self.inputs_val]
print >> sio, " Apply :", self.r.owner print >> sio, " Bad Variable:", self.r
print >> sio, " thunk1 :", self.thunk1 print >> sio, " thunk1 :", self.thunk1
print >> sio, " thunk2 :", self.thunk2 print >> sio, " thunk2 :", self.thunk2
print >> sio, " val1 :", self.val1
print >> sio, " val2 :", self.val2 #Don't import it at the top of the file to prevent circular import.
print >> sio, " op :", self.offending_op() utt = theano.tests.unittest_tools
try: print >> sio, utt.str_diagnostic(self.val1, self.val2, None, None)
ssio = StringIO() ret = sio.getvalue()
print >> ssio, " Value 1 : shape, dtype, strides, min, max, n_inf, n_nan:", return ret
print >> ssio, self.val1.shape,
print >> ssio, self.val1.dtype,
print >> ssio, self.val1.strides,
print >> ssio, self.val1.min(),
print >> ssio, self.val1.max(),
print >> ssio, numpy.isinf(self.val1).sum(),
print >> ssio, numpy.isnan(self.val1).sum(),
# only if all succeeds to we add anything to sio
print >> sio, ssio.getvalue()
except Exception:
pass
try:
ssio = StringIO()
print >> ssio, " Value 2 : shape, dtype, strides, min, max, n_inf, n_nan:",
print >> ssio, self.val2.shape,
print >> ssio, self.val2.dtype,
print >> ssio, self.val2.strides,
print >> ssio, self.val2.min(),
print >> ssio, self.val2.max(),
print >> ssio, numpy.isinf(self.val2).sum(),
print >> ssio, numpy.isnan(self.val2).sum(),
# only if all succeeds to we add anything to sio
print >> sio, ssio.getvalue()
except Exception:
pass
try:
ov = numpy.asarray(self.val1)
nv = numpy.asarray(self.val2)
ssio = StringIO()
absdiff = numpy.absolute(nv - ov)
print >> ssio, " Max Abs Diff: ", numpy.max(absdiff)
print >> ssio, " Mean Abs Diff: ", numpy.mean(absdiff)
print >> ssio, " Median Abs Diff: ", numpy.median(absdiff)
print >> ssio, " Std Abs Diff: ", numpy.std(absdiff)
reldiff = numpy.absolute(nv - ov) / (numpy.absolute(nv) +
numpy.absolute(ov))
print >> ssio, " Max Rel Diff: ", numpy.max(reldiff)
print >> ssio, " Mean Rel Diff: ", numpy.mean(reldiff)
print >> ssio, " Median Rel Diff: ", numpy.median(reldiff)
print >> ssio, " Std Rel Diff: ", numpy.std(reldiff)
# only if all succeeds to we add anything to sio
print >> sio, ssio.getvalue()
except Exception:
pass
return sio.getvalue()
class BadOptimization(DebugModeError): class BadOptimization(DebugModeError):
......
...@@ -242,33 +242,22 @@ class InferShapeTester(unittest.TestCase): ...@@ -242,33 +242,22 @@ class InferShapeTester(unittest.TestCase):
assert numpy.all(out.shape == shape) assert numpy.all(out.shape == shape)
class WrongValue(Exception): def str_diagnostic(val1, val2, rtol, atol):
def __init__(self, expected_val, val, rtol, atol):
self.val1 = expected_val
self.val2 = val
self.rtol = rtol
self.atol = atol
def __str__(self):
return self.str_diagnostic()
def str_diagnostic(self):
"""Return a pretty multiline string representating the cause """Return a pretty multiline string representating the cause
of the exception""" of the exception"""
sio = StringIO() sio = StringIO()
print >> sio, self.__class__.__name__
try: try:
ssio = StringIO() ssio = StringIO()
print >> ssio, " : shape, dtype, strides, min, max, n_inf, n_nan:" print >> ssio, " : shape, dtype, strides, min, max, n_inf, n_nan:"
print >> ssio, " Expected :", print >> ssio, " Expected :",
print >> ssio, self.val1.shape, print >> ssio, val1.shape,
print >> ssio, self.val1.dtype, print >> ssio, val1.dtype,
print >> ssio, self.val1.strides, print >> ssio, val1.strides,
print >> ssio, self.val1.min(), print >> ssio, val1.min(),
print >> ssio, self.val1.max(), print >> ssio, val1.max(),
print >> ssio, numpy.isinf(self.val1).sum(), print >> ssio, numpy.isinf(val1).sum(),
print >> ssio, numpy.isnan(self.val1).sum(), print >> ssio, numpy.isnan(val1).sum(),
# only if all succeeds to we add anything to sio # only if all succeeds to we add anything to sio
print >> sio, ssio.getvalue() print >> sio, ssio.getvalue()
except Exception: except Exception:
...@@ -276,20 +265,24 @@ class WrongValue(Exception): ...@@ -276,20 +265,24 @@ class WrongValue(Exception):
try: try:
ssio = StringIO() ssio = StringIO()
print >> ssio, " Value :", print >> ssio, " Value :",
print >> ssio, self.val2.shape, print >> ssio, val2.shape,
print >> ssio, self.val2.dtype, print >> ssio, val2.dtype,
print >> ssio, self.val2.strides, print >> ssio, val2.strides,
print >> ssio, self.val2.min(), print >> ssio, val2.min(),
print >> ssio, self.val2.max(), print >> ssio, val2.max(),
print >> ssio, numpy.isinf(self.val2).sum(), print >> ssio, numpy.isinf(val2).sum(),
print >> ssio, numpy.isnan(self.val2).sum(), print >> ssio, numpy.isnan(val2).sum(),
# only if all succeeds to we add anything to sio # only if all succeeds to we add anything to sio
print >> sio, ssio.getvalue() print >> sio, ssio.getvalue()
except Exception: except Exception:
pass pass
print >> sio, " val1 :", val1
print >> sio, " val2 :", val2
try: try:
ov = numpy.asarray(self.val1) ov = numpy.asarray(val1)
nv = numpy.asarray(self.val2) nv = numpy.asarray(val2)
ssio = StringIO() ssio = StringIO()
absdiff = numpy.absolute(nv - ov) absdiff = numpy.absolute(nv - ov)
print >> ssio, " Max Abs Diff: ", numpy.max(absdiff) print >> ssio, " Max Abs Diff: ", numpy.max(absdiff)
...@@ -308,21 +301,33 @@ class WrongValue(Exception): ...@@ -308,21 +301,33 @@ class WrongValue(Exception):
pass pass
#Use the same formula as in _allclose to find the tolerance used #Use the same formula as in _allclose to find the tolerance used
narrow = 'float32', 'complex64' narrow = 'float32', 'complex64'
if ((str(self.val1.dtype) in narrow) or if ((str(val1.dtype) in narrow) or
(str(self.val2.dtype) in narrow)): (str(val2.dtype) in narrow)):
atol_ = T.basic.float32_atol atol_ = T.basic.float32_atol
rtol_ = T.basic.float32_rtol rtol_ = T.basic.float32_rtol
else: else:
atol_ = T.basic.float64_atol atol_ = T.basic.float64_atol
rtol_ = T.basic.float64_rtol rtol_ = T.basic.float64_rtol
if self.rtol is not None: if rtol is not None:
rtol_ = self.rtol rtol_ = rtol
if self.atol is not None: if atol is not None:
atol_ = self.atol atol_ = atol
print >> sio, " rtol, atol:", rtol_, atol_ print >> sio, " rtol, atol:", rtol_, atol_
return sio.getvalue() return sio.getvalue()
class WrongValue(Exception):
def __init__(self, expected_val, val, rtol, atol):
self.val1 = expected_val
self.val2 = val
self.rtol = rtol
self.atol = atol
def __str__(self):
s = "WrongValue\n"
return s + str_diagnostic(self.val1, self.val2, self.rtol, self.atol)
def assert_allclose(val1, val2, rtol=None, atol=None): def assert_allclose(val1, val2, rtol=None, atol=None):
if not T.basic._allclose(val1, val2, rtol, atol): if not T.basic._allclose(val1, val2, rtol, atol):
raise WrongValue(val1, val2, rtol, atol) raise WrongValue(val1, val2, rtol, atol)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论