提交 bebbb962 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Some fixes to CGpuKernelBase.

上级 7509e8a5
...@@ -224,19 +224,32 @@ class Kernel(object): ...@@ -224,19 +224,32 @@ class Kernel(object):
if self.flags.get('have_complex', False): if self.flags.get('have_complex', False):
res.append('GA_USE_COMPLEX') res.append('GA_USE_COMPLEX')
if self.flags.get('have_half', False): if self.flags.get('have_half', False):
res.append('GA_USE_SMALL') res.append('GA_USE_HALF')
return '|'.join(res) return '|'.join(res)
def _get_py_flags(self):
res = dict(self.flags)
cflags = res.pop('cflags', '')
for fl in cflags.split('|'):
fl = fl.strip()
if fl == 'GA_USE_CLUDA':
res['cluda'] = True
if fl == 'GA_USE_DOUBLE':
res['have_double'] = True
if fl == 'GA_USE_SMALL':
res['have_small'] = True
if fl == 'GA_USE_COMPLEX':
res['have_complex'] = True
if fl == 'GA_USE_HALF':
res['have_half'] = True
return res
def _get_c_types(self): def _get_c_types(self):
if not self.flags.get('ctypes', False): def m(t):
def m(t): if t == gpuarray.GpuArray:
if t == gpuarray.GpuArray: return "GA_BUFFER"
return "GA_BUFFER" else:
else: return str(gpuarray.dtype_to_typecode(t))
return str(gpuarray.dtype_to_typecode(t))
else:
def m(t):
return t
return ', '.join(m(t) for t in self.params) return ', '.join(m(t) for t in self.params)
...@@ -267,7 +280,7 @@ class GpuKernelBase(object): ...@@ -267,7 +280,7 @@ class GpuKernelBase(object):
def _generate_kernel_bin(self, k, ctx): def _generate_kernel_bin(self, k, ctx):
gk = gpuarray.GpuKernel(k.code, k.name, k.params, context=ctx, gk = gpuarray.GpuKernel(k.code, k.name, k.params, context=ctx,
**k.flags) **k._get_py_flags())
bin = gk._binary bin = gk._binary
bcode = ','.join(hex(c) for c in iterbytes(bin)) bcode = ','.join(hex(c) for c in iterbytes(bin))
return ("""static const char %(bname)s[] = { %(bcode)s };""" % return ("""static const char %(bname)s[] = { %(bcode)s };""" %
...@@ -377,6 +390,17 @@ def forward_string_meth(name): ...@@ -377,6 +390,17 @@ def forward_string_meth(name):
return f return f
def get_dtype(s):
if s == '*':
return gpuarray.GpuArray
if s == 'size':
return gpuarray.SIZE
if s == 'ssize':
return gpuarray.SSIZE
else:
return numpy.dtype(s)
class CGpuKernelBase(COp, GpuKernelBase): class CGpuKernelBase(COp, GpuKernelBase):
""" """
Class to combine GpuKernelBase and COp. Class to combine GpuKernelBase and COp.
...@@ -396,26 +420,28 @@ class CGpuKernelBase(COp, GpuKernelBase): ...@@ -396,26 +420,28 @@ class CGpuKernelBase(COp, GpuKernelBase):
c_cleanup_code_struct = forward_string_meth('c_cleanup_code_struct') c_cleanup_code_struct = forward_string_meth('c_cleanup_code_struct')
def _type_macros(self, node): def _type_macros(self, node):
define_template = "#define %s %s" define_template = "#define %s %s\n"
undef_template = "#undef %s" undef_template = "#undef %s\n"
define_macros = [] define_macros = []
undef_macros = [] undef_macros = []
for i, v in enumerate(node.inputs): for i, v in enumerate(node.inputs):
macro_name = "DTYPE_i%d" % (i,) if isinstance(v.type, GpuArrayType):
macro_value = pygpu.gpuarray.dtype_to_ctype(v.dtype) macro_name = "DTYPE_i%d" % (i,)
define_macros.append( macro_value = pygpu.gpuarray.dtype_to_ctype(v.dtype)
define_template % define_macros.append(
(macro_name, macro_value)) define_template %
undef_macros.append(undef_template % macro_name) (macro_name, macro_value))
undef_macros.append(undef_template % macro_name)
for i, v in enumerate(node.outputs): for i, v in enumerate(node.outputs):
macro_name = "DTYPE_o%d" % (i,) if isinstance(v.type, GpuArrayType):
macro_value = pygpu.gpuarray.dtype_to_ctype(v.dtype) macro_name = "DTYPE_o%d" % (i,)
define_macros.append( macro_value = pygpu.gpuarray.dtype_to_ctype(v.dtype)
define_template % define_macros.append(
(macro_name, macro_value)) define_template %
undef_macros.append(undef_template % macro_name) (macro_name, macro_value))
undef_macros.append(undef_template % macro_name)
return '\n'.join(define_macros), '\n'.join(undef_macros) return ''.join(define_macros), ''.join(undef_macros)
def gpu_kernels(self, node, name): def gpu_kernels(self, node, name):
if hasattr(self, '_cached_kernels'): if hasattr(self, '_cached_kernels'):
...@@ -436,12 +462,11 @@ class CGpuKernelBase(COp, GpuKernelBase): ...@@ -436,12 +462,11 @@ class CGpuKernelBase(COp, GpuKernelBase):
if len(splt2) != 3: if len(splt2) != 3:
raise ValueError("Bad kernel spec: %s" % (kspec,)) raise ValueError("Bad kernel spec: %s" % (kspec,))
kname = splt2[0].strip() kname = splt2[0].strip()
ktypes = [s.strip() for s in splt2[1].split(',')] ktypes = [get_dtype(s.strip()) for s in splt2[1].split(',')]
kflags = splt2[2].strip() kflags = splt2[2].strip()
kcode = def_macros + '\n' + kcode + '\n' + undef_macros kcode = def_macros + '\n' + kcode + '\n' + undef_macros
res.append(Kernel(kcode, ktypes, kname, res.append(Kernel(kcode, ktypes, kname,
flags=dict(ctypes=True, cluda=True, flags=dict(cluda=True, cflags=kflags)))
cflags=kflags)))
n += 2 n += 2
self._cached_kernels = res self._cached_kernels = res
return res return res
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论