提交 07e95861 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Fix GpuSubtensor so that it covers all the cases properly.

上级 776448dc
...@@ -40,68 +40,115 @@ class GpuSubtensor(HideC, Subtensor): ...@@ -40,68 +40,115 @@ class GpuSubtensor(HideC, Subtensor):
out[0] = x.__getitem__(cdata) out[0] = x.__getitem__(cdata)
def c_support_code(self):
return """
static int fix_indices(ssize_t *start, ssize_t *stop, ssize_t *step,
int start_n, int stop_n, int step_n,
size_t len) {
if (step_n) *step = 1;
if (*step == 0) {
PyErr_SetString(PyExc_ValueError, "slice step cannot be zero");
return -1;
}
if (start_n) *start = (*step < 0) ? len-1 : 0;
else {
if (*start < 0) *start += len;
if (*start < 0) *start = (*step < 0) ? -1 : 0;
if (*start >= len) *start = (*step < 0) ? len-1 : len;
}
if (stop_n) *stop = (*step < 0) ? -1 : len;
else {
if (*stop < 0) *stop += len;
if (*stop < 0) *stop = (*step < 0) ? -1 : 0;
if (*stop >= len) *stop = (*step < 0) ? len-1 : len;
}
if (*stop < *start && *step > 0)
*stop = *start;
return 0;
}
"""
def c_code(self, node, name, inputs, outputs, sub): def c_code(self, node, name, inputs, outputs, sub):
view_ndim = node.outputs[0].ndim inp_ndim = node.inputs[0].ndim
inp = inputs[0]
indices = inputs[1:] indices = inputs[1:]
# pad out the index list to the same dimension as the input
idx_list = self.idx_list + \
((slice(None),) * (inp_ndim - len(self.idx_list)))
sio = StringIO.StringIO() sio = StringIO.StringIO()
print >> sio, """ print >> sio, """
ssize_t %(name)s_starts[%(sz)s]; ssize_t starts[%(sz)s];
ssize_t %(name)s_stops[%(sz)s]; ssize_t stops[%(sz)s];
ssize_t %(name)s_steps[%(sz)s]; ssize_t steps[%(sz)s];
""" % dict(name=name, sz=len(self.idx_list)) ssize_t cur;
int err;
ndim = 0
for i, idx in enumerate(self.idx_list): if (%(inp)s->ga.nd != %(sz)s) {
if isinstance(idx, gof.Type): PyErr_SetString(PyExc_IndexError, "invalid index");
# Index by an input number %(fail)s
print >>sio, """ }
%(name)s_starts[%(i)s] = %(start)s; """ % dict(sz=len(idx_list), inp=inp, fail=sub['fail'])
%(name)s_steps[%(i)s] = %(step)s;
""" % dict(name=name, i=i, start=indices.pop(), step=0) def fix_idx(idx):
elif isinstance(idx, slice): if idx is None:
# index by a fixed slice return "0", 1
step = idx.step elif isinstance(idx, (numpy.integer, int)):
if step is None: return str(idx), 0
step = 1 elif isinstance(idx, gof.Type):
stop = idx.stop return indices.pop(), 0
if stop is None: else:
#TODO find what is needed assert 0, idx
raise NotImplementedError("This case is not yet implemented!")
for i, idx in enumerate(idx_list):
if isinstance(idx, slice):
start, start_n = fix_idx(idx.start)
stop, stop_n = fix_idx(idx.stop)
step, step_n = fix_idx(idx.step)
print >>sio, """ print >>sio, """
%(name)s_starts[%(i)s] = %(start)s; starts[%(i)s] = %(start)s;
%(name)s_stops[%(i)s] = %(stop)s; stops[%(i)s] = %(stop)s;
%(name)s_steps[%(i)s] = %(step)s; steps[%(i)s] = %(step)s;
""" % dict(i=i, name=name, start=idx.start, stop=stop, if (fix_indices(&starts[%(i)s], &stops[%(i)s], &steps[%(i)s],
step=step) %(start_n)s, %(stop_n)s, %(step_n)s,
ndim += 1 %(inp)s->ga.dimensions[%(i)s]) == -1) {
%(fail)s
}
""" % dict(i=i, start=start, stop=stop, step=step,
start_n=start_n, stop_n=stop_n, step_n=step_n,
fail=sub['fail'], inp=inp)
else: else:
# Index by a fixed number if isinstance(idx, gof.Type):
start = indices.pop()
elif isinstance(idx, (numpy.integer, int)):
start = idx
else:
assert 0, idx
print >>sio, """ print >>sio, """
%(name)s_starts[%(i)s] = %(start)s; cur = %(start)s;
%(name)s_steps[%(i)s] = %(step)s; if (cur < 0)
""" % dict(name=name, i=i, start=idx, step=0) cur += %(inp)s->ga.dimensions[%(i)s];
starts[%(i)s] = cur;
steps[%(i)s] = 0;
""" % dict(i=i, start=start, fail=sub['fail'], inp=inp)
print >>sio, """ print >>sio, """
if (%(out)s) { Py_XDECREF(%(out)s);
// Try to reuse the python object. %(out)s = new_GpuArray((PyObject *)&PyGpuArrayType, pygpu_default_context(), Py_None);
GpuArray_clear(&%(out)s->ga);
} else {
%(out)s = new_GpuArray((PyObject *)&PyGpuArrayType, pygpu_default_context(), Py_None);
}
if (!%(out)s) { %(fail)s } if (!%(out)s) { %(fail)s }
int %(name)s_err; if ((err = GpuArray_index(&%(out)s->ga, &%(inp)s->ga, starts, stops, steps))) {
%(name)s_err = GpuArray_index(&%(out)s->ga, &%(inp)s->ga,
%(name)s_starts, %(name)s_steps,
%(name)s_stops);
if (%(name)s_err != GA_NO_ERROR) {
Py_DECREF(%(out)s); %(out)s = NULL; Py_DECREF(%(out)s); %(out)s = NULL;
PyErr_SetString(PyExc_RuntimeError, "Error during index"); if (err == GA_VALUE_ERROR)
PyErr_SetString(PyExc_IndexError, "index out of bounds");
else
PyErr_SetString(PyExc_RuntimeError, "index failed");
%(fail)s %(fail)s
} }
""" % dict(name=name, fail=sub['fail'], inp=inputs[0], out=outputs[0]) """ % dict(name=name, fail=sub['fail'], inp=inp, out=outputs[0])
return sio.getvalue() return sio.getvalue()
def c_code_cache_version(self): def c_code_cache_version(self):
return (2,) return (3,)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论