提交 d0d41c3b authored 作者: Frederic's avatar Frederic

tensor.basic._allclose now accept an rtol and atol parameter.

上级 0edaad7b
...@@ -390,20 +390,24 @@ else: ...@@ -390,20 +390,24 @@ else:
#more strict. Atleast float32 precision. #more strict. Atleast float32 precision.
float64_rtol = 1.0000000000000001e-06 float64_rtol = 1.0000000000000001e-06
def _allclose(a, b): def _allclose(a, b, rtol=None, atol=None):
narrow = 'float32', 'complex64' narrow = 'float32', 'complex64'
if (str(a.dtype) in narrow) or (str(b.dtype) in narrow): if (str(a.dtype) in narrow) or (str(b.dtype) in narrow):
atol = float32_atol atol_ = float32_atol
rtol = float32_rtol rtol_ = float32_rtol
else: else:
atol = float64_atol atol_ = float64_atol
rtol = float64_rtol rtol_ = float64_rtol
if rtol_ is not None:
rtol_ = rtol
if atol_ is not None:
atol_ = atol
# Work around bug in Numpy, see http://projects.scipy.org/numpy/ticket/1684 # Work around bug in Numpy, see http://projects.scipy.org/numpy/ticket/1684
if str(b.dtype) in int_dtypes and (numpy.absolute(b) < 0).any(): if str(b.dtype) in int_dtypes and (numpy.absolute(b) < 0).any():
b = theano._asarray(b, dtype='float64') b = theano._asarray(b, dtype='float64')
return numpy.allclose(a,b, atol=atol, rtol=rtol) return numpy.allclose(a, b, atol=atol_, rtol=rtol_)
def get_constant_value(v): def get_constant_value(v):
"""return the constant scalar(0-D) value underlying variable `v` """return the constant scalar(0-D) value underlying variable `v`
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论