提交 c2c34032 authored 作者: Frederic's avatar Frederic

pep8

上级 97fff534
""" Helper routines for generating gpu kernels for nvcc. """ Helper routines for generating gpu kernels for nvcc.
""" """
def nvcc_kernel(name, params, body): def nvcc_kernel(name, params, body):
"""Return the c code of a kernel function. """Return the c code of a kernel function.
:param params: the parameters to the function as one or more strings :param params: the parameters to the function as one or more strings
:param body: the [nested] list of statements for the body of the function. These will be :param body: the [nested] list of statements for the body of the
separated by ';' characters. function. These will be separated by ';' characters.
""" """
paramstr = ', '.join(params) paramstr = ', '.join(params)
def flatbody(): def flatbody():
for b in body: for b in body:
if isinstance(b, (list, tuple)): if isinstance(b, (list, tuple)):
...@@ -21,12 +25,14 @@ def nvcc_kernel(name, params, body): ...@@ -21,12 +25,14 @@ def nvcc_kernel(name, params, body):
{ {
%(bodystr)s; %(bodystr)s;
} }
""" %locals() """ % locals()
def code_version(version): def code_version(version):
"""decorator to support version-based cache mechanism""" """decorator to support version-based cache mechanism"""
if not isinstance(version, tuple): if not isinstance(version, tuple):
raise TypeError('version must be tuple', version) raise TypeError('version must be tuple', version)
def deco(f): def deco(f):
f.code_version = version f.code_version = version
return f return f
...@@ -34,31 +40,33 @@ def code_version(version): ...@@ -34,31 +40,33 @@ def code_version(version):
UNVERSIONED = () UNVERSIONED = ()
@code_version((1,)) @code_version((1,))
def inline_reduce(N, buf, pos, count, manner_fn): def inline_reduce(N, buf, pos, count, manner_fn):
""" """Return C++ code for a function that reduces a contiguous buffer.
Return C++ code for a function that reduces a contiguous buffer.
:param N: length of the buffer :param N: length of the buffer
:param buf: buffer pointer :param buf: buffer pointer
:param pos: index of executing thread :param pos: index of executing thread
:param count: number of executing threads :param count: number of executing threads
:param manner_fn: a function that accepts strings of arguments a and b, and returns c code
for their reduction. (Example: return "%(a)s + %(b)s" for a sum reduction). :param manner_fn: a function that accepts strings of arguments a
and b, and returns c code for their reduction. (Example:
return "%(a)s + %(b)s" for a sum reduction).
:postcondition: :postcondition:
This function leaves the answer in position 0 of the buffer. The rest of the buffer is This function leaves the answer in position 0 of the buffer. The
trashed by this function. rest of the buffer is trashed by this function.
:note: buf should be in gpu shared memory, we access it many times. :note: buf should be in gpu shared memory, we access it many times.
""" """
loop_line = manner_fn("%s[%s]"%(buf,pos), "%s[i]" %(buf)) loop_line = manner_fn("%s[%s]" % (buf, pos), "%s[i]" % (buf))
r_16 = manner_fn("%s[%s]" %(buf, pos), "%s[%s+16]" %(buf, pos)) r_16 = manner_fn("%s[%s]" % (buf, pos), "%s[%s+16]" % (buf, pos))
r_8 = manner_fn("%s[%s]" %(buf, pos), "%s[%s+8]" %(buf, pos)) r_8 = manner_fn("%s[%s]" % (buf, pos), "%s[%s+8]" % (buf, pos))
r_4 = manner_fn("%s[%s]" %(buf, pos), "%s[%s+4]" %(buf, pos)) r_4 = manner_fn("%s[%s]" % (buf, pos), "%s[%s+4]" % (buf, pos))
r_2 = manner_fn("%s[%s]" %(buf, pos), "%s[%s+2]" %(buf, pos)) r_2 = manner_fn("%s[%s]" % (buf, pos), "%s[%s+2]" % (buf, pos))
r_1 = manner_fn("%s[%s]" %(buf, pos), "%s[%s+1]" %(buf, pos)) r_1 = manner_fn("%s[%s]" % (buf, pos), "%s[%s+1]" % (buf, pos))
return """ return """
{ {
...@@ -88,24 +96,33 @@ def inline_reduce(N, buf, pos, count, manner_fn): ...@@ -88,24 +96,33 @@ def inline_reduce(N, buf, pos, count, manner_fn):
} }
""" % locals() """ % locals()
@code_version(inline_reduce.code_version) @code_version(inline_reduce.code_version)
def inline_reduce_max(N, buf, pos, count): def inline_reduce_max(N, buf, pos, count):
return inline_reduce(N, buf, pos, count, lambda a, b: "max(%s, %s)"%(a,b)) return inline_reduce(N, buf, pos, count,
lambda a, b: "max(%s, %s)" % (a, b))
@code_version(inline_reduce.code_version) @code_version(inline_reduce.code_version)
def inline_reduce_sum(N, buf, pos, count): def inline_reduce_sum(N, buf, pos, count):
return inline_reduce(N, buf, pos, count, lambda a, b: "%s + %s"%(a,b)) return inline_reduce(N, buf, pos, count,
lambda a, b: "%s + %s" % (a, b))
@code_version(inline_reduce.code_version) @code_version(inline_reduce.code_version)
def inline_reduce_min(N, buf, pos, count): def inline_reduce_min(N, buf, pos, count):
return inline_reduce(N, buf, pos, count, lambda a, b: "min(%s, %s)"%(a,b)) return inline_reduce(N, buf, pos, count,
lambda a, b: "min(%s, %s)" % (a, b))
@code_version(inline_reduce.code_version) @code_version(inline_reduce.code_version)
def inline_reduce_prod(N, buf, pos, count): def inline_reduce_prod(N, buf, pos, count):
return inline_reduce(N, buf, pos, count, lambda a, b: "%s * %s"%(a,b)) return inline_reduce(N, buf, pos, count,
lambda a, b: "%s * %s" % (a, b))
@code_version((2,) + inline_reduce_max.code_version + inline_reduce_sum.code_version) @code_version((2,) + inline_reduce_max.code_version +
inline_reduce_sum.code_version)
def inline_softmax(N, buf, buf2, threadPos, threadCount): def inline_softmax(N, buf, buf2, threadPos, threadCount):
""" """
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论