提交 d0d41c3b authored 作者: Frederic's avatar Frederic

tensor.basic._allclose now accept an rtol and atol parameter.

上级 0edaad7b
......@@ -390,20 +390,24 @@ else:
#more strict. Atleast float32 precision.
float64_rtol = 1.0000000000000001e-06
def _allclose(a, b):
def _allclose(a, b, rtol=None, atol=None):
narrow = 'float32', 'complex64'
if (str(a.dtype) in narrow) or (str(b.dtype) in narrow):
atol = float32_atol
rtol = float32_rtol
atol_ = float32_atol
rtol_ = float32_rtol
else:
atol = float64_atol
rtol = float64_rtol
atol_ = float64_atol
rtol_ = float64_rtol
if rtol_ is not None:
rtol_ = rtol
if atol_ is not None:
atol_ = atol
# Work around bug in Numpy, see http://projects.scipy.org/numpy/ticket/1684
if str(b.dtype) in int_dtypes and (numpy.absolute(b) < 0).any():
b = theano._asarray(b, dtype='float64')
return numpy.allclose(a,b, atol=atol, rtol=rtol)
return numpy.allclose(a, b, atol=atol_, rtol=rtol_)
def get_constant_value(v):
"""return the constant scalar(0-D) value underlying variable `v`
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论