提交 bbd6a10e authored 作者: Frederic's avatar Frederic

Faster GpuSoftmax for big row.

上级 f3fc1f0f
......@@ -183,9 +183,10 @@ def inline_reduce_fixed_shared(N, buf, x, stride_x, pos, count,
:note: buf should be in gpu shared memory, we access it many times.
"""
init = manner_init("%(x)s[tx * %(stride_x)s]" % locals())
loop_line = manner_fn("%s[%s]" % (buf, pos),
manner_init("%s[i * %s]" % (x, stride_x)))
init = manner_init("%(x)s[%(pos)s * %(stride_x)s]" % locals())
loop_line = manner_fn("red", manner_init("%s[i]" % x))
loop_line2 = manner_fn("%s[%s]" % (buf, pos),
"%s[i]" % buf)
r_16 = manner_fn("%s[%s]" % (buf, pos), "%s[%s+16]" % (buf, pos))
r_8 = manner_fn("%s[%s]" % (buf, pos), "%s[%s+8]" % (buf, pos))
r_4 = manner_fn("%s[%s]" % (buf, pos), "%s[%s+4]" % (buf, pos))
......@@ -194,21 +195,23 @@ def inline_reduce_fixed_shared(N, buf, x, stride_x, pos, count,
return """
{
// This function trashes buf[1..warpSize], leaving the reduction result in buf[0].
for (int tx = %(pos)s; tx<warpSize; tx += %(count)s){
%(buf)s[tx] = %(init)s;
// This function trashes buf[1..n_threads], leaving the reduction result in buf[0].
float red = %(init)s;
#pragma unroll 16
for (int i = %(pos)s + %(count)s; i<%(N)s; i += %(count)s){
red = %(loop_line)s;
}
buf[%(pos)s] = red;
__syncthreads();
if (%(pos)s < warpSize)
{
for (int i = %(pos)s + warpSize; i < %(N)s; i += warpSize)
for (int i = %(pos)s + warpSize; i < %(count)s; i += warpSize)
{
%(buf)s[%(pos)s] = %(loop_line)s;
%(buf)s[%(pos)s] = %(loop_line2)s;
}
if (%(pos)s < 16)
{
//reduce so that %(pos)s 0 has the sum of everything
//reduce so that %(pos)s 0 has the reduction of everything
if(%(pos)s + 16 < %(N)s)
%(buf)s[%(pos)s] = %(r_16)s;
if(%(pos)s + 8 < %(N)s)
......
......@@ -351,7 +351,7 @@ class GpuSoftmax (GpuOp):
return shape
def c_code_cache_version(self):
return (8,) + inline_softmax.code_version
return (9,) + inline_softmax.code_version
def c_code(self, node, nodename, inp, out, sub):
x, = inp
......@@ -409,7 +409,7 @@ class GpuSoftmax (GpuOp):
<<<
n_blocks,
n_threads,
32 * sizeof(float)
n_threads * sizeof(float)
>>>(
CudaNdarray_HOST_DIMS(%(x)s)[0],
CudaNdarray_HOST_DIMS(%(x)s)[1],
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论