提交 964dcf94 authored 作者: Frederic's avatar Frederic

rename var for clarity

上级 121bea82
......@@ -242,7 +242,7 @@ class InferShapeTester(unittest.TestCase):
assert numpy.all(out.shape == shape)
def str_diagnostic(val1, val2, rtol, atol):
def str_diagnostic(expected, value, rtol, atol):
"""Return a pretty multiline string representating the cause
of the exception"""
sio = StringIO()
......@@ -251,13 +251,13 @@ def str_diagnostic(val1, val2, rtol, atol):
ssio = StringIO()
print >> ssio, " : shape, dtype, strides, min, max, n_inf, n_nan:"
print >> ssio, " Expected :",
print >> ssio, val1.shape,
print >> ssio, val1.dtype,
print >> ssio, val1.strides,
print >> ssio, val1.min(),
print >> ssio, val1.max(),
print >> ssio, numpy.isinf(val1).sum(),
print >> ssio, numpy.isnan(val1).sum(),
print >> ssio, expected.shape,
print >> ssio, expected.dtype,
print >> ssio, expected.strides,
print >> ssio, expected.min(),
print >> ssio, expected.max(),
print >> ssio, numpy.isinf(expected).sum(),
print >> ssio, numpy.isnan(expected).sum(),
# only if all succeeds to we add anything to sio
print >> sio, ssio.getvalue()
except Exception:
......@@ -265,24 +265,24 @@ def str_diagnostic(val1, val2, rtol, atol):
try:
ssio = StringIO()
print >> ssio, " Value :",
print >> ssio, val2.shape,
print >> ssio, val2.dtype,
print >> ssio, val2.strides,
print >> ssio, val2.min(),
print >> ssio, val2.max(),
print >> ssio, numpy.isinf(val2).sum(),
print >> ssio, numpy.isnan(val2).sum(),
print >> ssio, value.shape,
print >> ssio, value.dtype,
print >> ssio, value.strides,
print >> ssio, value.min(),
print >> ssio, value.max(),
print >> ssio, numpy.isinf(value).sum(),
print >> ssio, numpy.isnan(value).sum(),
# only if all succeeds to we add anything to sio
print >> sio, ssio.getvalue()
except Exception:
pass
print >> sio, " val1 :", val1
print >> sio, " val2 :", val2
print >> sio, " expected :", expected
print >> sio, " value :", value
try:
ov = numpy.asarray(val1)
nv = numpy.asarray(val2)
ov = numpy.asarray(expected)
nv = numpy.asarray(value)
ssio = StringIO()
absdiff = numpy.absolute(nv - ov)
print >> ssio, " Max Abs Diff: ", numpy.max(absdiff)
......@@ -301,8 +301,8 @@ def str_diagnostic(val1, val2, rtol, atol):
pass
#Use the same formula as in _allclose to find the tolerance used
narrow = 'float32', 'complex64'
if ((str(val1.dtype) in narrow) or
(str(val2.dtype) in narrow)):
if ((str(expected.dtype) in narrow) or
(str(value.dtype) in narrow)):
atol_ = T.basic.float32_atol
rtol_ = T.basic.float32_rtol
else:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论