提交 eef8dadf authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Patch values_eq_approx to follow changes.

上级 50fdd757
...@@ -301,20 +301,14 @@ class GpuArrayType(Type): ...@@ -301,20 +301,14 @@ class GpuArrayType(Type):
raise NotImplementedError( raise NotImplementedError(
"GpuArrayType.values_eq_approx() don't implemented the" "GpuArrayType.values_eq_approx() don't implemented the"
" allow_remove_inf and allow_remove_nan parameter") " allow_remove_inf and allow_remove_nan parameter")
if a.dtype == 'float16' or b.dtype == 'float16':
an = numpy.asarray(a)
bn = numpy.asarray(b)
return tensor.TensorType.values_eq_approx(
an, bn, allow_remove_inf=allow_remove_inf,
allow_remove_nan=allow_remove_nan, rtol=rtol, atol=atol)
atol_, rtol_ = theano.tensor.basic._get_atol_rtol(a, b) atol_, rtol_ = theano.tensor.basic._get_atol_rtol(a, b)
if rtol is not None: if rtol is not None:
rtol_ = rtol rtol_ = rtol
if atol is not None: if atol is not None:
atol_ = atol atol_ = atol
res = elemwise2(a, '', b, a, odtype=numpy.dtype('bool'), res = elemwise2(a, '', b, a, odtype=numpy.dtype('bool'),
op_tmpl="res[i] = (fabs(%%(a)s - %%(b)s) <" op_tmpl="res = (fabs(a - b) <"
"(%(atol_)s + %(rtol_)s * fabs(%%(b)s)))" % "(%(atol_)s + %(rtol_)s * fabs(b)))" %
locals()) locals())
ret = numpy.asarray(res).all() ret = numpy.asarray(res).all()
if ret: if ret:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论