tightened up c implementation of transpose

上级 a0a9ab02
...@@ -1082,6 +1082,29 @@ class transpose(omega_op, view): ...@@ -1082,6 +1082,29 @@ class transpose(omega_op, view):
def c_impl((x, ), (xt, )): def c_impl((x, ), (xt, )):
return """ return """
const int l = x->nd; const int l = x->nd;
// The user must ensure that all references to
//xt->data go through xt, or there's going to be trouble..
int refcheck = 0;
if (x == xt)
{
return -1;
}
if (refcheck)
{
int refcnt = PyArray_REFCOUNT(xt);
if ((refcnt > 2) // you might think this should be 1.. but this works
//|| (xt->base != NULL)
|| (xt->weakreflist != NULL))
{
PyErr_SetString(PyExc_ValueError,
"cannot resize an array that has "\\
"been referenced or is referencing\\n"\\
"another array in this way. Use the "\\
"resize function");
return -2;
}
}
if (xt->nd != x->nd) if (xt->nd != x->nd)
{ {
...@@ -1089,38 +1112,52 @@ class transpose(omega_op, view): ...@@ -1089,38 +1112,52 @@ class transpose(omega_op, view):
npy_intp * dimptr = (npy_intp*)PyDimMem_RENEW(xt->dimensions, 2 * x->nd); npy_intp * dimptr = (npy_intp*)PyDimMem_RENEW(xt->dimensions, 2 * x->nd);
if (!dimptr) if (!dimptr)
{ {
fprintf(stderr, "%i: %p\\n", __LINE__, dimptr); PyErr_NoMemory();
assert(!"dammit"); return 1;
} }
xt->nd = x->nd; xt->nd = x->nd;
xt->dimensions = dimptr; xt->dimensions = dimptr;
xt->strides = dimptr + x->nd; xt->strides = dimptr + x->nd;
//fprintf(stderr, "transpose: %i %i %i %i\\n", x->dimensions[0], x->dimensions[1], x->strides[0], x->strides[1]);
} }
//fprintf(stderr, "%s %i %p\\n", __FILE__, __LINE__, xt->base); //copy x's dimensions and strides
if ( xt->base != (PyObject*)x)
{
//fprintf(stderr, "%i: %p\\n", __LINE__, xt->base);
if ((xt->base) and (xt->base != Py_None)) Py_DECREF(xt->base);
Py_INCREF(x);
xt->base = (PyObject*)x;
}
xt->data = x->data;
for (int i = 0; i < l; ++i) for (int i = 0; i < l; ++i)
{ {
xt->dimensions[i] = x->dimensions[l-i-1]; xt->dimensions[i] = x->dimensions[l-i-1];
xt->strides[i] = x->strides[l-i-1]; xt->strides[i] = x->strides[l-i-1];
//fprintf(stderr, "%li\t", x->dimensions[i]);
} }
//fprintf(stderr, "\\n");
xt->flags &= ~NPY_OWNDATA; // point directly at b's type descriptor
PyArray_UpdateFlags(xt, NPY_CONTIGUOUS|NPY_FORTRAN|NPY_ALIGNED|NPY_WRITEABLE); Py_INCREF(x->descr);
//this function is described in Py_DECREF(xt->descr);
xt->descr = x->descr;
// name x as a base of xt, increment its refcount
if ( xt->base != (PyObject*)x)
{
Py_INCREF(x);
if ((xt->base) && (xt->base != Py_None))
{
Py_DECREF(xt->base);
}
xt->base = (PyObject*)x;
}
// mark xt as not owning its data
if (PyArray_CHKFLAGS(xt, NPY_OWNDATA))
{
PyDataMem_FREE(xt->data);
xt->flags &= ~NPY_OWNDATA;
}
xt->data = x->data;
// this function is described in
// ~/zzz.NOBACKUP/pub/src/numpy-1.0.3.1/numpy/core/src/arrayobject.c:1890 // ~/zzz.NOBACKUP/pub/src/numpy-1.0.3.1/numpy/core/src/arrayobject.c:1890
return 0; PyArray_UpdateFlags(xt, NPY_CONTIGUOUS|NPY_FORTRAN|NPY_ALIGNED|NPY_WRITEABLE);
/*
TODO
What should be done with the weakreflist ?
*/
""" """
def transpose_copy(x): def transpose_copy(x):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论