Unverified 提交 7b648eb9 authored 作者: Frédéric Bastien's avatar Frédéric Bastien 提交者: GitHub

Merge pull request #6569 from abergeron/fix_values_debug

Add a map of values_approx functions from cpu to gpu.
......@@ -30,7 +30,7 @@ except ImportError:
pass
from .type import (GpuArrayType, GpuArrayConstant, gpu_context_type,
get_context, ContextNotDefined)
get_context, ContextNotDefined, EQ_MAP)
from .fp16_help import write_w
......@@ -581,7 +581,8 @@ class HostFromGpu(Op):
# Keep the special comparison if there is one.
values_eq_approx = getattr(x.tag, 'values_eq_approx', None)
if values_eq_approx:
out_var.tag.values_eq_approx = values_eq_approx
out_var.tag.values_eq_approx = EQ_MAP.get(values_eq_approx,
values_eq_approx)
return Apply(self, [x], [out_var])
def perform(self, node, inp, out):
......@@ -674,7 +675,8 @@ class GpuFromHost(Op):
# Keep the special comparison if there is one.
values_eq_approx = getattr(x.tag, 'values_eq_approx', None)
if values_eq_approx:
out_var.tag.values_eq_approx = values_eq_approx
out_var.tag.values_eq_approx = EQ_MAP.get(values_eq_approx,
values_eq_approx)
return Apply(self, [x], [out_var])
def get_params(self, node):
......
......@@ -360,33 +360,8 @@ class GpuArrayType(Type):
def values_eq_approx(a, b,
allow_remove_inf=False, allow_remove_nan=False,
rtol=None, atol=None):
if a.shape != b.shape or a.dtype != b.dtype:
return False
if str(a.dtype) in theano.tensor.discrete_dtypes:
return GpuArrayType.values_eq(a, b)
else:
if allow_remove_inf or allow_remove_nan:
raise NotImplementedError(
"GpuArrayType.values_eq_approx() don't implemented the"
" allow_remove_inf and allow_remove_nan parameter")
atol_, rtol_ = theano.tensor.basic._get_atol_rtol(a, b)
if rtol is not None:
rtol_ = rtol
if atol is not None:
atol_ = atol
res = elemwise2(a, '', b, a, odtype=np.dtype('bool'),
op_tmpl="res = (fabs(a - b) <"
"(%(atol_)s + %(rtol_)s * fabs(b)))" %
locals())
ret = np.asarray(res).all()
if ret:
return True
# maybe the trouble is that there are NaNs
an = np.asarray(a)
bn = np.asarray(b)
return tensor.TensorType.values_eq_approx(
an, bn, allow_remove_inf=allow_remove_inf,
allow_remove_nan=allow_remove_nan, rtol=rtol, atol=atol)
return values_eq_approx(a, b, allow_remove_inf, allow_remove_nan,
rtol, atol)
@staticmethod
def may_share_memory(a, b):
......@@ -542,6 +517,65 @@ class GpuArrayType(Type):
return (2, ver[0])
def values_eq_approx(a, b, allow_remove_inf=False, allow_remove_nan=False,
rtol=None, atol=None):
if a.shape != b.shape or a.dtype != b.dtype:
return False
if str(a.dtype) in theano.tensor.discrete_dtypes:
return GpuArrayType.values_eq(a, b)
else:
if allow_remove_inf or allow_remove_nan:
raise NotImplementedError(
"GpuArrayType.values_eq_approx() don't implemented the"
" allow_remove_inf and allow_remove_nan parameter")
atol_, rtol_ = theano.tensor.basic._get_atol_rtol(a, b)
if rtol is not None:
rtol_ = rtol
if atol is not None:
atol_ = atol
res = elemwise2(a, '', b, a, odtype=np.dtype('bool'),
op_tmpl="res = (fabs(a - b) <"
"(%(atol_)s + %(rtol_)s * fabs(b)))" %
locals())
ret = np.asarray(res).all()
if ret:
return True
# maybe the trouble is that there are NaNs
an = np.asarray(a)
bn = np.asarray(b)
return tensor.TensorType.values_eq_approx(
an, bn, allow_remove_inf=allow_remove_inf,
allow_remove_nan=allow_remove_nan, rtol=rtol, atol=atol)
def values_eq_approx_remove_inf(a, b):
return values_eq_approx(a, b, True)
def values_eq_approx_remove_nan(a, b):
return values_eq_approx(a, b, False, True)
def values_eq_approx_remove_inf_nan(a, b):
return values_eq_approx(a, b, True, True)
# This is to map ndarray-specific versions of these functions to the GPU.
EQ_MAP = {
theano.tensor.type.values_eq_approx: values_eq_approx,
theano.tensor.type.values_eq_approx_remove_inf:
values_eq_approx_remove_inf,
theano.tensor.type.values_eq_approx_remove_nan:
values_eq_approx_remove_nan,
theano.tensor.type.values_eq_approx_remove_inf_nan:
values_eq_approx_remove_inf_nan,
}
# Add a reverse map too.
EQ_MAP.update(list((v, k) for k, v in EQ_MAP.items()))
class _operators(_tensor_py_operators):
def _as_TensorVariable(self):
from .basic_ops import host_from_gpu
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论