提交 ebc740ac authored 作者: James Bergstra's avatar James Bergstra

Merge pull request #813 from lamblin/fix_reshape

Fix C code of Reshape when input is not c-contiguous
...@@ -5306,7 +5306,7 @@ class Reshape(Op): ...@@ -5306,7 +5306,7 @@ class Reshape(Op):
return [tuple(oshape)] return [tuple(oshape)]
def c_code_cache_version(self): def c_code_cache_version(self):
return (1,) return (2,)
def c_code(self, node, name, inputs, outputs, sub): def c_code(self, node, name, inputs, outputs, sub):
if isinstance(node.inputs[0], TensorVariable): if isinstance(node.inputs[0], TensorVariable):
...@@ -5330,11 +5330,18 @@ class Reshape(Op): ...@@ -5330,11 +5330,18 @@ class Reshape(Op):
%(shp)s->data + ii * %(shp)s->strides[0]))[0]; %(shp)s->data + ii * %(shp)s->strides[0]))[0];
} }
Py_XDECREF(%(z)s); Py_XDECREF(%(z)s);
%(z)s = (PyArrayObject *) PyArray_Newshape(%(x)s, &newshape, NPY_ANYORDER); %(z)s = (PyArrayObject *) PyArray_Newshape(%(x)s, &newshape, PyArray_CORDER);
if (!%(z)s)
{
PyErr_Format(PyExc_ValueError,
"Could not reshape array.");
%(fail)s;
}
""" % locals() """ % locals()
else: else:
return Op.c_code(self, node, name, inputs, outputs, sub) return Op.c_code(self, node, name, inputs, outputs, sub)
def reshape(x, newshape, ndim=None, name=None): def reshape(x, newshape, ndim=None, name=None):
if ndim is None: if ndim is None:
ndim = get_vector_length(newshape) ndim = get_vector_length(newshape)
......
...@@ -4285,7 +4285,16 @@ class T_reshape(unittest.TestCase): ...@@ -4285,7 +4285,16 @@ class T_reshape(unittest.TestCase):
#basic to 1 dim(without list) #basic to 1 dim(without list)
c = reshape(b, as_tensor_variable(6), ndim=1) c = reshape(b, as_tensor_variable(6), ndim=1)
f = inplace_func([b], c) f = inplace_func([b], c)
assert numpy.all(f(numpy.asarray([[0,1,2],[3,4,5]])) == numpy.asarray([0,1,2,3,4,5]))
b_val1 = numpy.asarray([[0,1,2],[3,4,5]])
c_val1 = numpy.asarray([0,1,2,3,4,5])
b_val2 = b_val1.T
c_val2 = numpy.asarray([0,3,1,4,2,5])
f_out1 = f(b_val1)
f_out2 = f(b_val2)
assert numpy.all(f_out1 == c_val1), (f_out1, c_val1)
assert numpy.all(f_out2 == c_val2), (f_out2, c_val2)
#print f.maker.fgraph.toposort() #print f.maker.fgraph.toposort()
#check that we remove the useless reshape #check that we remove the useless reshape
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论