提交 40235973 authored 作者: Frederic Bastien's avatar Frederic Bastien

'Added more test for cuda_ndarray.zeros() and make it raise an appropriate error…

'Added more test for cuda_ndarray.zeros() and make it raise an appropriate error when it receive no input.'
上级 bf01580a
......@@ -387,6 +387,11 @@ PyObject* CudaNdarray_ZEROS(int n, int * dims)
// Based on _Copy and _dimshuffle
PyObject* CudaNdarray_Zeros(PyObject* dummy, PyObject* shape)
{
if(!shape)
{
PyErr_SetString(PyExc_TypeError, "CudaNdarray_Zeros: function takes at least 1 argument (0 given)");
return NULL;
}
if(!PySequence_Check(shape))
{
PyErr_SetString(PyExc_TypeError, "shape argument must be a sequence");
......
......@@ -813,13 +813,26 @@ def test_setitem_rightvalue_ndarray_fails():
#print e
assert True
def test_zeros_basic_3d_tensor():
_a = cuda_ndarray.CudaNdarray.zeros((3,4,5))
assert numpy.allclose(numpy.asarray(_a), numpy.zeros((3,4,5)))
def test_zeros_basic_vector():
_a = cuda_ndarray.CudaNdarray.zeros((300,))
assert numpy.allclose(numpy.asarray(_a), numpy.zeros((300,)))
def test_zeros_basic():
# 3d
for shp in [(3,4,5), (300,), ()]:
_a = cuda_ndarray.CudaNdarray.zeros(shp)
_n = numpy.zeros(shp, dtype="float32")
assert numpy.allclose(numpy.asarray(_a), _n)
assert _a.shape == _n.shape
assert all(_a._strides == numpy.asarray(_n.strides)/4)
try:
_n = numpy.zeros()
except TypeError:
pass
else:
raise Exception("An error was expected!")
try:
_a = cuda_ndarray.CudaNdarray.zeros()
except TypeError:
pass
else:
raise Exception("An error was expected!")
def test_base():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论