提交 26496654 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Make GpuSoftmax and GpuSoftmaxWithBias work with f16

上级 4a2e513e
......@@ -121,7 +121,7 @@ def inline_reduce_prod(N, buf, pos, count):
lambda a, b: "%s * %s" % (a, b))
@code_version((2,) + inline_reduce_max.code_version +
@code_version((3,) + inline_reduce_max.code_version +
inline_reduce_sum.code_version)
def inline_softmax(N, buf, buf2, threadPos, threadCount, dtype="float32"):
"""
......@@ -165,10 +165,10 @@ def inline_softmax(N, buf, buf2, threadPos, threadCount, dtype="float32"):
]
@code_version((1,))
def inline_reduce_fixed_shared(N, buf, x, stride_x, pos, count,
@code_version((2,))
def inline_reduce_fixed_shared(N, buf, x, stride_x, load_x, pos, count,
manner_fn, manner_init,
b='', stride_b='', dtype='float32'):
b='', stride_b='', load_b='', dtype='float32'):
"""Return C++ code for a function that reduces a contiguous buffer.
:param N: length of the buffer
......@@ -193,15 +193,15 @@ def inline_reduce_fixed_shared(N, buf, x, stride_x, pos, count,
"""
if b:
init = manner_init("%(x)s[%(pos)s * %(stride_x)s] +"
" %(b)s[%(pos)s * %(stride_b)s]" % locals())
init = manner_init("%(load_x)s(%(x)s[%(pos)s * %(stride_x)s]) +"
" %(load_b)s(%(b)s[%(pos)s * %(stride_b)s])" % locals())
loop_line = manner_fn("red",
manner_init("%(x)s[i * %(stride_x)s] + "
"%(b)s[i * %(stride_b)s]" %
manner_init("%(load_x)s(%(x)s[i * %(stride_x)s]) + "
"%(load_b)s(%(b)s[i * %(stride_b)s])" %
locals()))
else:
init = manner_init("%(x)s[%(pos)s * %(stride_x)s]" % locals())
loop_line = manner_fn("red", manner_init("%(x)s[i * %(stride_x)s]" %
init = manner_init("%(load_x)s(%(x)s[%(pos)s * %(stride_x)s])" % locals())
loop_line = manner_fn("red", manner_init("%(load_x)s(%(x)s[i * %(stride_x)s])" %
locals()))
loop_line2 = manner_fn("%s[%s]" % (buf, pos),
"%s[i]" % buf)
......@@ -248,20 +248,22 @@ def inline_reduce_fixed_shared(N, buf, x, stride_x, pos, count,
@code_version(inline_reduce_fixed_shared.code_version)
def inline_reduce_fixed_shared_max(N, buf, x, stride_x, pos, count,
b='', stride_b='', dtype='float32'):
return inline_reduce_fixed_shared(N, buf, x, stride_x, pos, count,
def inline_reduce_fixed_shared_max(N, buf, x, stride_x, load_x, pos, count,
b='', stride_b='', load_b='',
dtype='float32'):
return inline_reduce_fixed_shared(N, buf, x, stride_x, load_x, pos, count,
lambda a, b: "max(%s, %s)" % (a, b),
lambda a: a,
b, stride_b, dtype)
b, stride_b, load_b, dtype)
@code_version((1,) + inline_reduce_max.code_version +
@code_version((2,) + inline_reduce_max.code_version +
inline_reduce_sum.code_version)
def inline_softmax_fixed_shared(N, buf, x, stride_x,
sm, sm_stride,
def inline_softmax_fixed_shared(N, buf, x, stride_x, load_x,
sm, sm_stride, write_sm,
threadPos, threadCount,
b='', stride_b='', dtype="float32"):
b='', stride_b='', load_b='',
dtype="float32"):
"""
:param N: length of the buffer, atleast waprSize(32).
......@@ -286,16 +288,18 @@ def inline_softmax_fixed_shared(N, buf, x, stride_x,
"""
ret = [
# get max of buf (trashing all but buf[0])
inline_reduce_fixed_shared_max(N, buf, x, stride_x,
threadPos, threadCount, b, stride_b,
inline_reduce_fixed_shared_max(N, buf, x, stride_x, load_x,
threadPos, threadCount,
b, stride_b, load_b,
dtype),
'__syncthreads()',
('npy_%s row_max = ' + buf + '[0]') % dtype,
'__syncthreads()',
inline_reduce_fixed_shared(N, buf, x, stride_x, threadPos, threadCount,
inline_reduce_fixed_shared(N, buf, x, stride_x, load_x,
threadPos, threadCount,
lambda a, b: "%s + %s" % (a, b),
lambda a: "exp(%s - row_max)" % a,
b, stride_b, dtype),
b, stride_b, load_b, dtype),
'__syncthreads()',
('npy_%s row_sum = ' + buf + '[0]') % dtype,
'__syncthreads()',
......@@ -305,13 +309,14 @@ def inline_softmax_fixed_shared(N, buf, x, stride_x,
if b:
ret += [
"%(sm)s[tx * %(sm_stride)s] = "
" exp(%(x)s[tx * %(stride_x)s] +"
" %(b)s[tx * %(stride_b)s] - row_max)"
" / row_sum" % locals()]
" %(write_sm)s(exp(%(load_x)s(%(x)s[tx * %(stride_x)s]) +"
" %(load_b)s(%(b)s[tx * %(stride_b)s]) - row_max)"
" / row_sum)" % locals()]
else:
ret += [
"%(sm)s[tx * %(sm_stride)s] = "
"exp(%(x)s[tx * %(stride_x)s] - row_max) / row_sum" % locals()]
"%(write_sm)s(exp(%(load_x)s(%(x)s[tx * %(stride_x)s]) - row_max)"
" / row_sum)" % locals()]
ret += [
"}",
'__syncthreads()',
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论