提交 658cd71e authored 作者: Frédéric Bastien's avatar Frédéric Bastien

Merge pull request #2574 from sisp/flatten_c_code

Flatten C code
......@@ -4012,6 +4012,8 @@ class Flatten(Op):
"""
view_map = {0: [0]}
check_input = False
def __init__(self, outdim=1):
self.outdim = int(outdim)
......@@ -4078,6 +4080,74 @@ class Flatten(Op):
return [None]
return self.make_node(*eval_points).outputs
def c_code_cache_version(self):
return (1, 1)
def c_code(self, node, name, inputs, outputs, sub):
x, = inputs
out, = outputs
outdim = self.outdim
fail = sub['fail']
return """
if (%(outdim)s == PyArray_NDIM(%(x)s))
{
Py_XDECREF(%(out)s);
Py_XINCREF(%(x)s);
%(out)s = %(x)s;
}
else
{
Py_XDECREF(%(out)s);
if (%(outdim)s == 1)
{
npy_intp size = PyArray_SIZE(%(x)s);
PyArray_Dims newshape;
newshape.ptr = &size;
newshape.len = 1;
%(out)s = (PyArrayObject*)PyArray_Newshape(%(x)s,
&newshape,
NPY_CORDER);
}
else
{
npy_intp *oldshape = PyArray_DIMS(%(x)s);
npy_intp newshape_dims[%(outdim)s];
int i;
for (i = 0; i < %(outdim)s - 1; ++i)
newshape_dims[i] = oldshape[i];
newshape_dims[i] = 1;
for (int j = %(outdim)s - 1; j < PyArray_NDIM(%(x)s); ++j)
newshape_dims[i] *= oldshape[j];
PyArray_Dims newshape;
newshape.ptr = newshape_dims;
newshape.len = %(outdim)s;
%(out)s = (PyArrayObject*)PyArray_Newshape(%(x)s,
&newshape,
NPY_CORDER);
}
}
if (!%(out)s)
{
//The error message should have been set by
// PyArray_Newshape
%(fail)s;
}
if (!PyArray_ISALIGNED(%(out)s)) {
PyErr_Format(
PyExc_RuntimeError,
"PyArray_Newshape returned an object that isn't"
" aligned! NumPy versions 1.6.2, 1.7.0 and 1.7.1 have"
" this problem for some input shape/new shape"
" combinations. Use another NumPy version.");
%(fail)s;
}
""" % locals()
def flatten(x, outdim=1):
return Flatten(outdim)(x)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论