提交 e3a00698 authored 作者: james@mackie's avatar james@mackie

merged

...@@ -12,6 +12,8 @@ import numpy ...@@ -12,6 +12,8 @@ import numpy
import sys import sys
from scipy import weave
def inputs(): def inputs():
l1 = [[1.0, 2.0], [3.0, 4.0]] l1 = [[1.0, 2.0], [3.0, 4.0]]
...@@ -43,13 +45,23 @@ class _test_TensorOps(unittest.TestCase): ...@@ -43,13 +45,23 @@ class _test_TensorOps(unittest.TestCase):
fn() fn()
assert (e.data == numpy.array([[3, 3, 3], [7, 7, 7]]).T).all() assert (e.data == numpy.array([[3, 3, 3], [7, 7, 7]]).T).all()
def test_2(self):
x, y, z = inputs()
x = x.data
y = weave.inline("""
PyObject* p = PyArray_Transpose(x_array, NULL);
return_val = p;
""", ['x'])
print y
# def test_0(self): # def test_0(self):
# x, y, z = inputs() # x, y, z = inputs()
# e = transpose(x) # e = transpose(x)
# g = env([x], [e]) # g = env([x], [e])
# fn, (i, ), (o, ) = gof.cc.CLinker(g).make_thunk() # fn, (i, ), (o, ) = gof.cc.CLinker(g).make_thunk()
# i.data = [[1.0, 2.0], [3.0, 4.0]]
# # print sys.getrefcount(i.data) # # print sys.getrefcount(i.data)
# for blah in xrange(10000):
# i.data = numpy.ones((1000, 1000)) # [[1.0, 2.0], [3.0, 4.0]]
# fn() # fn()
# # print sys.getrefcount(i.data) # # print sys.getrefcount(i.data)
# # print sys.getrefcount(o.data) # # print sys.getrefcount(o.data)
......
...@@ -146,8 +146,8 @@ def struct_gen(args, struct_builders, blocks, sub): ...@@ -146,8 +146,8 @@ def struct_gen(args, struct_builders, blocks, sub):
this->__ERROR = __ERROR; this->__ERROR = __ERROR;
return 0; return 0;
%(struct_init_tail)s %(struct_init_tail)s
%(storage_decref)s
%(do_return)s %(do_return)s
return %(failure_var)s;
} }
void cleanup(void) { void cleanup(void) {
%(struct_cleanup)s %(struct_cleanup)s
......
...@@ -70,6 +70,7 @@ class Tensor(ResultBase): ...@@ -70,6 +70,7 @@ class Tensor(ResultBase):
def c_extract(self): def c_extract(self):
return """ return """
%(name)s = NULL;
if (py_%(name)s == Py_None) { if (py_%(name)s == Py_None) {
%(name)s = NULL; %(name)s = NULL;
} }
...@@ -103,6 +104,7 @@ class Tensor(ResultBase): ...@@ -103,6 +104,7 @@ class Tensor(ResultBase):
else if ((void*)py_%(name)s != (void*)%(name)s) { else if ((void*)py_%(name)s != (void*)%(name)s) {
Py_XDECREF(py_%(name)s); Py_XDECREF(py_%(name)s);
py_%(name)s = (PyObject*)%(name)s; py_%(name)s = (PyObject*)%(name)s;
Py_XINCREF(py_%(name)s);
} }
""" """
......
...@@ -85,28 +85,28 @@ class BinaryTensorOp(_TensorOp): ...@@ -85,28 +85,28 @@ class BinaryTensorOp(_TensorOp):
nin = 2 nin = 2
class Transpose(UnaryTensorOp): # class Transpose(UnaryTensorOp):
def propagate_broadcastable(self, x): # def propagate_broadcastable(self, x):
x2 = copy(x) # x2 = copy(x)
x2.reverse() # x2.reverse()
return [x2] # return [x2]
def impl(self, x): # def impl(self, x):
return x.T # return x.T
def c_impl(self, x, z): # def c_impl(self, x, z):
return """ # return """
PyArrayObject* transposed = (PyArrayObject*)PyArray_Transpose(%(x)s, NULL); # PyArrayObject* transposed = (PyArrayObject*)PyArray_Transpose(%(x)s, NULL);
//if (PyArray_REFCOUNT(transposed) == 1) { # //if (PyArray_REFCOUNT(transposed) == 1) {
// printf("lala\\n"); # // printf("lala\\n");
//} # //}
//if (%(z)s) { # //if (%(z)s) {
// Py_XDECREF(%(z)s); # // Py_XDECREF(%(z)s);
//} # //}
%(z)s = transposed; # %(z)s = transposed;
Py_XINCREF(%(z)s); # Py_XINCREF(%(z)s);
""" # """
...@@ -474,87 +474,97 @@ class Transpose(_TensorOp, Viewer): ...@@ -474,87 +474,97 @@ class Transpose(_TensorOp, Viewer):
rval = list(x) rval = list(x)
rval.reverse() rval.reverse()
return [rval] return [rval]
def c_impl(self, (x, ), (xt, )):
return """
const int l = x->nd;
// The user must ensure that all references to
//xt->data go through xt, or there's going to be trouble..
int refcheck = 0;
if (x == xt)
{
return -1;
}
if (refcheck)
{
int refcnt = PyArray_REFCOUNT(xt);
if ((refcnt > 2) // you might think this should be 1.. but this works
//|| (xt->base != NULL)
|| (xt->weakreflist != NULL))
{
PyErr_SetString(PyExc_ValueError,
"cannot resize an array that has "\\
"been referenced or is referencing\\n"\\
"another array in this way. Use the "\\
"resize function");
return -2;
}
}
if (xt->nd != x->nd)
{
// this technique comes from PyArray_Resize()
npy_intp * dimptr = (npy_intp*)PyDimMem_RENEW(xt->dimensions, 2 * x->nd);
if (!dimptr)
{
PyErr_NoMemory();
return 1;
}
xt->nd = x->nd;
xt->dimensions = dimptr;
xt->strides = dimptr + x->nd;
}
//copy x's dimensions and strides
for (int i = 0; i < l; ++i)
{
xt->dimensions[i] = x->dimensions[l-i-1];
xt->strides[i] = x->strides[l-i-1];
}
// point directly at b's type descriptor
Py_INCREF(x->descr);
Py_DECREF(xt->descr);
xt->descr = x->descr;
// name x as a base of xt, increment its refcount
if ( xt->base != (PyObject*)x)
{
Py_INCREF(x);
if ((xt->base) && (xt->base != Py_None))
{
Py_DECREF(xt->base);
}
xt->base = (PyObject*)x;
}
// mark xt as not owning its data def c_impl(self, x, z):
if (PyArray_CHKFLAGS(xt, NPY_OWNDATA)) return """
{ PyArrayObject* transposed = (PyArrayObject*)PyArray_Transpose(%(x)s, NULL);
PyDataMem_FREE(xt->data); if (%(z)s) {
xt->flags &= ~NPY_OWNDATA; Py_XDECREF(%(z)s);
} }
xt->data = x->data; %(z)s = transposed;
// this function is described in
// ~/zzz.NOBACKUP/pub/src/numpy-1.0.3.1/numpy/core/src/arrayobject.c:1890
PyArray_UpdateFlags(xt, NPY_CONTIGUOUS|NPY_FORTRAN|NPY_ALIGNED|NPY_WRITEABLE);
/*
TODO
What should be done with the weakreflist ?
*/
""" """
# def c_impl(self, (x, ), (xt, )):
# return """
# const int l = x->nd;
# // The user must ensure that all references to
# //xt->data go through xt, or there's going to be trouble..
# int refcheck = 0;
# if (x == xt)
# {
# return -1;
# }
# if (refcheck)
# {
# int refcnt = PyArray_REFCOUNT(xt);
# if ((refcnt > 2) // you might think this should be 1.. but this works
# //|| (xt->base != NULL)
# || (xt->weakreflist != NULL))
# {
# PyErr_SetString(PyExc_ValueError,
# "cannot resize an array that has "\\
# "been referenced or is referencing\\n"\\
# "another array in this way. Use the "\\
# "resize function");
# return -2;
# }
# }
# if (xt->nd != x->nd)
# {
# // this technique comes from PyArray_Resize()
# npy_intp * dimptr = (npy_intp*)PyDimMem_RENEW(xt->dimensions, 2 * x->nd);
# if (!dimptr)
# {
# PyErr_NoMemory();
# return 1;
# }
# xt->nd = x->nd;
# xt->dimensions = dimptr;
# xt->strides = dimptr + x->nd;
# }
# //copy x's dimensions and strides
# for (int i = 0; i < l; ++i)
# {
# xt->dimensions[i] = x->dimensions[l-i-1];
# xt->strides[i] = x->strides[l-i-1];
# }
# // point directly at b's type descriptor
# Py_INCREF(x->descr);
# Py_DECREF(xt->descr);
# xt->descr = x->descr;
# // name x as a base of xt, increment its refcount
# if ( xt->base != (PyObject*)x)
# {
# Py_INCREF(x);
# if ((xt->base) && (xt->base != Py_None))
# {
# Py_DECREF(xt->base);
# }
# xt->base = (PyObject*)x;
# }
# // mark xt as not owning its data
# if (PyArray_CHKFLAGS(xt, NPY_OWNDATA))
# {
# PyDataMem_FREE(xt->data);
# xt->flags &= ~NPY_OWNDATA;
# }
# xt->data = x->data;
# // this function is described in
# // ~/zzz.NOBACKUP/pub/src/numpy-1.0.3.1/numpy/core/src/arrayobject.c:1890
# PyArray_UpdateFlags(xt, NPY_CONTIGUOUS|NPY_FORTRAN|NPY_ALIGNED|NPY_WRITEABLE);
# /*
# TODO
# What should be done with the weakreflist ?
# */
# """
def transpose_copy(x): def transpose_copy(x):
return array_copy(transpose(x)) return array_copy(transpose(x))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论