提交 6fbae155 authored 作者: AndroidCloud's avatar AndroidCloud

adding updates for GpuReshape

上级 d92a95d2
...@@ -1129,7 +1129,7 @@ class GpuReshape(HideC, tensor.Reshape): ...@@ -1129,7 +1129,7 @@ class GpuReshape(HideC, tensor.Reshape):
context_name=ctx_name) context_name=ctx_name)
return Apply(self, [x, shp], [otype()]) return Apply(self, [x, shp], [otype()])
def perform(self, node, inp, out_): def perform(self, node, inp, out_, params):
x, shp = inp x, shp = inp
out, = out_ out, = out_
if (len(shp) != self.ndim): if (len(shp) != self.ndim):
...@@ -1165,25 +1165,26 @@ class GpuReshape(HideC, tensor.Reshape): ...@@ -1165,25 +1165,26 @@ class GpuReshape(HideC, tensor.Reshape):
new_ndim = self.ndim new_ndim = self.ndim
sdtype = node.inputs[1].type.dtype_specs()[1] sdtype = node.inputs[1].type.dtype_specs()[1]
fail = sub['fail'] fail = sub['fail']
params = sub['params']
return """ return """
size_t old_size = 1, new_size = 1; size_t old_size = 1, new_size = 1;
size_t new_dims[%(new_ndim)s]; size_t new_dims[%(params)s->ndim];
int compute_axis = -1; int compute_axis = -1;
assert (PyArray_NDIM(%(shape)s) == 1); assert (PyArray_NDIM(%(shape)s) == 1);
if (PyArray_DIM(%(shape)s, 0) != %(new_ndim)s) if (PyArray_DIM(%(shape)s, 0) != %(params)s->ndim)
{ {
PyErr_Format(PyExc_ValueError, PyErr_Format(PyExc_ValueError,
"GpuReshape: given shape is of incorrect " "GpuReshape: given shape is of incorrect "
"length (%%d should be %%d).", "length (%%d should be %%d).",
PyArray_DIM(%(shape)s, 0), %(new_ndim)s); PyArray_DIM(%(shape)s, 0), %(params)s->ndim);
%(fail)s; %(fail)s;
} }
for (size_t i = 0; i < %(x)s->ga.nd; ++i) for (size_t i = 0; i < %(x)s->ga.nd; ++i)
old_size *= %(x)s->ga.dimensions[i]; old_size *= %(x)s->ga.dimensions[i];
for (size_t i = 0; i < %(new_ndim)s; ++i) for (size_t i = 0; i < %(params)s->ndim; ++i)
{ {
new_dims[i] = ((%(sdtype)s*)( new_dims[i] = ((%(sdtype)s*)(
PyArray_BYTES(%(shape)s) + PyArray_BYTES(%(shape)s) +
...@@ -1224,7 +1225,7 @@ class GpuReshape(HideC, tensor.Reshape): ...@@ -1224,7 +1225,7 @@ class GpuReshape(HideC, tensor.Reshape):
} }
Py_XDECREF(%(output)s); Py_XDECREF(%(output)s);
%(output)s = pygpu_reshape(%(x)s, %(new_ndim)s, new_dims, %(output)s = pygpu_reshape(%(x)s, %(params)s->ndim, new_dims,
GA_C_ORDER, 0, compute_axis); GA_C_ORDER, 0, compute_axis);
if (%(output)s == NULL) if (%(output)s == NULL)
{ {
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论