提交 bac35334 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

support allow_remove_{inf,nan}

上级 7b648eb9
...@@ -524,10 +524,7 @@ def values_eq_approx(a, b, allow_remove_inf=False, allow_remove_nan=False, ...@@ -524,10 +524,7 @@ def values_eq_approx(a, b, allow_remove_inf=False, allow_remove_nan=False,
if str(a.dtype) in theano.tensor.discrete_dtypes: if str(a.dtype) in theano.tensor.discrete_dtypes:
return GpuArrayType.values_eq(a, b) return GpuArrayType.values_eq(a, b)
else: else:
if allow_remove_inf or allow_remove_nan: if not (allow_remove_inf or allow_remove_nan):
raise NotImplementedError(
"GpuArrayType.values_eq_approx() don't implemented the"
" allow_remove_inf and allow_remove_nan parameter")
atol_, rtol_ = theano.tensor.basic._get_atol_rtol(a, b) atol_, rtol_ = theano.tensor.basic._get_atol_rtol(a, b)
if rtol is not None: if rtol is not None:
rtol_ = rtol rtol_ = rtol
...@@ -540,7 +537,7 @@ def values_eq_approx(a, b, allow_remove_inf=False, allow_remove_nan=False, ...@@ -540,7 +537,7 @@ def values_eq_approx(a, b, allow_remove_inf=False, allow_remove_nan=False,
ret = np.asarray(res).all() ret = np.asarray(res).all()
if ret: if ret:
return True return True
# maybe the trouble is that there are NaNs
an = np.asarray(a) an = np.asarray(a)
bn = np.asarray(b) bn = np.asarray(b)
return tensor.TensorType.values_eq_approx( return tensor.TensorType.values_eq_approx(
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论