提交 2ea6dbac authored 作者: Chiheb Trabelsi's avatar Chiheb Trabelsi

kernet_codegen.py has been modified in order to respect the flake8 style.

kernel_codegen.py do not contain long lines.
上级 e4ebf518
......@@ -76,7 +76,7 @@ def inline_reduce(N, buf, pos, count, manner_fn):
rest of the buffer is trashed by this function.
Notes
-----
-----
buf should be in gpu shared memory, we access it many times.
"""
......@@ -167,29 +167,26 @@ def inline_softmax(N, buf, buf2, threadPos, threadCount):
We use __i as an int variable in a loop.
"""
return [
# get max of buf (trashing all but buf[0])
inline_reduce_max(N, buf, threadPos, threadCount),
'__syncthreads()',
'float row_max = ' + buf + '[0]',
'__syncthreads()',
'for(int __i=' + threadPos + '; __i<' + N +
'; __i+=' + threadCount + '){',
buf + '[__i] = exp(' + buf2 + '[__i] - row_max)',
buf2 + '[__i] = ' + buf + '[__i]',
'}',
'__syncthreads()',
inline_reduce_sum(N, buf, threadPos, threadCount),
'__syncthreads()',
'float row_sum = ' + buf + '[0]',
'__syncthreads()',
# divide each exp() result by the sum to complete the job.
'for(int __i=' + threadPos + '; __i<' + N +
'; __i+=' + threadCount + '){',
buf + '[__i] = ' + buf2 + '[__i] / row_sum',
'}',
'__syncthreads()',
]
return [ # get max of buf (trashing all but buf[0])
inline_reduce_max(N, buf, threadPos, threadCount),
'__syncthreads()',
'float row_max = ' + buf + '[0]',
'__syncthreads()',
'for(int __i=' + threadPos + '; __i<' + N + '; __i+=' +
threadCount + '){',
buf + '[__i] = exp(' + buf2 + '[__i] - row_max)',
buf2 + '[__i] = ' + buf + '[__i]', '}',
'__syncthreads()',
inline_reduce_sum(N, buf, threadPos, threadCount),
'__syncthreads()',
'float row_sum = ' + buf + '[0]',
'__syncthreads()',
# divide each exp() result by the sum to complete the job.
'for(int __i=' + threadPos + '; __i<' + N +
'; __i+=' + threadCount + '){',
buf + '[__i] = ' + buf2 + '[__i] / row_sum', '}',
'__syncthreads()',
]
@code_version((1,))
......@@ -241,8 +238,7 @@ def inline_reduce_fixed_shared(N, buf, x, stride_x, pos, count,
init = manner_init("%(x)s[%(pos)s * %(stride_x)s]" % locals())
loop_line = manner_fn("red", manner_init("%(x)s[i * %(stride_x)s]" %
locals()))
loop_line2 = manner_fn("%s[%s]" % (buf, pos),
"%s[i]" % buf)
loop_line2 = manner_fn("%s[%s]" % (buf, pos), "%s[i]" % buf)
r_16 = manner_fn("%s[%s]" % (buf, pos), "%s[%s+16]" % (buf, pos))
r_8 = manner_fn("%s[%s]" % (buf, pos), "%s[%s+8]" % (buf, pos))
r_4 = manner_fn("%s[%s]" % (buf, pos), "%s[%s+4]" % (buf, pos))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论