提交 55d20a08 authored 作者: delallea's avatar delallea

Merge pull request #65 from nouiz/gnumpy

Added the test in reverse order.
import numpy
import theano
from theano.misc.gnumpy_utils import gnumpy_available
if not gnumpy_available:
from nose.plugins.skip import SkipTest
raise SkipTest("gnumpy not installed. Skip test of theano op with pycuda code.")
raise SkipTest("gnumpy not installed. Skip test related to it.")
from theano.misc.gnumpy_utils import garray_to_cudandarray, cudandarray_to_garray
from theano.misc.gnumpy_utils import (garray_to_cudandarray,
cudandarray_to_garray)
import gnumpy
def test(shape=(3,4,5)):
def test(shape=(3, 4, 5)):
"""
Make sure that the gnumpy conversion is exact from garray to
CudaNdarray back to garray.
"""
Make sure that the gnumpy conversion is exact.
"""
gpu = theano.sandbox.cuda.basic_ops.gpu_from_host
U = gpu(theano.tensor.ftensor3('U'))
ii = theano.function([U], gpu(U+1))
ii = theano.function([U], gpu(U + 1))
A = gnumpy.rand(*shape)
A_cnd = garray_to_cudandarray(A)
assert A_cnd.shape == A.shape
# dtype always float32
# garray don't have strides
B_cnd = ii(A_cnd)
B = cudandarray_to_garray(B_cnd)
assert A_cnd.shape == A.shape
from numpy import array
B2 = array(B_cnd)
u = (A+1).asarray()
u = (A + 1).asarray()
v = B.asarray()
w = B2
assert abs(u-v).max() == 0
assert abs(u-w).max() == 0
w = array(B_cnd)
assert (u == v).all()
assert (u == w).all()
def test2(shape=(3, 4, 5)):
"""
Make sure that the gnumpy conversion is exact from CudaNdarray to
garray back to CudaNdarray.
"""
gpu = theano.sandbox.cuda.basic_ops.gpu_from_host
U = gpu(theano.tensor.ftensor3('U'))
ii = theano.function([U], gpu(U + 1))
A = numpy.random.rand(*shape).astype('float32')
A_cnd = theano.sandbox.cuda.CudaNdarray(A)
A_gar = cudandarray_to_garray(A_cnd)
assert A_cnd.shape == A_gar.shape
# dtype always float32
# garray don't have strides
B = garray_to_cudandarray(A_gar)
assert A_cnd.shape == B.shape
# dtype always float32
assert A_cnd._strides == B._strides
assert A_cnd.gpudata == B.gpudata
v = numpy.asarray(B)
assert (v == A).all()
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论