提交 7eea950a authored 作者: David Warde-Farley's avatar David Warde-Farley

Merge pull request #599 from lamblin/cudandarray_0_stride

Set stride to 0 on length-1 dimension in subtensor
...@@ -1542,7 +1542,8 @@ CudaNdarray_Subscript(PyObject * py_self, PyObject * key) ...@@ -1542,7 +1542,8 @@ CudaNdarray_Subscript(PyObject * py_self, PyObject * key)
return NULL; return NULL;
} }
//initialize dimension 0 of rval //initialize dimension 0 of rval
CudaNdarray_set_stride(rval, 0, step * CudaNdarray_HOST_STRIDES(self)[0]); CudaNdarray_set_stride(rval, 0,
(slen == 1) ? 0 : step * CudaNdarray_HOST_STRIDES(self)[0]);
CudaNdarray_set_dim(rval, 0, slen); CudaNdarray_set_dim(rval, 0, slen);
if (verbose) std::cerr << "rval stride " << CudaNdarray_HOST_STRIDES(rval)[0] << "\n"; if (verbose) std::cerr << "rval stride " << CudaNdarray_HOST_STRIDES(rval)[0] << "\n";
// initialize dimensions > 0 of rval // initialize dimensions > 0 of rval
...@@ -1614,7 +1615,8 @@ CudaNdarray_Subscript(PyObject * py_self, PyObject * key) ...@@ -1614,7 +1615,8 @@ CudaNdarray_Subscript(PyObject * py_self, PyObject * key)
return NULL; return NULL;
} }
rval->devdata += start * CudaNdarray_HOST_STRIDES(self)[d]; rval->devdata += start * CudaNdarray_HOST_STRIDES(self)[d];
CudaNdarray_set_stride(rval, rval_d, step * CudaNdarray_HOST_STRIDES(self)[d]); CudaNdarray_set_stride(rval, rval_d,
(slen == 1) ? 0 : step * CudaNdarray_HOST_STRIDES(self)[d]);
CudaNdarray_set_dim(rval, rval_d, slen); CudaNdarray_set_dim(rval, rval_d, slen);
if (0) if (0)
{ {
......
...@@ -463,6 +463,15 @@ def test_stride_manipulation(): ...@@ -463,6 +463,15 @@ def test_stride_manipulation():
assert numpy.all(c == [[5, 4, 3], [2, 1, 0]]) assert numpy.all(c == [[5, 4, 3], [2, 1, 0]])
def test_subtensor_broadcastable():
a = numpy.zeros((2, 7), dtype='float32')
cuda_a = cuda_ndarray.CudaNdarray(a)
# Will have shape (1, 7), so the stride in the first dim should be 0
sub_a = cuda_a[1:]
assert sub_a.shape == (1, 7)
assert sub_a._strides[0] == 0
def test_copy_subtensor0(): def test_copy_subtensor0():
sizeof_float=4 sizeof_float=4
a = theano._asarray(numpy.random.rand(30,20,5,5), dtype='float32') a = theano._asarray(numpy.random.rand(30,20,5,5), dtype='float32')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论