提交 2a36222d authored 作者: fsavard's avatar fsavard

Added checks around call to CopyFromCudaNdarray in setitem, and added Equal…

Added checks around call to CopyFromCudaNdarray in setitem, and added Equal method for comparing CudaNdarrays in C code
上级 29b9b721
...@@ -251,7 +251,7 @@ PyObject * CudaNdarray_CreateArrayObj(CudaNdarray * self) ...@@ -251,7 +251,7 @@ PyObject * CudaNdarray_CreateArrayObj(CudaNdarray * self)
// declared as a static method // declared as a static method
// Based on _Copy and _dimshuffle // Based on _Copy and _dimshuffle
PyObject* CudaNdarray_ZerosWithPattern(PyObject* dummy, PyObject* pattern) PyObject* CudaNdarray_Zeros(PyObject* dummy, PyObject* pattern)
{ {
if(!PySequence_Check(pattern)) if(!PySequence_Check(pattern))
{ {
...@@ -264,7 +264,7 @@ PyObject* CudaNdarray_ZerosWithPattern(PyObject* dummy, PyObject* pattern) ...@@ -264,7 +264,7 @@ PyObject* CudaNdarray_ZerosWithPattern(PyObject* dummy, PyObject* pattern)
if (patlen == 0) if (patlen == 0)
{ {
PyErr_SetString(PyExc_ValueError, PyErr_SetString(PyExc_ValueError,
"CudaNdarray_NewWithPattern: empty pattern"); "CudaNdarray_Zeros: empty pattern");
return NULL; return NULL;
} }
...@@ -275,7 +275,7 @@ PyObject* CudaNdarray_ZerosWithPattern(PyObject* dummy, PyObject* pattern) ...@@ -275,7 +275,7 @@ PyObject* CudaNdarray_ZerosWithPattern(PyObject* dummy, PyObject* pattern)
if (!newdims) if (!newdims)
{ {
PyErr_SetString(PyExc_MemoryError, PyErr_SetString(PyExc_MemoryError,
"CudaNdarray_NewWithPattern: Failed to allocate temporary space"); "CudaNdarray_Zeros: Failed to allocate temporary space");
return NULL; return NULL;
} }
...@@ -291,7 +291,7 @@ PyObject* CudaNdarray_ZerosWithPattern(PyObject* dummy, PyObject* pattern) ...@@ -291,7 +291,7 @@ PyObject* CudaNdarray_ZerosWithPattern(PyObject* dummy, PyObject* pattern)
if(pat_el_obj == NULL) if(pat_el_obj == NULL)
{ {
// shouldn't happen since we checked length before... // shouldn't happen since we checked length before...
PyErr_SetString(PyExc_RuntimeError, "CudaNdarray_NewWithPattern: Index out of bound in sequence"); PyErr_SetString(PyExc_RuntimeError, "CudaNdarray_Zeros: Index out of bound in sequence");
free(newdims); free(newdims);
return NULL; return NULL;
} }
...@@ -300,7 +300,7 @@ PyObject* CudaNdarray_ZerosWithPattern(PyObject* dummy, PyObject* pattern) ...@@ -300,7 +300,7 @@ PyObject* CudaNdarray_ZerosWithPattern(PyObject* dummy, PyObject* pattern)
if (pat_el == 0) if (pat_el == 0)
{ {
PyErr_SetString(PyExc_ValueError, "CudaNdarray_NewWithPattern: pattern must not contain 0 for size of a dimension"); PyErr_SetString(PyExc_ValueError, "CudaNdarray_Zeros: pattern must not contain 0 for size of a dimension");
free(newdims); free(newdims);
return NULL; return NULL;
} }
...@@ -328,14 +328,14 @@ PyObject* CudaNdarray_ZerosWithPattern(PyObject* dummy, PyObject* pattern) ...@@ -328,14 +328,14 @@ PyObject* CudaNdarray_ZerosWithPattern(PyObject* dummy, PyObject* pattern)
CudaNdarray* rval = (CudaNdarray*)CudaNdarray_new_null(); CudaNdarray* rval = (CudaNdarray*)CudaNdarray_new_null();
if (!rval) if (!rval)
{ {
PyErr_SetString(PyExc_RuntimeError, "CudaNdarray_NewWithPattern: call to new_null failed"); PyErr_SetString(PyExc_RuntimeError, "CudaNdarray_Zeros: call to new_null failed");
free(newdims); free(newdims);
return NULL; return NULL;
} }
if (CudaNdarray_alloc_contiguous(rval, patlen, newdims)) if (CudaNdarray_alloc_contiguous(rval, patlen, newdims))
{ {
PyErr_SetString(PyExc_RuntimeError, "CudaNdarray_NewWithPattern: allocation failed."); PyErr_SetString(PyExc_RuntimeError, "CudaNdarray_Zeros: allocation failed.");
free(newdims); free(newdims);
Py_DECREF(rval); Py_DECREF(rval);
return NULL; return NULL;
...@@ -361,7 +361,7 @@ PyObject* CudaNdarray_ZerosWithPattern(PyObject* dummy, PyObject* pattern) ...@@ -361,7 +361,7 @@ PyObject* CudaNdarray_ZerosWithPattern(PyObject* dummy, PyObject* pattern)
if (cnda_copy_structure_to_device(rval)) if (cnda_copy_structure_to_device(rval))
{ {
PyErr_SetString(PyExc_RuntimeError, "CudaNdarray_NewWithPattern: syncing structure to device failed"); PyErr_SetString(PyExc_RuntimeError, "CudaNdarray_Zeros: syncing structure to device failed");
free(newdims); free(newdims);
Py_DECREF(rval); Py_DECREF(rval);
return NULL; return NULL;
...@@ -708,8 +708,8 @@ static PyMethodDef CudaNdarray_methods[] = ...@@ -708,8 +708,8 @@ static PyMethodDef CudaNdarray_methods[] =
(PyCFunction)CudaNdarray_DeepCopy, METH_O, (PyCFunction)CudaNdarray_DeepCopy, METH_O,
"Create a copy of this object"}, "Create a copy of this object"},
{"zeros_with_pattern", {"zeros_with_pattern",
(PyCFunction)CudaNdarray_ZerosWithPattern, METH_STATIC, (PyCFunction)CudaNdarray_Zeros, METH_STATIC,
"Create a new CudaNdarray with specified shape and broadcastability, filled with zeros."}, "Create a new CudaNdarray with specified shape, filled with zeros."},
{"copy", {"copy",
(PyCFunction)CudaNdarray_Copy, METH_NOARGS, (PyCFunction)CudaNdarray_Copy, METH_NOARGS,
"Create a copy of this object"}, "Create a copy of this object"},
...@@ -1331,7 +1331,6 @@ CudaNdarray_setitem(PyObject *o, PyObject *key, PyObject *v) ...@@ -1331,7 +1331,6 @@ CudaNdarray_setitem(PyObject *o, PyObject *key, PyObject *v)
return -1; return -1;
} }
// Check that 'v' is compatible?
CudaNdarray* rval = (CudaNdarray*)CudaNdarray_Subscript(o, key); CudaNdarray* rval = (CudaNdarray*)CudaNdarray_Subscript(o, key);
if(rval == NULL) if(rval == NULL)
...@@ -1349,14 +1348,38 @@ CudaNdarray_setitem(PyObject *o, PyObject *key, PyObject *v) ...@@ -1349,14 +1348,38 @@ CudaNdarray_setitem(PyObject *o, PyObject *key, PyObject *v)
Py_DECREF(rval); Py_DECREF(rval);
return -1; return -1;
} }
if (cnda_copy_structure_to_device(rval))
{
PyErr_SetString(PyExc_RuntimeError, "CudaNdarray_setitem: syncing structure to device failed");
Py_DECREF(rval);
return NULL;
}
CudaNdarray *viewCopyForComparison =
(CudaNdarray*)CudaNdarray_View(rval);
PyObject *baseSavedForComparison = rval->base;
if(!viewCopyForComparison)
{
PyErr_SetString(PyExc_RuntimeError, "__setitem__ could not allocate a view to verify copy results.");
Py_DECREF((PyObject*)rval);
return -1;
}
if(CudaNdarray_CopyFromCudaNdarray(rval, (CudaNdarray*)v)) if(CudaNdarray_CopyFromCudaNdarray(rval, (CudaNdarray*)v))
{ {
Py_DECREF(rval); Py_DECREF(viewCopyForComparison);
Py_DECREF((PyObject*)rval);
return -1; return -1;
} }
// If it fails, deallocate memory (DECREF?) // Check that copy didn't modify shape or strides
assert (CudaNdarray_EqualAndIgnore(viewCopyForComparison, rval, 1, 1));
assert (rval->base == baseSavedForComparison);
assert (rval->dev_structure_fresh);
Py_DECREF((PyObject*)viewCopyForComparison);
return 0; return 0;
} }
......
...@@ -125,6 +125,61 @@ cnda_mark_dev_structure_dirty(CudaNdarray * self) ...@@ -125,6 +125,61 @@ cnda_mark_dev_structure_dirty(CudaNdarray * self)
{ {
self->dev_structure_fresh = 0; self->dev_structure_fresh = 0;
} }
int
CudaNdarray_EqualAndIgnore(CudaNdarray *cnda1, CudaNdarray *cnda2, int ignoreSync, int ignoreBase)
{
int verbose = 1;
if (!ignoreSync && cnda1->dev_structure_fresh != cnda2->dev_structure_fresh)
{
if(verbose) fprintf(stdout, "CUDANDARRAY_EQUAL FAILED : 1\n");
return 0;
}
if (cnda1->nd != cnda2->nd)
{
if(verbose) fprintf(stdout, "CUDANDARRAY_EQUAL FAILED : 2\n");
return 0;
}
for (int i=0; i < 2*cnda1->nd; i++)
{
if (cnda1->host_structure[i] != cnda2->host_structure[i])
{
if(verbose)
fprintf(stdout, "CUDANDARRAY_EQUAL : host_structure : %d, %d, %d\n", i, cnda1->host_structure[i], cnda2->host_structure[i]);
return 0;
}
}
if (!ignoreBase && cnda1->base != cnda2->base)
{
if(verbose) fprintf(stdout, "CUDANDARRAY_EQUAL FAILED : 4");
return 0;
}
else if (cnda1->data_allocated != cnda2->data_allocated)
{
if(verbose) fprintf(stdout, "CUDANDARRAY_EQUAL FAILED : 5");
return 0;
}
else if (cnda1->data_allocated && cnda1->devdata != cnda2->devdata)
{
if(verbose) fprintf(stdout, "CUDANDARRAY_EQUAL FAILED : 6");
// no need to check devdata if data is not allocated
return 0;
}
return 1;
}
// Default: do not ignore sync of dev and host structures in comparing, and do not ignore difference in base pointers
int
CudaNdarray_Equal(CudaNdarray *cnda1, CudaNdarray *cnda2)
{
return CudaNdarray_EqualAndIgnore(cnda1, cnda2, 0, 0);
}
/**** /****
* Set the idx'th dimension to value d. * Set the idx'th dimension to value d.
* *
......
...@@ -382,12 +382,10 @@ def test_setitem_rightvalue_ndarray_fails(): ...@@ -382,12 +382,10 @@ def test_setitem_rightvalue_ndarray_fails():
assert True assert True
'''
if __name__ == '__main__': if __name__ == '__main__':
test_setitem_matrixvector1() test_setitem_matrixvector1()
test_setitem_matrix_tensor3() test_setitem_matrix_tensor3()
test_setitem_broadcast_must_fail() test_setitem_broadcast_must_fail()
test_setitem_assign_to_slice() test_setitem_assign_to_slice()
test_setitem_rightvalue_ndarray_fails() test_setitem_rightvalue_ndarray_fails()
'''
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论