提交 e2a09110 authored 作者: James Bergstra's avatar James Bergstra

ENH: c_code for reshape

上级 555af254
......@@ -5256,6 +5256,7 @@ class Reshape(Op):
raise ValueError('Cannot reshape input of shape %s to shape %s' %
(x.shape, shp))
def grad(self, inp, grads):
x, shp = inp
g_out, = grads
......@@ -5298,6 +5299,30 @@ class Reshape(Op):
oshape.append(os_i)
return [tuple(oshape)]
def c_code_cache_version(self):
return (1,)
def c_code(self, node, name, inputs, outputs, sub):
if isinstance(node.inputs[0], TensorVariable):
x, shp = inputs
z, = outputs
new_ndim = self.ndim
fail = sub['fail']
return """
assert (%(shp)s->nd == 1);
npy_intp new_dims[%(new_ndim)s];
PyArray_Dims newshape;
newshape.ptr = new_dims;
newshape.len = %(new_ndim)s;
for (int ii = 0; ii < %(new_ndim)s; ++ii)
{
new_dims[ii] = ((dtype_%(shp)s*)(%(shp)s->data + ii * %(shp)s->strides[0]))[0];
}
Py_XDECREF(%(z)s);
%(z)s = (PyArrayObject *) PyArray_Newshape(%(x)s, &newshape, NPY_ANYORDER);
""" % locals()
else:
return Op.c_code(self, node, name, inputs, outputs, sub)
def reshape(x, newshape, ndim=None, name=None):
if ndim is None:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论