提交 484ee1e0 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

use safe_to_cpu to transfer output results in op_lifter, since some ops return…

use safe_to_cpu to transfer output results in op_lifter, since some ops return some results already on the cpu.
上级 2a95f6a5
......@@ -55,6 +55,20 @@ def register_opt(*tags, **kwargs):
register_opt()(theano.tensor.opt.local_track_shape_i)
def safe_to_gpu(x):
if isinstance(x.type, tensor.TensorType):
return gpu_from_host(x)
else:
return x
def safe_to_cpu(x):
if isinstance(x.type, GpuArrayType):
return host_from_gpu(x)
else:
return x
def op_lifter(OP):
"""
OP(..., host_from_gpu(), ...) -> host_from_gpu(GpuOP(...))
......@@ -74,10 +88,10 @@ def op_lifter(OP):
# This is needed as sometimes new_op inherit from OP.
if new_op and new_op != node.op:
if isinstance(new_op, theano.Op):
return [host_from_gpu(o) for o in
return [safe_to_cpu(o) for o in
new_op(*node.inputs, return_list=True)]
elif isinstance(new_op, (tuple, list)):
return [host_from_gpu(o) for o in new_op]
return [safe_to_cpu(o) for o in new_op]
else: # suppose it is a variable on the GPU
return [host_from_gpu(new_op)]
return False
......@@ -463,20 +477,6 @@ def tensor_to_gpu(x):
return x
def safe_to_gpu(x):
if isinstance(x.type, tensor.TensorType):
return gpu_from_host(x)
else:
return x
def safe_to_cpu(x):
if isinstance(x.type, GpuArrayType):
return host_from_gpu(x)
else:
return x
def gpu_safe_new(x, tag=''):
"""
Internal function that constructs a new variable from x with the same
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论