Unverified 提交 42275cba authored 作者: abergeron's avatar abergeron 提交者: GitHub

Merge pull request #6570 from abergeron/fix_values_debug_2

support allow_remove_{inf,nan}
......@@ -524,23 +524,20 @@ def values_eq_approx(a, b, allow_remove_inf=False, allow_remove_nan=False,
if str(a.dtype) in theano.tensor.discrete_dtypes:
return GpuArrayType.values_eq(a, b)
else:
if allow_remove_inf or allow_remove_nan:
raise NotImplementedError(
"GpuArrayType.values_eq_approx() don't implemented the"
" allow_remove_inf and allow_remove_nan parameter")
atol_, rtol_ = theano.tensor.basic._get_atol_rtol(a, b)
if rtol is not None:
rtol_ = rtol
if atol is not None:
atol_ = atol
res = elemwise2(a, '', b, a, odtype=np.dtype('bool'),
op_tmpl="res = (fabs(a - b) <"
"(%(atol_)s + %(rtol_)s * fabs(b)))" %
locals())
ret = np.asarray(res).all()
if ret:
return True
# maybe the trouble is that there are NaNs
if not (allow_remove_inf or allow_remove_nan):
atol_, rtol_ = theano.tensor.basic._get_atol_rtol(a, b)
if rtol is not None:
rtol_ = rtol
if atol is not None:
atol_ = atol
res = elemwise2(a, '', b, a, odtype=np.dtype('bool'),
op_tmpl="res = (fabs(a - b) <"
"(%(atol_)s + %(rtol_)s * fabs(b)))" %
locals())
ret = np.asarray(res).all()
if ret:
return True
an = np.asarray(a)
bn = np.asarray(b)
return tensor.TensorType.values_eq_approx(
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论