提交 eef8dadf authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Patch values_eq_approx to follow changes.

上级 50fdd757
......@@ -301,20 +301,14 @@ class GpuArrayType(Type):
raise NotImplementedError(
"GpuArrayType.values_eq_approx() don't implemented the"
" allow_remove_inf and allow_remove_nan parameter")
if a.dtype == 'float16' or b.dtype == 'float16':
an = numpy.asarray(a)
bn = numpy.asarray(b)
return tensor.TensorType.values_eq_approx(
an, bn, allow_remove_inf=allow_remove_inf,
allow_remove_nan=allow_remove_nan, rtol=rtol, atol=atol)
atol_, rtol_ = theano.tensor.basic._get_atol_rtol(a, b)
if rtol is not None:
rtol_ = rtol
if atol is not None:
atol_ = atol
res = elemwise2(a, '', b, a, odtype=numpy.dtype('bool'),
op_tmpl="res[i] = (fabs(%%(a)s - %%(b)s) <"
"(%(atol_)s + %(rtol_)s * fabs(%%(b)s)))" %
op_tmpl="res = (fabs(a - b) <"
"(%(atol_)s + %(rtol_)s * fabs(b)))" %
locals())
ret = numpy.asarray(res).all()
if ret:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论