提交 9c0a9989 authored 作者: Benjamin Scellier's avatar Benjamin Scellier

file theano/misc/tests/test_pycuda_utils.py

上级 1e87be5d
from __future__ import absolute_import, print_function, division from __future__ import absolute_import, print_function, division
import numpy import numpy as np
import theano.sandbox.cuda as cuda import theano.sandbox.cuda as cuda
import theano.misc.pycuda_init import theano.misc.pycuda_init
...@@ -22,30 +22,30 @@ def test_to_gpuarray(): ...@@ -22,30 +22,30 @@ def test_to_gpuarray():
px = to_gpuarray(cx) px = to_gpuarray(cx)
assert isinstance(px, pycuda.gpuarray.GPUArray) assert isinstance(px, pycuda.gpuarray.GPUArray)
cx[0, 0] = numpy.asarray(1, dtype="float32") cx[0, 0] = np.asarray(1, dtype="float32")
# Check that they share the same memory space # Check that they share the same memory space
assert px.gpudata == cx.gpudata assert px.gpudata == cx.gpudata
assert numpy.asarray(cx[0, 0]) == 1 assert np.asarray(cx[0, 0]) == 1
assert numpy.allclose(numpy.asarray(cx), px.get()) assert np.allclose(np.asarray(cx), px.get())
assert px.dtype == cx.dtype assert px.dtype == cx.dtype
assert px.shape == cx.shape assert px.shape == cx.shape
assert all(numpy.asarray(cx._strides) * 4 == px.strides) assert all(np.asarray(cx._strides) * 4 == px.strides)
# Test when the CudaNdarray is strided # Test when the CudaNdarray is strided
cx = cx[::2, ::] cx = cx[::2, ::]
px = to_gpuarray(cx, copyif=True) px = to_gpuarray(cx, copyif=True)
assert isinstance(px, pycuda.gpuarray.GPUArray) assert isinstance(px, pycuda.gpuarray.GPUArray)
cx[0, 0] = numpy.asarray(2, dtype="float32") cx[0, 0] = np.asarray(2, dtype="float32")
# Check that they do not share the same memory space # Check that they do not share the same memory space
assert px.gpudata != cx.gpudata assert px.gpudata != cx.gpudata
assert numpy.asarray(cx[0, 0]) == 2 assert np.asarray(cx[0, 0]) == 2
assert not numpy.allclose(numpy.asarray(cx), px.get()) assert not np.allclose(np.asarray(cx), px.get())
assert px.dtype == cx.dtype assert px.dtype == cx.dtype
assert px.shape == cx.shape assert px.shape == cx.shape
assert not all(numpy.asarray(cx._strides) * 4 == px.strides) assert not all(np.asarray(cx._strides) * 4 == px.strides)
# Test that we return an error # Test that we return an error
try: try:
...@@ -59,11 +59,11 @@ def test_to_cudandarray(): ...@@ -59,11 +59,11 @@ def test_to_cudandarray():
px = pycuda.gpuarray.zeros((3, 4, 5), 'float32') px = pycuda.gpuarray.zeros((3, 4, 5), 'float32')
cx = to_cudandarray(px) cx = to_cudandarray(px)
assert isinstance(cx, cuda.CudaNdarray) assert isinstance(cx, cuda.CudaNdarray)
assert numpy.allclose(px.get(), assert np.allclose(px.get(),
numpy.asarray(cx)) np.asarray(cx))
assert px.dtype == cx.dtype assert px.dtype == cx.dtype
assert px.shape == cx.shape assert px.shape == cx.shape
assert all(numpy.asarray(cx._strides) * 4 == px.strides) assert all(np.asarray(cx._strides) * 4 == px.strides)
try: try:
px = pycuda.gpuarray.zeros((3, 4, 5), 'float64') px = pycuda.gpuarray.zeros((3, 4, 5), 'float64')
...@@ -73,7 +73,7 @@ def test_to_cudandarray(): ...@@ -73,7 +73,7 @@ def test_to_cudandarray():
pass pass
try: try:
to_cudandarray(numpy.zeros(4)) to_cudandarray(np.zeros(4))
assert False assert False
except ValueError: except ValueError:
pass pass
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论