提交 62b0bddb authored 作者: notoraptor's avatar notoraptor

Fix code for GpuReshape.

上级 82284b21
......@@ -1151,17 +1151,21 @@ class GpuReshape(HideC, tensor.Reshape):
out[0] = x.reshape(tuple(shp))
def c_code_cache_version(self):
return (2,)
return (3,)
def c_code(self, node, name, inputs, outputs, sub):
x, shape = inputs
output, = outputs
sdtype = node.inputs[1].type.dtype_specs()[1]
fail = sub['fail']
just_fail = sub['fail']
fail = """{
free(new_dims);
%(just_fail)s
}""" % dict(just_fail=just_fail)
params = sub['params']
return """
size_t old_size = 1, new_size = 1;
size_t new_dims[%(params)s->ndim];
size_t* new_dims = NULL;
int compute_axis = -1;
assert (PyArray_NDIM(%(shape)s) == 1);
......@@ -1171,7 +1175,13 @@ class GpuReshape(HideC, tensor.Reshape):
"GpuReshape: given shape is of incorrect "
"length (%%d should be %%d).",
PyArray_DIM(%(shape)s, 0), %(params)s->ndim);
%(fail)s;
%(just_fail)s;
}
new_dims = (size_t*) malloc(sizeof(size_t) * %(params)s->ndim);
if (new_dims == NULL) {
PyErr_NoMemory();
%(just_fail)s
}
for (size_t i = 0; i < %(x)s->ga.nd; ++i)
......@@ -1220,9 +1230,10 @@ class GpuReshape(HideC, tensor.Reshape):
Py_XDECREF(%(output)s);
%(output)s = pygpu_reshape(%(x)s, %(params)s->ndim, new_dims,
GA_C_ORDER, 0, compute_axis);
free(new_dims);
if (%(output)s == NULL)
{
%(fail)s;
%(just_fail)s;
}
""" % locals()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论