提交 333b0887 authored 作者: fsavard's avatar fsavard

Changed ZerosWithPattern to Zeros/zeros (had started that in last commit too, but fixed a few bugs)

上级 2a36222d
...@@ -1465,13 +1465,12 @@ class GpuJoin(tensor.Join): ...@@ -1465,13 +1465,12 @@ class GpuJoin(tensor.Join):
final_shape = list(cndas[0].shape) final_shape = list(cndas[0].shape)
final_shape[axis] = width_sum final_shape[axis] = width_sum
# just to be explicit, set -1 for broadcastable # just to be explicit, check that dim=1 for broadcastable
# dimensions # dimensions
for i, val in enumerate(node.outputs[0].type.broadcastable): for i, bcastable in enumerate(node.outputs[0].type.broadcastable):
if val: assert not bcastable or final_shape[i] == 1, "Broadcastable dimension but dim != 1, this is invalid"
final_shape[i] = -1
rval = cuda_ndarray.cuda_ndarray.CudaNdarray.zeros_with_pattern(final_shape) rval = cuda_ndarray.cuda_ndarray.CudaNdarray.zeros(final_shape)
curpos = 0 curpos = 0
......
...@@ -249,28 +249,28 @@ PyObject * CudaNdarray_CreateArrayObj(CudaNdarray * self) ...@@ -249,28 +249,28 @@ PyObject * CudaNdarray_CreateArrayObj(CudaNdarray * self)
} }
// declared as a static method // declared as a static method (hence "dummy" is not used)
// Based on _Copy and _dimshuffle // Based on _Copy and _dimshuffle
PyObject* CudaNdarray_Zeros(PyObject* dummy, PyObject* pattern) PyObject* CudaNdarray_Zeros(PyObject* dummy, PyObject* shape)
{ {
if(!PySequence_Check(pattern)) if(!PySequence_Check(shape))
{ {
PyErr_SetString(PyExc_TypeError, "pattern argument must be a sequence"); PyErr_SetString(PyExc_TypeError, "shape argument must be a sequence");
return NULL; return NULL;
} }
int patlen = PySequence_Length(pattern); int shplen = PySequence_Length(shape);
if (patlen == 0) if (shplen == 0)
{ {
PyErr_SetString(PyExc_ValueError, PyErr_SetString(PyExc_ValueError,
"CudaNdarray_Zeros: empty pattern"); "CudaNdarray_Zeros: empty shape not allowed");
return NULL; return NULL;
} }
//fprintf(stdout, "Pattern length: %d\n", patlen); //fprintf(stdout, "Pattern length: %d\n", shplen);
int* newdims = (int *)malloc(sizeof(int) * 2 * patlen); int* newdims = (int *)malloc(sizeof(int) * 2 * shplen);
if (!newdims) if (!newdims)
{ {
...@@ -279,16 +279,16 @@ PyObject* CudaNdarray_Zeros(PyObject* dummy, PyObject* pattern) ...@@ -279,16 +279,16 @@ PyObject* CudaNdarray_Zeros(PyObject* dummy, PyObject* pattern)
return NULL; return NULL;
} }
int* newstrides = newdims + patlen; int* newstrides = newdims + shplen;
// strides are in number of floats, not bytes // strides are in number of floats, not bytes
int cur_stride = 1; int cur_stride = 1;
// start from the end to compute strides // start from the end to compute strides
for (int i = patlen-1; i >= 0; --i) for (int i = shplen-1; i >= 0; --i)
{ {
PyObject* pat_el_obj = PySequence_GetItem(pattern, i); PyObject* shp_el_obj = PySequence_GetItem(shape, i);
if(pat_el_obj == NULL) if(shp_el_obj == NULL)
{ {
// shouldn't happen since we checked length before... // shouldn't happen since we checked length before...
PyErr_SetString(PyExc_RuntimeError, "CudaNdarray_Zeros: Index out of bound in sequence"); PyErr_SetString(PyExc_RuntimeError, "CudaNdarray_Zeros: Index out of bound in sequence");
...@@ -296,18 +296,18 @@ PyObject* CudaNdarray_Zeros(PyObject* dummy, PyObject* pattern) ...@@ -296,18 +296,18 @@ PyObject* CudaNdarray_Zeros(PyObject* dummy, PyObject* pattern)
return NULL; return NULL;
} }
int pat_el = PyInt_AsLong(pat_el_obj); int shp_el = PyInt_AsLong(shp_el_obj);
if (pat_el == 0) if (shp_el <= 0)
{ {
PyErr_SetString(PyExc_ValueError, "CudaNdarray_Zeros: pattern must not contain 0 for size of a dimension"); PyErr_SetString(PyExc_ValueError, "CudaNdarray_Zeros: shape must not contain 0 (or negative value) for size of a dimension");
free(newdims); free(newdims);
return NULL; return NULL;
} }
// apparently, from looking at alloc_contiguous, we set // based on alloc_contiguous, we set
// stride=0 if the dim == 1 // stride=0 if the dim == 1
if (pat_el < 0 || pat_el == 1) if (shp_el == 1)
{ {
// broadcast // broadcast
newdims[i] = 1; newdims[i] = 1;
...@@ -315,7 +315,7 @@ PyObject* CudaNdarray_Zeros(PyObject* dummy, PyObject* pattern) ...@@ -315,7 +315,7 @@ PyObject* CudaNdarray_Zeros(PyObject* dummy, PyObject* pattern)
} }
else else
{ {
newdims[i] = pat_el; newdims[i] = shp_el;
newstrides[i] = cur_stride; newstrides[i] = cur_stride;
} }
...@@ -333,7 +333,7 @@ PyObject* CudaNdarray_Zeros(PyObject* dummy, PyObject* pattern) ...@@ -333,7 +333,7 @@ PyObject* CudaNdarray_Zeros(PyObject* dummy, PyObject* pattern)
return NULL; return NULL;
} }
if (CudaNdarray_alloc_contiguous(rval, patlen, newdims)) if (CudaNdarray_alloc_contiguous(rval, shplen, newdims))
{ {
PyErr_SetString(PyExc_RuntimeError, "CudaNdarray_Zeros: allocation failed."); PyErr_SetString(PyExc_RuntimeError, "CudaNdarray_Zeros: allocation failed.");
free(newdims); free(newdims);
...@@ -352,13 +352,6 @@ PyObject* CudaNdarray_Zeros(PyObject* dummy, PyObject* pattern) ...@@ -352,13 +352,6 @@ PyObject* CudaNdarray_Zeros(PyObject* dummy, PyObject* pattern)
return NULL; return NULL;
} }
// change the strides to account for broadcastability
// (not necessary as alloc_contiguous sets stride=0 for dim=1)
//for (int i = 0; i < patlen; ++i)
//{
// CudaNdarray_set_stride(rval, i, newstrides[i]);
//}
if (cnda_copy_structure_to_device(rval)) if (cnda_copy_structure_to_device(rval))
{ {
PyErr_SetString(PyExc_RuntimeError, "CudaNdarray_Zeros: syncing structure to device failed"); PyErr_SetString(PyExc_RuntimeError, "CudaNdarray_Zeros: syncing structure to device failed");
...@@ -707,7 +700,7 @@ static PyMethodDef CudaNdarray_methods[] = ...@@ -707,7 +700,7 @@ static PyMethodDef CudaNdarray_methods[] =
{"__deepcopy__", {"__deepcopy__",
(PyCFunction)CudaNdarray_DeepCopy, METH_O, (PyCFunction)CudaNdarray_DeepCopy, METH_O,
"Create a copy of this object"}, "Create a copy of this object"},
{"zeros_with_pattern", {"zeros",
(PyCFunction)CudaNdarray_Zeros, METH_STATIC, (PyCFunction)CudaNdarray_Zeros, METH_STATIC,
"Create a new CudaNdarray with specified shape, filled with zeros."}, "Create a new CudaNdarray with specified shape, filled with zeros."},
{"copy", {"copy",
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论