提交 bbd6a10e authored 作者: Frederic's avatar Frederic

Faster GpuSoftmax for big row.

上级 f3fc1f0f
...@@ -183,9 +183,10 @@ def inline_reduce_fixed_shared(N, buf, x, stride_x, pos, count, ...@@ -183,9 +183,10 @@ def inline_reduce_fixed_shared(N, buf, x, stride_x, pos, count,
:note: buf should be in gpu shared memory, we access it many times. :note: buf should be in gpu shared memory, we access it many times.
""" """
init = manner_init("%(x)s[tx * %(stride_x)s]" % locals()) init = manner_init("%(x)s[%(pos)s * %(stride_x)s]" % locals())
loop_line = manner_fn("%s[%s]" % (buf, pos), loop_line = manner_fn("red", manner_init("%s[i]" % x))
manner_init("%s[i * %s]" % (x, stride_x))) loop_line2 = manner_fn("%s[%s]" % (buf, pos),
"%s[i]" % buf)
r_16 = manner_fn("%s[%s]" % (buf, pos), "%s[%s+16]" % (buf, pos)) r_16 = manner_fn("%s[%s]" % (buf, pos), "%s[%s+16]" % (buf, pos))
r_8 = manner_fn("%s[%s]" % (buf, pos), "%s[%s+8]" % (buf, pos)) r_8 = manner_fn("%s[%s]" % (buf, pos), "%s[%s+8]" % (buf, pos))
r_4 = manner_fn("%s[%s]" % (buf, pos), "%s[%s+4]" % (buf, pos)) r_4 = manner_fn("%s[%s]" % (buf, pos), "%s[%s+4]" % (buf, pos))
...@@ -194,21 +195,23 @@ def inline_reduce_fixed_shared(N, buf, x, stride_x, pos, count, ...@@ -194,21 +195,23 @@ def inline_reduce_fixed_shared(N, buf, x, stride_x, pos, count,
return """ return """
{ {
// This function trashes buf[1..warpSize], leaving the reduction result in buf[0]. // This function trashes buf[1..n_threads], leaving the reduction result in buf[0].
float red = %(init)s;
for (int tx = %(pos)s; tx<warpSize; tx += %(count)s){ #pragma unroll 16
%(buf)s[tx] = %(init)s; for (int i = %(pos)s + %(count)s; i<%(N)s; i += %(count)s){
red = %(loop_line)s;
} }
buf[%(pos)s] = red;
__syncthreads(); __syncthreads();
if (%(pos)s < warpSize) if (%(pos)s < warpSize)
{ {
for (int i = %(pos)s + warpSize; i < %(N)s; i += warpSize) for (int i = %(pos)s + warpSize; i < %(count)s; i += warpSize)
{ {
%(buf)s[%(pos)s] = %(loop_line)s; %(buf)s[%(pos)s] = %(loop_line2)s;
} }
if (%(pos)s < 16) if (%(pos)s < 16)
{ {
//reduce so that %(pos)s 0 has the sum of everything //reduce so that %(pos)s 0 has the reduction of everything
if(%(pos)s + 16 < %(N)s) if(%(pos)s + 16 < %(N)s)
%(buf)s[%(pos)s] = %(r_16)s; %(buf)s[%(pos)s] = %(r_16)s;
if(%(pos)s + 8 < %(N)s) if(%(pos)s + 8 < %(N)s)
......
...@@ -351,7 +351,7 @@ class GpuSoftmax (GpuOp): ...@@ -351,7 +351,7 @@ class GpuSoftmax (GpuOp):
return shape return shape
def c_code_cache_version(self): def c_code_cache_version(self):
return (8,) + inline_softmax.code_version return (9,) + inline_softmax.code_version
def c_code(self, node, nodename, inp, out, sub): def c_code(self, node, nodename, inp, out, sub):
x, = inp x, = inp
...@@ -409,7 +409,7 @@ class GpuSoftmax (GpuOp): ...@@ -409,7 +409,7 @@ class GpuSoftmax (GpuOp):
<<< <<<
n_blocks, n_blocks,
n_threads, n_threads,
32 * sizeof(float) n_threads * sizeof(float)
>>>( >>>(
CudaNdarray_HOST_DIMS(%(x)s)[0], CudaNdarray_HOST_DIMS(%(x)s)[0],
CudaNdarray_HOST_DIMS(%(x)s)[1], CudaNdarray_HOST_DIMS(%(x)s)[1],
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论