提交 9c0a9989 authored 作者: Benjamin Scellier's avatar Benjamin Scellier

file theano/misc/tests/test_pycuda_utils.py

上级 1e87be5d
from __future__ import absolute_import, print_function, division
import numpy
import numpy as np
import theano.sandbox.cuda as cuda
import theano.misc.pycuda_init
......@@ -22,30 +22,30 @@ def test_to_gpuarray():
px = to_gpuarray(cx)
assert isinstance(px, pycuda.gpuarray.GPUArray)
cx[0, 0] = numpy.asarray(1, dtype="float32")
cx[0, 0] = np.asarray(1, dtype="float32")
# Check that they share the same memory space
assert px.gpudata == cx.gpudata
assert numpy.asarray(cx[0, 0]) == 1
assert np.asarray(cx[0, 0]) == 1
assert numpy.allclose(numpy.asarray(cx), px.get())
assert np.allclose(np.asarray(cx), px.get())
assert px.dtype == cx.dtype
assert px.shape == cx.shape
assert all(numpy.asarray(cx._strides) * 4 == px.strides)
assert all(np.asarray(cx._strides) * 4 == px.strides)
# Test when the CudaNdarray is strided
cx = cx[::2, ::]
px = to_gpuarray(cx, copyif=True)
assert isinstance(px, pycuda.gpuarray.GPUArray)
cx[0, 0] = numpy.asarray(2, dtype="float32")
cx[0, 0] = np.asarray(2, dtype="float32")
# Check that they do not share the same memory space
assert px.gpudata != cx.gpudata
assert numpy.asarray(cx[0, 0]) == 2
assert not numpy.allclose(numpy.asarray(cx), px.get())
assert np.asarray(cx[0, 0]) == 2
assert not np.allclose(np.asarray(cx), px.get())
assert px.dtype == cx.dtype
assert px.shape == cx.shape
assert not all(numpy.asarray(cx._strides) * 4 == px.strides)
assert not all(np.asarray(cx._strides) * 4 == px.strides)
# Test that we return an error
try:
......@@ -59,11 +59,11 @@ def test_to_cudandarray():
px = pycuda.gpuarray.zeros((3, 4, 5), 'float32')
cx = to_cudandarray(px)
assert isinstance(cx, cuda.CudaNdarray)
assert numpy.allclose(px.get(),
numpy.asarray(cx))
assert np.allclose(px.get(),
np.asarray(cx))
assert px.dtype == cx.dtype
assert px.shape == cx.shape
assert all(numpy.asarray(cx._strides) * 4 == px.strides)
assert all(np.asarray(cx._strides) * 4 == px.strides)
try:
px = pycuda.gpuarray.zeros((3, 4, 5), 'float64')
......@@ -73,7 +73,7 @@ def test_to_cudandarray():
pass
try:
to_cudandarray(numpy.zeros(4))
to_cudandarray(np.zeros(4))
assert False
except ValueError:
pass
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论