提交 ee51be36 authored 作者: nouiz's avatar nouiz

Merge pull request #802 from jaberg/reshape_c_code

ENH: c_code for reshape
......@@ -5256,6 +5256,7 @@ class Reshape(Op):
raise ValueError('Cannot reshape input of shape %s to shape %s' %
(x.shape, shp))
def grad(self, inp, grads):
x, shp = inp
g_out, = grads
......@@ -5298,6 +5299,35 @@ class Reshape(Op):
oshape.append(os_i)
return [tuple(oshape)]
def c_code_cache_version(self):
return (1,)
def c_code(self, node, name, inputs, outputs, sub):
if isinstance(node.inputs[0], TensorVariable):
x, shp = inputs
z, = outputs
new_ndim = self.ndim
fail = sub['fail']
return """
assert (%(shp)s->nd == 1);
npy_intp new_dims[%(new_ndim)s];
PyArray_Dims newshape;
newshape.ptr = new_dims;
newshape.len = %(new_ndim)s;
for (int ii = 0; ii < %(new_ndim)s; ++ii)
{
// -- We do not want an explicit cast here. the shp can be any
// -- int* dtype. The compiler will explicitly upcast it, but
// -- will err if this will downcast. This could happen if the
// -- user pass an int64 dtype, but npy_intp endup being int32.
new_dims[ii] = ((dtype_%(shp)s*)(
%(shp)s->data + ii * %(shp)s->strides[0]))[0];
}
Py_XDECREF(%(z)s);
%(z)s = (PyArrayObject *) PyArray_Newshape(%(x)s, &newshape, NPY_ANYORDER);
""" % locals()
else:
return Op.c_code(self, node, name, inputs, outputs, sub)
def reshape(x, newshape, ndim=None, name=None):
if ndim is None:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论