提交 b3203e28 authored 作者: Olivier Breuleux's avatar Olivier Breuleux

fixed c bug

上级 0dd147a2
......@@ -12,6 +12,8 @@ import numpy
import sys
from scipy import weave
def inputs():
l1 = [[1.0, 2.0], [3.0, 4.0]]
......@@ -43,13 +45,23 @@ class _test_TensorOps(unittest.TestCase):
fn()
assert (e.data == numpy.array([[3, 3, 3], [7, 7, 7]]).T).all()
def test_2(self):
x, y, z = inputs()
x = x.data
y = weave.inline("""
PyObject* p = PyArray_Transpose(x_array, NULL);
return_val = p;
""", ['x'])
print y
# def test_0(self):
# x, y, z = inputs()
# e = transpose(x)
# g = env([x], [e])
# fn, (i, ), (o, ) = gof.cc.CLinker(g).make_thunk()
# i.data = [[1.0, 2.0], [3.0, 4.0]]
# # print sys.getrefcount(i.data)
# for blah in xrange(10000):
# i.data = numpy.ones((1000, 1000)) # [[1.0, 2.0], [3.0, 4.0]]
# fn()
# # print sys.getrefcount(i.data)
# # print sys.getrefcount(o.data)
......
......@@ -146,8 +146,8 @@ def struct_gen(args, struct_builders, blocks, sub):
this->__ERROR = __ERROR;
return 0;
%(struct_init_tail)s
%(storage_decref)s
%(do_return)s
return %(failure_var)s;
}
void cleanup(void) {
%(struct_cleanup)s
......
......@@ -70,6 +70,7 @@ class Tensor(ResultBase):
def c_extract(self):
return """
%(name)s = NULL;
if (py_%(name)s == Py_None) {
%(name)s = NULL;
}
......@@ -103,6 +104,7 @@ class Tensor(ResultBase):
else if ((void*)py_%(name)s != (void*)%(name)s) {
Py_XDECREF(py_%(name)s);
py_%(name)s = (PyObject*)%(name)s;
Py_XINCREF(py_%(name)s);
}
"""
......
......@@ -76,28 +76,28 @@ class BinaryTensorOp(TensorOp):
nin = 2
class Transpose(UnaryTensorOp):
# class Transpose(UnaryTensorOp):
def propagate_broadcastable(self, x):
x2 = copy(x)
x2.reverse()
return [x2]
# def propagate_broadcastable(self, x):
# x2 = copy(x)
# x2.reverse()
# return [x2]
def impl(self, x):
return x.T
# def impl(self, x):
# return x.T
def c_impl(self, x, z):
return """
PyArrayObject* transposed = (PyArrayObject*)PyArray_Transpose(%(x)s, NULL);
//if (PyArray_REFCOUNT(transposed) == 1) {
// printf("lala\\n");
//}
//if (%(z)s) {
// Py_XDECREF(%(z)s);
//}
%(z)s = transposed;
Py_XINCREF(%(z)s);
"""
# def c_impl(self, x, z):
# return """
# PyArrayObject* transposed = (PyArrayObject*)PyArray_Transpose(%(x)s, NULL);
# //if (PyArray_REFCOUNT(transposed) == 1) {
# // printf("lala\\n");
# //}
# //if (%(z)s) {
# // Py_XDECREF(%(z)s);
# //}
# %(z)s = transposed;
# Py_XINCREF(%(z)s);
# """
......@@ -465,87 +465,97 @@ class Transpose(TensorOp, Viewer):
rval = list(x)
rval.reverse()
return [rval]
def c_impl(self, (x, ), (xt, )):
return """
const int l = x->nd;
// The user must ensure that all references to
//xt->data go through xt, or there's going to be trouble..
int refcheck = 0;
if (x == xt)
{
return -1;
}
if (refcheck)
{
int refcnt = PyArray_REFCOUNT(xt);
if ((refcnt > 2) // you might think this should be 1.. but this works
//|| (xt->base != NULL)
|| (xt->weakreflist != NULL))
{
PyErr_SetString(PyExc_ValueError,
"cannot resize an array that has "\\
"been referenced or is referencing\\n"\\
"another array in this way. Use the "\\
"resize function");
return -2;
}
}
if (xt->nd != x->nd)
{
// this technique comes from PyArray_Resize()
npy_intp * dimptr = (npy_intp*)PyDimMem_RENEW(xt->dimensions, 2 * x->nd);
if (!dimptr)
{
PyErr_NoMemory();
return 1;
}
xt->nd = x->nd;
xt->dimensions = dimptr;
xt->strides = dimptr + x->nd;
}
//copy x's dimensions and strides
for (int i = 0; i < l; ++i)
{
xt->dimensions[i] = x->dimensions[l-i-1];
xt->strides[i] = x->strides[l-i-1];
}
// point directly at b's type descriptor
Py_INCREF(x->descr);
Py_DECREF(xt->descr);
xt->descr = x->descr;
// name x as a base of xt, increment its refcount
if ( xt->base != (PyObject*)x)
{
Py_INCREF(x);
if ((xt->base) && (xt->base != Py_None))
{
Py_DECREF(xt->base);
}
xt->base = (PyObject*)x;
}
// mark xt as not owning its data
if (PyArray_CHKFLAGS(xt, NPY_OWNDATA))
{
PyDataMem_FREE(xt->data);
xt->flags &= ~NPY_OWNDATA;
def c_impl(self, x, z):
return """
PyArrayObject* transposed = (PyArrayObject*)PyArray_Transpose(%(x)s, NULL);
if (%(z)s) {
Py_XDECREF(%(z)s);
}
xt->data = x->data;
// this function is described in
// ~/zzz.NOBACKUP/pub/src/numpy-1.0.3.1/numpy/core/src/arrayobject.c:1890
PyArray_UpdateFlags(xt, NPY_CONTIGUOUS|NPY_FORTRAN|NPY_ALIGNED|NPY_WRITEABLE);
/*
TODO
What should be done with the weakreflist ?
*/
%(z)s = transposed;
"""
# def c_impl(self, (x, ), (xt, )):
# return """
# const int l = x->nd;
# // The user must ensure that all references to
# //xt->data go through xt, or there's going to be trouble..
# int refcheck = 0;
# if (x == xt)
# {
# return -1;
# }
# if (refcheck)
# {
# int refcnt = PyArray_REFCOUNT(xt);
# if ((refcnt > 2) // you might think this should be 1.. but this works
# //|| (xt->base != NULL)
# || (xt->weakreflist != NULL))
# {
# PyErr_SetString(PyExc_ValueError,
# "cannot resize an array that has "\\
# "been referenced or is referencing\\n"\\
# "another array in this way. Use the "\\
# "resize function");
# return -2;
# }
# }
# if (xt->nd != x->nd)
# {
# // this technique comes from PyArray_Resize()
# npy_intp * dimptr = (npy_intp*)PyDimMem_RENEW(xt->dimensions, 2 * x->nd);
# if (!dimptr)
# {
# PyErr_NoMemory();
# return 1;
# }
# xt->nd = x->nd;
# xt->dimensions = dimptr;
# xt->strides = dimptr + x->nd;
# }
# //copy x's dimensions and strides
# for (int i = 0; i < l; ++i)
# {
# xt->dimensions[i] = x->dimensions[l-i-1];
# xt->strides[i] = x->strides[l-i-1];
# }
# // point directly at b's type descriptor
# Py_INCREF(x->descr);
# Py_DECREF(xt->descr);
# xt->descr = x->descr;
# // name x as a base of xt, increment its refcount
# if ( xt->base != (PyObject*)x)
# {
# Py_INCREF(x);
# if ((xt->base) && (xt->base != Py_None))
# {
# Py_DECREF(xt->base);
# }
# xt->base = (PyObject*)x;
# }
# // mark xt as not owning its data
# if (PyArray_CHKFLAGS(xt, NPY_OWNDATA))
# {
# PyDataMem_FREE(xt->data);
# xt->flags &= ~NPY_OWNDATA;
# }
# xt->data = x->data;
# // this function is described in
# // ~/zzz.NOBACKUP/pub/src/numpy-1.0.3.1/numpy/core/src/arrayobject.c:1890
# PyArray_UpdateFlags(xt, NPY_CONTIGUOUS|NPY_FORTRAN|NPY_ALIGNED|NPY_WRITEABLE);
# /*
# TODO
# What should be done with the weakreflist ?
# */
# """
def transpose_copy(x):
return array_copy(transpose(x))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论