提交 6e74a08b authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Add c_code() to GpuAlloc().

上级 4c2e0344
...@@ -478,4 +478,50 @@ class GpuAlloc(HideC, Alloc): ...@@ -478,4 +478,50 @@ class GpuAlloc(HideC, Alloc):
if config.gpuarray.sync: if config.gpuarray.sync:
out[0].sync() out[0].sync()
def c_code(self, node, name, inp, out, sub):
vv = inp[0]
ndim = len(inp[1:])
zz, = out
code = """
int i;
size_t %(name)s_shape[%(ndim)s];
""" % dict(name=name, ndim=ndim)
for i, shp_i in enumerate(inp[1:]):
code += """
%(name)s_shape[%(i)s] = ((dtype_%(shp_i)s *)PyArray_DATA(%(shp_i)s))[0];
""" % dict(name=name, i=i, shp_i=shp_i)
code += """
int need_new_out = (NULL == %(zz)s || %(zz)s->ga.nd != %(ndim)s);
if (!need_new_out)
for (i = 0; i < %(ndim)s; i++)
need_new_out |= %(zz)s->ga.dimensions[i] != %(name)s_shape[i];
if (need_new_out) {
Py_XDECREF(%(zz)s);
%(zz)s = pygpu_empty(%(ndim)s, %(name)s_shape,
%(vv)s->ga.typecode, GA_C_ORDER,
pygpu_default_context(), Py_None);
if (!%(zz)s) {
%(fail)s
}
}
if (GpuArray_setarray(&%(zz)s->ga, &%(vv)s->ga) != GA_NO_ERROR) {
PyErr_SetString(PyExc_RuntimeError, "setarray failed");
%(fail)s
}
""" % dict(name=name, ndim=ndim, zz=zz, vv=vv, fail=sub['fail'])
if config.gpuarray.sync:
code += "GpuArray_sync(&%(zz)s->ga);" % dict(zz=zz);
return code
def c_code_cache_version(self):
(0,)
gpu_alloc = GpuAlloc() gpu_alloc = GpuAlloc()
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论