提交 b3203e28 authored 作者: Olivier Breuleux's avatar Olivier Breuleux

fixed c bug

上级 0dd147a2
...@@ -12,6 +12,8 @@ import numpy ...@@ -12,6 +12,8 @@ import numpy
import sys import sys
from scipy import weave
def inputs(): def inputs():
l1 = [[1.0, 2.0], [3.0, 4.0]] l1 = [[1.0, 2.0], [3.0, 4.0]]
...@@ -43,14 +45,24 @@ class _test_TensorOps(unittest.TestCase): ...@@ -43,14 +45,24 @@ class _test_TensorOps(unittest.TestCase):
fn() fn()
assert (e.data == numpy.array([[3, 3, 3], [7, 7, 7]]).T).all() assert (e.data == numpy.array([[3, 3, 3], [7, 7, 7]]).T).all()
def test_2(self):
x, y, z = inputs()
x = x.data
y = weave.inline("""
PyObject* p = PyArray_Transpose(x_array, NULL);
return_val = p;
""", ['x'])
print y
# def test_0(self): # def test_0(self):
# x, y, z = inputs() # x, y, z = inputs()
# e = transpose(x) # e = transpose(x)
# g = env([x], [e]) # g = env([x], [e])
# fn, (i, ), (o, ) = gof.cc.CLinker(g).make_thunk() # fn, (i, ), (o, ) = gof.cc.CLinker(g).make_thunk()
# i.data = [[1.0, 2.0], [3.0, 4.0]]
# # print sys.getrefcount(i.data) # # print sys.getrefcount(i.data)
# fn() # for blah in xrange(10000):
# i.data = numpy.ones((1000, 1000)) # [[1.0, 2.0], [3.0, 4.0]]
# fn()
# # print sys.getrefcount(i.data) # # print sys.getrefcount(i.data)
# # print sys.getrefcount(o.data) # # print sys.getrefcount(o.data)
# print o.data # print o.data
......
...@@ -146,8 +146,8 @@ def struct_gen(args, struct_builders, blocks, sub): ...@@ -146,8 +146,8 @@ def struct_gen(args, struct_builders, blocks, sub):
this->__ERROR = __ERROR; this->__ERROR = __ERROR;
return 0; return 0;
%(struct_init_tail)s %(struct_init_tail)s
%(storage_decref)s
%(do_return)s %(do_return)s
return %(failure_var)s;
} }
void cleanup(void) { void cleanup(void) {
%(struct_cleanup)s %(struct_cleanup)s
......
...@@ -70,6 +70,7 @@ class Tensor(ResultBase): ...@@ -70,6 +70,7 @@ class Tensor(ResultBase):
def c_extract(self): def c_extract(self):
return """ return """
%(name)s = NULL;
if (py_%(name)s == Py_None) { if (py_%(name)s == Py_None) {
%(name)s = NULL; %(name)s = NULL;
} }
...@@ -103,6 +104,7 @@ class Tensor(ResultBase): ...@@ -103,6 +104,7 @@ class Tensor(ResultBase):
else if ((void*)py_%(name)s != (void*)%(name)s) { else if ((void*)py_%(name)s != (void*)%(name)s) {
Py_XDECREF(py_%(name)s); Py_XDECREF(py_%(name)s);
py_%(name)s = (PyObject*)%(name)s; py_%(name)s = (PyObject*)%(name)s;
Py_XINCREF(py_%(name)s);
} }
""" """
......
...@@ -76,28 +76,28 @@ class BinaryTensorOp(TensorOp): ...@@ -76,28 +76,28 @@ class BinaryTensorOp(TensorOp):
nin = 2 nin = 2
class Transpose(UnaryTensorOp): # class Transpose(UnaryTensorOp):
def propagate_broadcastable(self, x): # def propagate_broadcastable(self, x):
x2 = copy(x) # x2 = copy(x)
x2.reverse() # x2.reverse()
return [x2] # return [x2]
def impl(self, x): # def impl(self, x):
return x.T # return x.T
def c_impl(self, x, z): # def c_impl(self, x, z):
return """ # return """
PyArrayObject* transposed = (PyArrayObject*)PyArray_Transpose(%(x)s, NULL); # PyArrayObject* transposed = (PyArrayObject*)PyArray_Transpose(%(x)s, NULL);
//if (PyArray_REFCOUNT(transposed) == 1) { # //if (PyArray_REFCOUNT(transposed) == 1) {
// printf("lala\\n"); # // printf("lala\\n");
//} # //}
//if (%(z)s) { # //if (%(z)s) {
// Py_XDECREF(%(z)s); # // Py_XDECREF(%(z)s);
//} # //}
%(z)s = transposed; # %(z)s = transposed;
Py_XINCREF(%(z)s); # Py_XINCREF(%(z)s);
""" # """
...@@ -465,86 +465,96 @@ class Transpose(TensorOp, Viewer): ...@@ -465,86 +465,96 @@ class Transpose(TensorOp, Viewer):
rval = list(x) rval = list(x)
rval.reverse() rval.reverse()
return [rval] return [rval]
def c_impl(self, (x, ), (xt, )):
def c_impl(self, x, z):
return """ return """
const int l = x->nd; PyArrayObject* transposed = (PyArrayObject*)PyArray_Transpose(%(x)s, NULL);
// The user must ensure that all references to if (%(z)s) {
//xt->data go through xt, or there's going to be trouble.. Py_XDECREF(%(z)s);
int refcheck = 0;
if (x == xt)
{
return -1;
}
if (refcheck)
{
int refcnt = PyArray_REFCOUNT(xt);
if ((refcnt > 2) // you might think this should be 1.. but this works
//|| (xt->base != NULL)
|| (xt->weakreflist != NULL))
{
PyErr_SetString(PyExc_ValueError,
"cannot resize an array that has "\\
"been referenced or is referencing\\n"\\
"another array in this way. Use the "\\
"resize function");
return -2;
}
}
if (xt->nd != x->nd)
{
// this technique comes from PyArray_Resize()
npy_intp * dimptr = (npy_intp*)PyDimMem_RENEW(xt->dimensions, 2 * x->nd);
if (!dimptr)
{
PyErr_NoMemory();
return 1;
}
xt->nd = x->nd;
xt->dimensions = dimptr;
xt->strides = dimptr + x->nd;
}
//copy x's dimensions and strides
for (int i = 0; i < l; ++i)
{
xt->dimensions[i] = x->dimensions[l-i-1];
xt->strides[i] = x->strides[l-i-1];
} }
%(z)s = transposed;
"""
// point directly at b's type descriptor # def c_impl(self, (x, ), (xt, )):
Py_INCREF(x->descr); # return """
Py_DECREF(xt->descr); # const int l = x->nd;
xt->descr = x->descr; # // The user must ensure that all references to
# //xt->data go through xt, or there's going to be trouble..
// name x as a base of xt, increment its refcount # int refcheck = 0;
if ( xt->base != (PyObject*)x)
{ # if (x == xt)
Py_INCREF(x); # {
if ((xt->base) && (xt->base != Py_None)) # return -1;
{ # }
Py_DECREF(xt->base); # if (refcheck)
} # {
xt->base = (PyObject*)x; # int refcnt = PyArray_REFCOUNT(xt);
} # if ((refcnt > 2) // you might think this should be 1.. but this works
# //|| (xt->base != NULL)
# || (xt->weakreflist != NULL))
# {
# PyErr_SetString(PyExc_ValueError,
# "cannot resize an array that has "\\
# "been referenced or is referencing\\n"\\
# "another array in this way. Use the "\\
# "resize function");
# return -2;
# }
# }
# if (xt->nd != x->nd)
# {
# // this technique comes from PyArray_Resize()
# npy_intp * dimptr = (npy_intp*)PyDimMem_RENEW(xt->dimensions, 2 * x->nd);
# if (!dimptr)
# {
# PyErr_NoMemory();
# return 1;
# }
# xt->nd = x->nd;
# xt->dimensions = dimptr;
# xt->strides = dimptr + x->nd;
# }
# //copy x's dimensions and strides
# for (int i = 0; i < l; ++i)
# {
# xt->dimensions[i] = x->dimensions[l-i-1];
# xt->strides[i] = x->strides[l-i-1];
# }
# // point directly at b's type descriptor
# Py_INCREF(x->descr);
# Py_DECREF(xt->descr);
# xt->descr = x->descr;
# // name x as a base of xt, increment its refcount
# if ( xt->base != (PyObject*)x)
# {
# Py_INCREF(x);
# if ((xt->base) && (xt->base != Py_None))
# {
# Py_DECREF(xt->base);
# }
# xt->base = (PyObject*)x;
# }
// mark xt as not owning its data # // mark xt as not owning its data
if (PyArray_CHKFLAGS(xt, NPY_OWNDATA)) # if (PyArray_CHKFLAGS(xt, NPY_OWNDATA))
{ # {
PyDataMem_FREE(xt->data); # PyDataMem_FREE(xt->data);
xt->flags &= ~NPY_OWNDATA; # xt->flags &= ~NPY_OWNDATA;
} # }
xt->data = x->data; # xt->data = x->data;
// this function is described in # // this function is described in
// ~/zzz.NOBACKUP/pub/src/numpy-1.0.3.1/numpy/core/src/arrayobject.c:1890 # // ~/zzz.NOBACKUP/pub/src/numpy-1.0.3.1/numpy/core/src/arrayobject.c:1890
PyArray_UpdateFlags(xt, NPY_CONTIGUOUS|NPY_FORTRAN|NPY_ALIGNED|NPY_WRITEABLE); # PyArray_UpdateFlags(xt, NPY_CONTIGUOUS|NPY_FORTRAN|NPY_ALIGNED|NPY_WRITEABLE);
/* # /*
TODO # TODO
What should be done with the weakreflist ? # What should be done with the weakreflist ?
*/ # */
""" # """
def transpose_copy(x): def transpose_copy(x):
return array_copy(transpose(x)) return array_copy(transpose(x))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论