提交 e0701ccd authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Add c_code to GpuFromHost too.

上级 c2a36918
......@@ -44,10 +44,7 @@ class HostFromGpu(Op):
z[0] = numpy.asarray(x)
def c_code(self, node, name, inputs, outputs, sub):
inp = inputs[0]
out = outputs[0]
fail = sub['fail']
return """{
return """
GpuArray %(name)s_ga_s;
GpuArray *%(name)s_ga;
int %(name)serr;
......@@ -80,7 +77,8 @@ class HostFromGpu(Op):
if (%(name)serr != GA_NO_ERROR) {
%(fail)s
}
}""" % {'name': name, 'fail': sub['fail'], 'inp': inp, 'out': out}
""" % {'name': name, 'fail': sub['fail'], 'inp': inputs[0],
'out': outputs[0]}
def grad(self, inputs, grads):
gz, = grads
......@@ -133,6 +131,40 @@ class GpuFromHost(Op):
def infer_shape(self, node, xshp):
return xshp
def c_code(self, node, name, inputs, outputs, sub):
type = node.outputs[0].type
return """
PyArrayObject *%(name)s_tmp;
int %(name)serr;
%(name)s_tmp = PyArray_GETCONTIGUOUS(%(inp)s);
if (%(name)s_tmp == NULL) {
%(fail)s
}
Py_DECREF(%(inp)s);
%(inp)s = %(name)s_tmp;
%(out)s = new_GpuArray((PyObject *)&GpuArrayType);
if (%(out)s == NULL) {
%(fail)s
}
%(name)serr = GpuArray_empty(&%(out)s->ga, compyte_get_ops("%(kind)s"),
(void *)%(ctx)s, %(typecode)s,
PyArray_NDIM(%(inp)s),
(size_t *)PyArray_DIMS(%(inp)s),
GA_C_ORDER);
if (%(name)serr != GA_NO_ERROR) {
Py_DECREF(%(out)s);
%(fail)s
}
%(name)serr = GpuArray_write(&%(out)s->ga, PyArray_DATA(%(inp)s),
PyArray_NBYTES(%(inp)s));
if (%(name)serr != GA_NO_ERROR) {
Py_DECREF(%(out)s);
%(fail)s
}
""" % {'name': name, 'kind': type.kind, 'ctx': hex(type.context),
'inp': inputs[0], 'out': outputs[0], 'fail': sub['fail'],
'typecode': type.typecode}
gpu_from_host = GpuFromHost()
......@@ -157,6 +189,8 @@ class GpuFromCuda(Op):
base = x
while hasattr(base, 'base') and base.base is not None:
base = base.base
# TODO: I know how to do this in C, but I don't know about python.
# Is perform() actually required to work?
raise NotImplementedError("How are we going to get a gpudata pointer from here")
x[0] = gpuarray.from_gpudata(b, 0, x.dtype, x.shape,
base=base, kind=globals.kind,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论