提交 11d58402 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Un-bump code version and document new parameters.

上级 afb65d0b
......@@ -121,7 +121,7 @@ def inline_reduce_prod(N, buf, pos, count):
lambda a, b: "%s * %s" % (a, b))
@code_version((3,) + inline_reduce_max.code_version +
@code_version((2,) + inline_reduce_max.code_version +
inline_reduce_sum.code_version)
def inline_softmax(N, buf, buf2, threadPos, threadCount, dtype="float32"):
"""
......@@ -173,10 +173,14 @@ def inline_reduce_fixed_shared(N, buf, x, stride_x, load_x, pos, count,
:param N: length of the buffer
:param buf: buffer pointer of size warpSize * sizeof(dtype)
:param x: input data
:param stride_x: input data stride
:param load_x: wrapper to read from x
:param pos: index of executing thread
:param count: number of executing threads
:param b: Optional, pointer to the bias
:param stride_b: Optional, the stride of b if b is provided
:param load_b: Optional, wrapper to read from b if b is provided
:param dtype: Optional, the dtype of the output
:param manner_fn: a function that accepts strings of arguments a
......@@ -270,12 +274,15 @@ def inline_softmax_fixed_shared(N, buf, x, stride_x, load_x,
:param buf: a shared memory buffer of size warpSize * sizeof(dtype)
:param x: a ptr to the gpu memory where the row is stored
:param stride_x: the stride between each element in x
:param load_x: wrapper to read from x
:param sm: a ptr to the gpu memory to store the result
:param sm_stride: the stride between eash sm element
:param write_sm: wrapper before writing to sm
:param threadPos: index of executing thread
:param threadCount: number of executing threads
:param b: Optional, pointer to the bias
:param stride_b: Optional, the stride of b if b is provided
:param load_b: Optional, wrapper to read from b if b is provided
:param dtype: Optional, the dtype of the softmax's output if not float32
:Precondition: buf is empty
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论