提交 8f21a69a authored 作者: Frederic's avatar Frederic 提交者: Arnaud Bergeron

Fix some compilation crash of the new GpuSubtensor

上级 d24e99ff
import StringIO
import numpy
import theano
from theano import tensor
from theano import tensor, gof
from theano.tensor.subtensor import Subtensor, get_idx_list
from theano.gof.python25 import all, any
......@@ -21,7 +23,7 @@ class GpuSubtensor(HideC, Subtensor):
otype = GpuArrayType(dtype=rval.outputs[0].type.dtype,
broadcastable=rval.outputs[0].type.broadcastable)
x = as_gpuarray_variable(x)
return Apply(self, [x] + rval.inputs[1:], [otype()])
return gof.Apply(self, [x] + rval.inputs[1:], [otype()])
def perform(self, node, inputs, out_):
out, = out_
......@@ -42,15 +44,16 @@ class GpuSubtensor(HideC, Subtensor):
view_ndim = node.outputs[0].ndim
indices = inputs[1:]
sio = StringIO.StringIO("""
sio = StringIO.StringIO()
print >> sio, """
ssize_t %(name)s_starts[%(sz)s];
ssize_t %(name)s_stops[%(sz)s];
ssize_t %(name)s_steps[%(sz)s];
""" % dict(name=name, sz=len(self.idx_list)))
""" % dict(name=name, sz=len(self.idx_list))
ndim = 0
for i, idx in enumerate(self.idx_list):
if isinstance(idx, Type):
if isinstance(idx, gof.Type):
# Index by an input number
print >>sio, """
%(name)s_starts[%(i)s] = %(start)s;
......@@ -58,12 +61,19 @@ class GpuSubtensor(HideC, Subtensor):
""" % dict(name=name, i=i, start=indices.pop(), step=0)
elif isinstance(idx, slice):
# index by a fixed slice
step = idx.step
if step is None:
step = 1
stop = idx.stop
if stop is None:
#TODO find what is needed
raise NotImplementedError("This case is not yet implemented!")
print >>sio, """
%(name)s_starts[%(i)s] = %(start)s;
%(name)s_stops[%(i)s] = %(stop)s;
%(name)s_steps[%(i)s] = %(step)s;
""" % dict(i=i, name=name, start=idx.start, stop=idx.stop,
step=idx.step)
""" % dict(i=i, name=name, start=idx.start, stop=stop,
step=step)
ndim += 1
else:
# Index by a fixed number
......@@ -77,13 +87,13 @@ class GpuSubtensor(HideC, Subtensor):
// Try to reuse the python object.
GpuArray_clear(&%(out)s->ga);
} else {
%(out)s = new_GpuArray((PyObject *)&PyGpuArrayType, pygpu_default_context, Py_None);
%(out)s = new_GpuArray((PyObject *)&PyGpuArrayType, pygpu_default_context(), Py_None);
}
if (!%(out)s) { %(fail)s }
int %(name)s_err;
%(name)s_err = GpuArray_index(&%(out)s->ga, &%(inp)s->ga,
%(name)s_starts, %(name)s_steps,
%(name)s_stops)
%(name)s_stops);
if (%(name)s_err != GA_NO_ERROR) {
Py_DECREF(%(out)s); %(out)s = NULL;
PyErr_SetString(PyExc_RuntimeError, "Error during index");
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论