提交 578573fc authored 作者: Frederic Bastien's avatar Frederic Bastien

Move code to keep the special values_eq_approx method more frequently.

上级 e3490009
...@@ -576,9 +576,13 @@ class HostFromGpu(Op): ...@@ -576,9 +576,13 @@ class HostFromGpu(Op):
def make_node(self, x): def make_node(self, x):
if not isinstance(x.type, GpuArrayType): if not isinstance(x.type, GpuArrayType):
raise TypeError(x) raise TypeError(x)
return Apply(self, [x], out_var = tensor.TensorType(dtype=x.dtype,
[tensor.TensorType(dtype=x.dtype, broadcastable=x.broadcastable)()
broadcastable=x.broadcastable)()]) # Keep the special comparison if there is one.
values_eq_approx = getattr(x.tag, 'values_eq_approx', None)
if values_eq_approx:
out_var.tag.values_eq_approx = values_eq_approx
return Apply(self, [x], [out_var])
def perform(self, node, inp, out): def perform(self, node, inp, out):
x, = inp x, = inp
...@@ -664,9 +668,14 @@ class GpuFromHost(Op): ...@@ -664,9 +668,14 @@ class GpuFromHost(Op):
raise TypeError(x) raise TypeError(x)
if "complex" in x.dtype: if "complex" in x.dtype:
raise TypeError("complex not supported in the new gpuarray back-end.", x) raise TypeError("complex not supported in the new gpuarray back-end.", x)
return Apply(self, [x], [GpuArrayType(broadcastable=x.broadcastable, out_var = GpuArrayType(broadcastable=x.broadcastable,
context_name=self.context_name, context_name=self.context_name,
dtype=x.dtype)()]) dtype=x.dtype)()
# Keep the special comparison if there is one.
values_eq_approx = getattr(x.tag, 'values_eq_approx', None)
if values_eq_approx:
out_var.tag.values_eq_approx = values_eq_approx
return Apply(self, [x], [out_var])
def get_params(self, node): def get_params(self, node):
return get_context(self.context_name) return get_context(self.context_name)
......
...@@ -183,22 +183,14 @@ gpu_optimizer.register('local_remove_all_assert', ...@@ -183,22 +183,14 @@ gpu_optimizer.register('local_remove_all_assert',
# in order to avoid introducin new CPU Ops, or useless ones. # in order to avoid introducin new CPU Ops, or useless ones.
def safe_to_gpu(x, ctx_name): def safe_to_gpu(x, ctx_name):
if isinstance(x.type, tensor.TensorType): if isinstance(x.type, tensor.TensorType):
ret = GpuFromHost(ctx_name)(x) return GpuFromHost(ctx_name)(x)
values_eq_approx = getattr(x.tag, 'values_eq_approx', None)
if values_eq_approx:
ret.tag.values_eq_approx = values_eq_approx
return ret
else: else:
return x return x
def safe_to_cpu(x): def safe_to_cpu(x):
if isinstance(x.type, GpuArrayType): if isinstance(x.type, GpuArrayType):
ret = x.transfer('cpu') return x.transfer('cpu')
values_eq_approx = getattr(x.tag, 'values_eq_approx', None)
if values_eq_approx:
ret.tag.values_eq_approx = values_eq_approx
return ret
else: else:
return x return x
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论