tightened up c implementation of transpose

上级 a0a9ab02
......@@ -1082,6 +1082,29 @@ class transpose(omega_op, view):
def c_impl((x, ), (xt, )):
return """
const int l = x->nd;
// The user must ensure that all references to
//xt->data go through xt, or there's going to be trouble..
int refcheck = 0;
if (x == xt)
{
return -1;
}
if (refcheck)
{
int refcnt = PyArray_REFCOUNT(xt);
if ((refcnt > 2) // you might think this should be 1.. but this works
//|| (xt->base != NULL)
|| (xt->weakreflist != NULL))
{
PyErr_SetString(PyExc_ValueError,
"cannot resize an array that has "\\
"been referenced or is referencing\\n"\\
"another array in this way. Use the "\\
"resize function");
return -2;
}
}
if (xt->nd != x->nd)
{
......@@ -1089,38 +1112,52 @@ class transpose(omega_op, view):
npy_intp * dimptr = (npy_intp*)PyDimMem_RENEW(xt->dimensions, 2 * x->nd);
if (!dimptr)
{
fprintf(stderr, "%i: %p\\n", __LINE__, dimptr);
assert(!"dammit");
PyErr_NoMemory();
return 1;
}
xt->nd = x->nd;
xt->dimensions = dimptr;
xt->strides = dimptr + x->nd;
//fprintf(stderr, "transpose: %i %i %i %i\\n", x->dimensions[0], x->dimensions[1], x->strides[0], x->strides[1]);
}
//fprintf(stderr, "%s %i %p\\n", __FILE__, __LINE__, xt->base);
if ( xt->base != (PyObject*)x)
{
//fprintf(stderr, "%i: %p\\n", __LINE__, xt->base);
if ((xt->base) and (xt->base != Py_None)) Py_DECREF(xt->base);
Py_INCREF(x);
xt->base = (PyObject*)x;
}
xt->data = x->data;
//copy x's dimensions and strides
for (int i = 0; i < l; ++i)
{
xt->dimensions[i] = x->dimensions[l-i-1];
xt->strides[i] = x->strides[l-i-1];
//fprintf(stderr, "%li\t", x->dimensions[i]);
}
//fprintf(stderr, "\\n");
xt->flags &= ~NPY_OWNDATA;
PyArray_UpdateFlags(xt, NPY_CONTIGUOUS|NPY_FORTRAN|NPY_ALIGNED|NPY_WRITEABLE);
//this function is described in
// point directly at b's type descriptor
Py_INCREF(x->descr);
Py_DECREF(xt->descr);
xt->descr = x->descr;
// name x as a base of xt, increment its refcount
if ( xt->base != (PyObject*)x)
{
Py_INCREF(x);
if ((xt->base) && (xt->base != Py_None))
{
Py_DECREF(xt->base);
}
xt->base = (PyObject*)x;
}
// mark xt as not owning its data
if (PyArray_CHKFLAGS(xt, NPY_OWNDATA))
{
PyDataMem_FREE(xt->data);
xt->flags &= ~NPY_OWNDATA;
}
xt->data = x->data;
// this function is described in
// ~/zzz.NOBACKUP/pub/src/numpy-1.0.3.1/numpy/core/src/arrayobject.c:1890
return 0;
PyArray_UpdateFlags(xt, NPY_CONTIGUOUS|NPY_FORTRAN|NPY_ALIGNED|NPY_WRITEABLE);
/*
TODO
What should be done with the weakreflist ?
*/
"""
def transpose_copy(x):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论