提交 f8ce40ae authored 作者: Brandon T. Willard's avatar Brandon T. Willard

Relocate str_diagnostic function from tests to theano.compile.debugmode

上级 f683d1ac
...@@ -12,9 +12,9 @@ from functools import wraps ...@@ -12,9 +12,9 @@ from functools import wraps
from copy import copy, deepcopy from copy import copy, deepcopy
from six import integer_types from six import integer_types
from six.moves import StringIO
from theano import config from theano import config
from theano.compile.debugmode import str_diagnostic
_logger = logging.getLogger("tests.unittest_tools") _logger = logging.getLogger("tests.unittest_tools")
...@@ -282,71 +282,6 @@ class InferShapeTester: ...@@ -282,71 +282,6 @@ class InferShapeTester:
assert np.all(out.shape == shape), (out.shape, shape) assert np.all(out.shape == shape), (out.shape, shape)
def str_diagnostic(expected, value, rtol, atol):
"""Return a pretty multiline string representating the cause
of the exception"""
sio = StringIO()
try:
ssio = StringIO()
print(" : shape, dtype, strides, min, max, n_inf, n_nan:", file=ssio)
print(" Expected :", end=" ", file=ssio)
print(expected.shape, end=" ", file=ssio)
print(expected.dtype, end=" ", file=ssio)
print(expected.strides, end=" ", file=ssio)
print(expected.min(), end=" ", file=ssio)
print(expected.max(), end=" ", file=ssio)
print(np.isinf(expected).sum(), end=" ", file=ssio)
print(np.isnan(expected).sum(), end=" ", file=ssio)
# only if all succeeds to we add anything to sio
print(ssio.getvalue(), file=sio)
except Exception:
pass
try:
ssio = StringIO()
print(" Value :", end=" ", file=ssio)
print(value.shape, end=" ", file=ssio)
print(value.dtype, end=" ", file=ssio)
print(value.strides, end=" ", file=ssio)
print(value.min(), end=" ", file=ssio)
print(value.max(), end=" ", file=ssio)
print(np.isinf(value).sum(), end=" ", file=ssio)
print(np.isnan(value).sum(), end=" ", file=ssio)
# only if all succeeds to we add anything to sio
print(ssio.getvalue(), file=sio)
except Exception:
pass
print(" expected :", expected, file=sio)
print(" value :", value, file=sio)
try:
ov = np.asarray(expected)
nv = np.asarray(value)
ssio = StringIO()
absdiff = np.absolute(nv - ov)
print(" Max Abs Diff: ", np.max(absdiff), file=ssio)
print(" Mean Abs Diff: ", np.mean(absdiff), file=ssio)
print(" Median Abs Diff: ", np.median(absdiff), file=ssio)
print(" Std Abs Diff: ", np.std(absdiff), file=ssio)
reldiff = np.absolute(nv - ov) / np.absolute(ov)
print(" Max Rel Diff: ", np.max(reldiff), file=ssio)
print(" Mean Rel Diff: ", np.mean(reldiff), file=ssio)
print(" Median Rel Diff: ", np.median(reldiff), file=ssio)
print(" Std Rel Diff: ", np.std(reldiff), file=ssio)
# only if all succeeds to we add anything to sio
print(ssio.getvalue(), file=sio)
except Exception:
pass
atol_, rtol_ = T.basic._get_atol_rtol(expected, value)
if rtol is not None:
rtol_ = rtol
if atol is not None:
atol_ = atol
print(" rtol, atol:", rtol_, atol_, file=sio)
return sio.getvalue()
class WrongValue(Exception): class WrongValue(Exception):
def __init__(self, expected_val, val, rtol, atol): def __init__(self, expected_val, val, rtol, atol):
Exception.__init__(self) # to be compatible with python2.4 Exception.__init__(self) # to be compatible with python2.4
......
...@@ -157,10 +157,7 @@ class BadThunkOutput(DebugModeError): ...@@ -157,10 +157,7 @@ class BadThunkOutput(DebugModeError):
print(" thunk1 :", self.thunk1, file=sio) print(" thunk1 :", self.thunk1, file=sio)
print(" thunk2 :", self.thunk2, file=sio) print(" thunk2 :", self.thunk2, file=sio)
# Don't import it at the top of the file to prevent circular import. print(str_diagnostic(self.val1, self.val2, None, None), file=sio)
import tests.unittest_tools as utt
print(utt.str_diagnostic(self.val1, self.val2, None, None), file=sio)
ret = sio.getvalue() ret = sio.getvalue()
return ret return ret
...@@ -382,6 +379,71 @@ class InvalidValueError(DebugModeError): ...@@ -382,6 +379,71 @@ class InvalidValueError(DebugModeError):
######################## ########################
def str_diagnostic(expected, value, rtol, atol):
"""Return a pretty multiline string representating the cause
of the exception"""
sio = StringIO()
try:
ssio = StringIO()
print(" : shape, dtype, strides, min, max, n_inf, n_nan:", file=ssio)
print(" Expected :", end=" ", file=ssio)
print(expected.shape, end=" ", file=ssio)
print(expected.dtype, end=" ", file=ssio)
print(expected.strides, end=" ", file=ssio)
print(expected.min(), end=" ", file=ssio)
print(expected.max(), end=" ", file=ssio)
print(np.isinf(expected).sum(), end=" ", file=ssio)
print(np.isnan(expected).sum(), end=" ", file=ssio)
# only if all succeeds to we add anything to sio
print(ssio.getvalue(), file=sio)
except Exception:
pass
try:
ssio = StringIO()
print(" Value :", end=" ", file=ssio)
print(value.shape, end=" ", file=ssio)
print(value.dtype, end=" ", file=ssio)
print(value.strides, end=" ", file=ssio)
print(value.min(), end=" ", file=ssio)
print(value.max(), end=" ", file=ssio)
print(np.isinf(value).sum(), end=" ", file=ssio)
print(np.isnan(value).sum(), end=" ", file=ssio)
# only if all succeeds to we add anything to sio
print(ssio.getvalue(), file=sio)
except Exception:
pass
print(" expected :", expected, file=sio)
print(" value :", value, file=sio)
try:
ov = np.asarray(expected)
nv = np.asarray(value)
ssio = StringIO()
absdiff = np.absolute(nv - ov)
print(" Max Abs Diff: ", np.max(absdiff), file=ssio)
print(" Mean Abs Diff: ", np.mean(absdiff), file=ssio)
print(" Median Abs Diff: ", np.median(absdiff), file=ssio)
print(" Std Abs Diff: ", np.std(absdiff), file=ssio)
reldiff = np.absolute(nv - ov) / np.absolute(ov)
print(" Max Rel Diff: ", np.max(reldiff), file=ssio)
print(" Mean Rel Diff: ", np.mean(reldiff), file=ssio)
print(" Median Rel Diff: ", np.median(reldiff), file=ssio)
print(" Std Rel Diff: ", np.std(reldiff), file=ssio)
# only if all succeeds to we add anything to sio
print(ssio.getvalue(), file=sio)
except Exception:
pass
atol_, rtol_ = theano.tensor.basic._get_atol_rtol(expected, value)
if rtol is not None:
rtol_ = rtol
if atol is not None:
atol_ = atol
print(" rtol, atol:", rtol_, atol_, file=sio)
return sio.getvalue()
def char_from_number(number): def char_from_number(number):
""" """
Converts number to string by rendering it in base 26 using Converts number to string by rendering it in base 26 using
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论