提交 578573fc authored 作者: Frederic Bastien's avatar Frederic Bastien

Move code to keep the special values_eq_approx method more frequently.

上级 e3490009
......@@ -576,9 +576,13 @@ class HostFromGpu(Op):
def make_node(self, x):
if not isinstance(x.type, GpuArrayType):
raise TypeError(x)
return Apply(self, [x],
[tensor.TensorType(dtype=x.dtype,
broadcastable=x.broadcastable)()])
out_var = tensor.TensorType(dtype=x.dtype,
broadcastable=x.broadcastable)()
# Keep the special comparison if there is one.
values_eq_approx = getattr(x.tag, 'values_eq_approx', None)
if values_eq_approx:
out_var.tag.values_eq_approx = values_eq_approx
return Apply(self, [x], [out_var])
def perform(self, node, inp, out):
x, = inp
......@@ -664,9 +668,14 @@ class GpuFromHost(Op):
raise TypeError(x)
if "complex" in x.dtype:
raise TypeError("complex not supported in the new gpuarray back-end.", x)
return Apply(self, [x], [GpuArrayType(broadcastable=x.broadcastable,
context_name=self.context_name,
dtype=x.dtype)()])
out_var = GpuArrayType(broadcastable=x.broadcastable,
context_name=self.context_name,
dtype=x.dtype)()
# Keep the special comparison if there is one.
values_eq_approx = getattr(x.tag, 'values_eq_approx', None)
if values_eq_approx:
out_var.tag.values_eq_approx = values_eq_approx
return Apply(self, [x], [out_var])
def get_params(self, node):
return get_context(self.context_name)
......
......@@ -183,22 +183,14 @@ gpu_optimizer.register('local_remove_all_assert',
# in order to avoid introducin new CPU Ops, or useless ones.
def safe_to_gpu(x, ctx_name):
if isinstance(x.type, tensor.TensorType):
ret = GpuFromHost(ctx_name)(x)
values_eq_approx = getattr(x.tag, 'values_eq_approx', None)
if values_eq_approx:
ret.tag.values_eq_approx = values_eq_approx
return ret
return GpuFromHost(ctx_name)(x)
else:
return x
def safe_to_cpu(x):
if isinstance(x.type, GpuArrayType):
ret = x.transfer('cpu')
values_eq_approx = getattr(x.tag, 'values_eq_approx', None)
if values_eq_approx:
ret.tag.values_eq_approx = values_eq_approx
return ret
return x.transfer('cpu')
else:
return x
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论