提交 ca20b477 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Fix C code of Reshape when input is not c-contiguous

Also add test that would have been able to catch it.
上级 221ba717
......@@ -5300,7 +5300,7 @@ class Reshape(Op):
return [tuple(oshape)]
def c_code_cache_version(self):
return (1,)
return (2,)
def c_code(self, node, name, inputs, outputs, sub):
if isinstance(node.inputs[0], TensorVariable):
......@@ -5324,11 +5324,18 @@ class Reshape(Op):
%(shp)s->data + ii * %(shp)s->strides[0]))[0];
}
Py_XDECREF(%(z)s);
%(z)s = (PyArrayObject *) PyArray_Newshape(%(x)s, &newshape, NPY_ANYORDER);
%(z)s = (PyArrayObject *) PyArray_Newshape(%(x)s, &newshape, PyArray_CORDER);
if (!%(z)s)
{
PyErr_Format(PyExc_ValueError,
"Could not reshape array.");
%(fail)s;
}
""" % locals()
else:
return Op.c_code(self, node, name, inputs, outputs, sub)
def reshape(x, newshape, ndim=None, name=None):
if ndim is None:
ndim = get_vector_length(newshape)
......
......@@ -4284,7 +4284,16 @@ class T_reshape(unittest.TestCase):
#basic to 1 dim(without list)
c = reshape(b, as_tensor_variable(6), ndim=1)
f = inplace_func([b], c)
assert numpy.all(f(numpy.asarray([[0,1,2],[3,4,5]])) == numpy.asarray([0,1,2,3,4,5]))
b_val1 = numpy.asarray([[0,1,2],[3,4,5]])
c_val1 = numpy.asarray([0,1,2,3,4,5])
b_val2 = b_val1.T
c_val2 = numpy.asarray([0,3,1,4,2,5])
f_out1 = f(b_val1)
f_out2 = f(b_val2)
assert numpy.all(f_out1 == c_val1), (f_out1, c_val1)
assert numpy.all(f_out2 == c_val2), (f_out2, c_val2)
#print f.maker.fgraph.toposort()
#check that we remove the useless reshape
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论