提交 cff861bf authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Make GpuKernelBase use c_init_code_apply() and fix GpuEye to use the new API.

上级 6d851334
...@@ -52,6 +52,7 @@ class HideC(object): ...@@ -52,6 +52,7 @@ class HideC(object):
c_compile_args = __hide c_compile_args = __hide
c_no_compile_args = __hide c_no_compile_args = __hide
c_init_code = __hide c_init_code = __hide
c_init_code_apply = __hide
def c_code_cache_version(self): def c_code_cache_version(self):
return () return ()
...@@ -63,13 +64,13 @@ class HideC(object): ...@@ -63,13 +64,13 @@ class HideC(object):
class GpuKernelBase(object): class GpuKernelBase(object):
GpuKernelBase_version = 0 GpuKernelBase_version = 0
def c_kernel_code(self): def c_kernel_code(self, node):
""" """
Return the source code of the kernel. Return the source code of the kernel.
""" """
raise AttributeError("c_kernel_code", type(self)) raise AttributeError("c_kernel_code", type(self))
def c_kernel_params(self): def c_kernel_params(self, node):
""" """
Return the list of typecodes for kernel parameters. Return the list of typecodes for kernel parameters.
...@@ -83,7 +84,7 @@ class GpuKernelBase(object): ...@@ -83,7 +84,7 @@ class GpuKernelBase(object):
""" """
raise AttributeError("c_kernel_name", type(self)) raise AttributeError("c_kernel_name", type(self))
def c_kernel_flags(self): def c_kernel_flags(self, node):
""" """
Return a string representing the C flags for the kernel. Return a string representing the C flags for the kernel.
...@@ -95,11 +96,11 @@ class GpuKernelBase(object): ...@@ -95,11 +96,11 @@ class GpuKernelBase(object):
""" """
raise AttributeError("c_kernel_flags", type(self)) raise AttributeError("c_kernel_flags", type(self))
def c_kernel_codevar(self): def c_kernel_codevar(self, name):
return 'kcode_' + type(self).__name__ + '_' + hex(hash(self))[2:] return 'kcode_' + name
def c_kernel_obj(self): def c_kernel_obj(self, name):
return 'k_' + type(self).__name__ + '_' + hex(hash(self))[2:] return 'k_' + name
def _get_kernel_flags(self, *dtypes): def _get_kernel_flags(self, *dtypes):
dtypes = [numpy.dtype(d) for d in dtypes] dtypes = [numpy.dtype(d) for d in dtypes]
...@@ -113,35 +114,35 @@ class GpuKernelBase(object): ...@@ -113,35 +114,35 @@ class GpuKernelBase(object):
def c_headers(self): def c_headers(self):
return ['compyte/types.h'] return ['compyte/types.h']
def c_support_code(self): def c_support_code_apply(self, node, name):
kcode = self.c_kernel_code() kcode = self.c_kernel_code(node)
vname = self.c_kernel_codevar() vname = self.c_kernel_codevar(name)
kname = self.c_kernel_obj() kname = self.c_kernel_obj(name)
code = '\\n'.join(l for l in kcode.split('\n')) code = '\\n'.join(l for l in kcode.split('\n'))
return """static const char *%(vname)s = "%(code)s"; return """static const char *%(vname)s = "%(code)s";
static GpuKernel %(kname)s;""" % dict(vname=vname, kname=kname,code=code) static GpuKernel %(kname)s;""" % dict(vname=vname, kname=kname, code=code)
def c_init_code(self): def c_init_code_apply(self, node, name):
types = self.c_kernel_params() types = self.c_kernel_params(node)
numargs = len(types) numargs = len(types)
name = self.c_kernel_name() kname = self.c_kernel_name()
vname = self.c_kernel_codevar() vname = self.c_kernel_codevar(name)
kname = self.c_kernel_obj() oname = self.c_kernel_obj(name)
flags = self.c_kernel_flags() flags = self.c_kernel_flags(node)
# TODO: find a way to release the kernel once the module is unloaded # TODO: find a way to release the kernel once the module is unloaded
error_out = "" error_out = ""
if PY3: if PY3:
error_out = "NULL" error_out = "NULL"
return [""" return """
int types[%(numargs)u] = {%(types)s}; int types_%(name)s[%(numargs)u] = {%(types)s};
if (GpuKernel_init(&%(kname)s, pygpu_default_context()->ops, if (GpuKernel_init(&%(oname)s, pygpu_default_context()->ops,
pygpu_default_context()->ctx, 1, &%(vname)s, NULL, pygpu_default_context()->ctx, 1, &%(vname)s, NULL,
"%(name)s", %(numargs)s, types, %(flags)s) != GA_NO_ERROR) { "%(kname)s", %(numargs)s, types_%(name)s, %(flags)s) != GA_NO_ERROR) {
PyErr_SetString(PyExc_RuntimeError, "Error initializing kernel"); PyErr_SetString(PyExc_RuntimeError, "Error initializing kernel");
return %(error_out)s; return %(error_out)s;
} }
""" % dict(types=','.join(types), numargs=numargs, kname=kname, name=name, """ % dict(types=','.join(types), numargs=numargs, kname=kname, oname=oname,
vname=vname, flags=flags, error_out=error_out)] vname=vname, flags=flags, error_out=error_out, name=name)
class HostFromGpu(Op): class HostFromGpu(Op):
...@@ -726,7 +727,7 @@ class GpuEye(GpuKernelBase, Op): ...@@ -726,7 +727,7 @@ class GpuEye(GpuKernelBase, Op):
def __hash__(self): def __hash__(self):
return hash(self.dtype) ^ hash(type(self)) return hash(self.dtype) ^ hash(type(self))
def c_kernel_code(self): def c_kernel_code(self, node):
return """ return """
KERNEL void k(GLOBAL_MEM %(ctype)s *a, ga_size n, ga_size m) { KERNEL void k(GLOBAL_MEM %(ctype)s *a, ga_size n, ga_size m) {
ga_size nb = n < m ? n : m; ga_size nb = n < m ? n : m;
...@@ -735,13 +736,13 @@ KERNEL void k(GLOBAL_MEM %(ctype)s *a, ga_size n, ga_size m) { ...@@ -735,13 +736,13 @@ KERNEL void k(GLOBAL_MEM %(ctype)s *a, ga_size n, ga_size m) {
} }
}""" % dict(ctype=pygpu.gpuarray.dtype_to_ctype(self.dtype)) }""" % dict(ctype=pygpu.gpuarray.dtype_to_ctype(self.dtype))
def c_kernel_params(self): def c_kernel_params(self, node):
return ["GA_BUFFER", "GA_SIZE", "GA_SIZE"] return ["GA_BUFFER", "GA_SIZE", "GA_SIZE"]
def c_kernel_name(self): def c_kernel_name(self):
return "k" return "k"
def c_kernel_flags(self): def c_kernel_flags(self, node):
return self._get_kernel_flags(self.dtype) return self._get_kernel_flags(self.dtype)
def c_code(self, node, name, inp, out, sub): def c_code(self, node, name, inp, out, sub):
...@@ -750,7 +751,7 @@ KERNEL void k(GLOBAL_MEM %(ctype)s *a, ga_size n, ga_size m) { ...@@ -750,7 +751,7 @@ KERNEL void k(GLOBAL_MEM %(ctype)s *a, ga_size n, ga_size m) {
fail = sub['fail'] fail = sub['fail']
typecode = pygpu.gpuarray.dtype_to_typecode(self.dtype) typecode = pygpu.gpuarray.dtype_to_typecode(self.dtype)
sync = bool(config.gpuarray.sync) sync = bool(config.gpuarray.sync)
kname = self.c_kernel_obj() kname = self.c_kernel_obj(name)
s = """ s = """
size_t dims[2] = {0, 0}; size_t dims[2] = {0, 0};
void *args[3]; void *args[3];
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论