提交 ed61205c authored 作者: Frederic Bastien's avatar Frederic Bastien

added a as_cuda_ndarray function.

上级 a3f7b08d
...@@ -144,7 +144,7 @@ outdated!""") ...@@ -144,7 +144,7 @@ outdated!""")
GpuJoin, fscalar, fvector, fmatrix, frow, fcol, GpuJoin, fscalar, fvector, fmatrix, frow, fcol,
ftensor3, ftensor4, scalar, vector, matrix, row, col, ftensor3, ftensor4, scalar, vector, matrix, row, col,
tensor3, tensor4) tensor3, tensor4)
from basic_ops import host_from_gpu, gpu_from_host from basic_ops import host_from_gpu, gpu_from_host, as_cuda_array
import opt import opt
import cuda_ndarray import cuda_ndarray
......
...@@ -31,6 +31,14 @@ def as_cuda_ndarray_variable(x): ...@@ -31,6 +31,14 @@ def as_cuda_ndarray_variable(x):
tensor_x = tensor.as_tensor_variable(x) tensor_x = tensor.as_tensor_variable(x)
return gpu_from_host(tensor_x) return gpu_from_host(tensor_x)
def as_cuda_array(obj):
if isinstance(obj, numpy.ndarray):
return cuda_ndarray.cuda_ndarray.CudaNdarray(obj)
elif isinstance(obj, cuda_ndarray.cuda_ndarray.CudaNdarray):
return obj
else:
raise TypeError("Don't know how to cast to a CudaNdarray object")
class HostFromGpu(Op): class HostFromGpu(Op):
def __eq__(self, other): def __eq__(self, other):
return type(self) == type(other) return type(self) == type(other)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论