Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
9eb54e01
提交
9eb54e01
authored
6月 06, 2017
作者:
xiaoqie
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Port Softmax to OpenCL
上级
608e9aef
隐藏空白字符变更
内嵌
并排
正在显示
1 个修改的文件
包含
387 行增加
和
192 行删除
+387
-192
nnet.py
theano/gpuarray/nnet.py
+387
-192
没有找到文件。
theano/gpuarray/nnet.py
浏览文件 @
9eb54e01
...
...
@@ -65,41 +65,47 @@ class GpuCrossentropySoftmaxArgmax1HotWithBias(GpuKernelBase, Op):
type_y_idx
=
gpuarray
.
dtype_to_ctype
(
dtype_y_idx
)
kname
=
"k_xent_sm_1hot_bias"
k_var
=
"k_xent_sm_1hot_bias_"
+
nodename
f
=
''
if
dtype_x
==
'float64'
else
'f'
if
node
.
inputs
[
0
]
.
type
.
context
.
kind
!=
b
'cuda'
:
f
=
''
else
:
f
=
''
if
dtype_x
==
'float64'
else
'f'
params
=
[
gpuarray
.
SIZE
,
gpuarray
.
SIZE
,
gpuarray
.
GpuArray
,
gpuarray
.
SIZE
,
gpuarray
.
SSIZE
,
gpuarray
.
SSIZE
,
gpuarray
.
GpuArray
,
gpuarray
.
SIZE
,
gpuarray
.
SSIZE
,
gpuarray
.
GpuArray
,
gpuarray
.
SIZE
,
gpuarray
.
SSIZE
,
gpuarray
.
GpuArray
,
gpuarray
.
SIZE
,
gpuarray
.
SSIZE
,
gpuarray
.
GpuArray
,
gpuarray
.
SIZE
,
gpuarray
.
SSIZE
,
gpuarray
.
SSIZE
,
gpuarray
.
GpuArray
,
gpuarray
.
SIZE
,
gpuarray
.
SSIZE
]
sio
=
StringIO
()
print
(
"""
KERNEL void
%(kname)
s(const ga_size M, const ga_size N,
const
%(type_x)
s* x_data, const ga_size offset_x,
const ga_ssize xs0, const ga_ssize xs1,
const
%(type_b)
s* b, const ga_size offset_b,
const ga_ssize bs0,
const
%(type_y_idx)
s* y_idx_data, const ga_size offset_y_idx,
const ga_ssize y_idxs0,
%(type_x)
s* nll_data, const ga_size offset_nll,
const ga_ssize nlls0,
%(type_x)
s* sm_data, const ga_size offset_sm,
const ga_ssize sms0, const ga_ssize sms1,
%(type_y_idx)
s* am_data, const ga_size offset_am,
const ga_ssize ams0)
GLOBAL_MEM const
%(type_x)
s* x_data, const ga_size offset_x, const ga_ssize xs0, const ga_ssize xs1,
GLOBAL_MEM const
%(type_b)
s* b, const ga_size offset_b, const ga_ssize bs0,
GLOBAL_MEM const
%(type_y_idx)
s* y_idx_data, const ga_size offset_y_idx, const ga_ssize y_idxs0,
GLOBAL_MEM
%(type_x)
s* nll_data, const ga_size offset_nll, const ga_ssize nlls0,
GLOBAL_MEM
%(type_x)
s* sm_data, const ga_size offset_sm, const ga_ssize sms0, const ga_ssize sms1,
GLOBAL_MEM
%(type_y_idx)
s* am_data, const ga_size offset_am, const ga_ssize ams0 GA_DECL_SHARED_PARAM(
%(work_x)
s, per_thread_values))
{
x_data = (
const
%(type_x)
s *)(((
char *)x_data)+offset_x);
b = (
const
%(type_b)
s *)(((
char *)b)+offset_b);
y_idx_data = (
const
%(type_y_idx)
s *)(((
char *)y_idx_data)+offset_y_idx);
nll_data = (
%(type_x)
s *)(((
char *)nll_data)+offset_nll);
sm_data = (
%(type_x)
s *)(((
char *)sm_data)+offset_sm);
am_data = (
%(type_y_idx)
s *)(((
char *)am_data)+offset_am);
x_data = (
GLOBAL_MEM const
%(type_x)
s *)(((GLOBAL_MEM
char *)x_data)+offset_x);
b = (
GLOBAL_MEM const
%(type_b)
s *)(((GLOBAL_MEM
char *)b)+offset_b);
y_idx_data = (
GLOBAL_MEM const
%(type_y_idx)
s *)(((GLOBAL_MEM
char *)y_idx_data)+offset_y_idx);
nll_data = (
GLOBAL_MEM
%(type_x)
s *)(((GLOBAL_MEM
char *)nll_data)+offset_nll);
sm_data = (
GLOBAL_MEM
%(type_x)
s *)(((GLOBAL_MEM
char *)sm_data)+offset_sm);
am_data = (
GLOBAL_MEM
%(type_y_idx)
s *)(((GLOBAL_MEM
char *)am_data)+offset_am);
for (
int row = blockIdx.x; row < M; row += gridDim.x
){
for (
ga_int row = GID_0; row < M; row += GDIM_0
){
const
%(type_x)
s* x = x_data + xs0 * row;
%(type_x)
s* sm = sm_data + sms0 * row;
GLOBAL_MEM
const
%(type_x)
s* x = x_data + xs0 * row;
GLOBAL_MEM
%(type_x)
s* sm = sm_data + sms0 * row;
extern LOCAL_MEM
%(work_x)
s per_thread_values[]
;
GA_DECL_SHARED_BODY(
%(work_x)
s, per_thread_values)
;
LOCAL_MEM
%(work_x)
s row_max, sum, sum_inv;
LOCAL_MEM int row_max_threadIdx;
LOCAL_MEM
ga_
int row_max_threadIdx;
%(work_x)
s per_thread_row_max, per_thread_sum;
int per_thread_row_max_j;
ga_
int per_thread_row_max_j;
// COMPUTE ROW MAX AND ARGMAX
...
...
@@ -107,20 +113,20 @@ class GpuCrossentropySoftmaxArgmax1HotWithBias(GpuKernelBase, Op):
per_thread_row_max = NAN;
per_thread_row_max_j = 0;
for (
int j = threadIdx.x; j < N; j += blockDim.x
)
for (
ga_int j = LID_0; j < N; j += LDIM_0
)
{
%(work_x)
s row_ij =
%(load_x)
s(x[j * xs1]) +
%(load_b)
s(b[j * bs0]);
per_thread_row_max_j = (row_ij > per_thread_row_max) ? j : per_thread_row_max_j;
per_thread_row_max = fmax
%(f)
s(row_ij, per_thread_row_max);
}
per_thread_values[
threadIdx.x
] = per_thread_row_max;
per_thread_values[
LID_0
] = per_thread_row_max;
local_barrier();
if (
threadIdx.x
== 0) {
if (
LID_0
== 0) {
row_max = NAN;
row_max_threadIdx = 0;
for (
int j = 0; j < blockDim.x
; j++)
for (
ga_int j = 0; j < LDIM_0
; j++)
{
%(work_x)
s per_thread_max = per_thread_values[j];
row_max_threadIdx = (per_thread_max > row_max) ? j : row_max_threadIdx;
...
...
@@ -132,11 +138,11 @@ class GpuCrossentropySoftmaxArgmax1HotWithBias(GpuKernelBase, Op):
// The thread with the higest max writes out which of its
// values was the winner.
if (
threadIdx.x
== row_max_threadIdx) am_data[row * ams0] = per_thread_row_max_j;
if (
LID_0
== row_max_threadIdx) am_data[row * ams0] = per_thread_row_max_j;
// COMPUTE SOFTMAX
per_thread_sum = 0.0;
for (
int j = threadIdx.x; j < N; j += blockDim.x
)
for (
ga_int j = LID_0; j < N; j += LDIM_0
)
{
%(work_x)
s row_ij =
%(load_x)
s(x[j * xs1]) +
%(load_b)
s(b[j * bs0]);
%(work_x)
s sm_ij = exp
%(f)
s(row_ij - row_max);
...
...
@@ -144,13 +150,13 @@ class GpuCrossentropySoftmaxArgmax1HotWithBias(GpuKernelBase, Op):
sm[j * sms1] =
%(write_x)
s(sm_ij);
}
per_thread_values[
threadIdx.x
] = per_thread_sum;
per_thread_values[
LID_0
] = per_thread_sum;
local_barrier();
if (
threadIdx.x
== 0) {
if (
LID_0
== 0) {
sum = 0.0;
for (
int j = 0; j < blockDim.x
; j++) {
for (
ga_int j = 0; j < LDIM_0
; j++) {
sum += per_thread_values[j];
}
sum_inv = 1.0 / sum;
...
...
@@ -158,12 +164,12 @@ class GpuCrossentropySoftmaxArgmax1HotWithBias(GpuKernelBase, Op):
local_barrier();
for (
int j = threadIdx.x; j < N; j += blockDim.x
) {
for (
ga_int j = LID_0; j < N; j += LDIM_0
) {
sm[j * sms1] =
%(write_x)
s(
%(load_x)
s(sm[j * sms1]) * sum_inv);
}
if (
threadIdx.x
== 0) {
const
%(type_y_idx)
s y_idx = (int)y_idx_data[row * y_idxs0];
if (
LID_0
== 0) {
const
%(type_y_idx)
s y_idx = (
ga_
int)y_idx_data[row * y_idxs0];
if ((y_idx >= N || y_idx < 0)) {
// raise some suspicion.
nll_data[row * nlls0] =
%(write_x)
s(0.0);
...
...
@@ -177,21 +183,11 @@ class GpuCrossentropySoftmaxArgmax1HotWithBias(GpuKernelBase, Op):
}
}
"""
%
locals
(),
file
=
sio
)
params
=
[
'uintp'
,
'uintp'
,
gpuarray
.
GpuArray
,
'uintp'
,
'intp'
,
'intp'
,
gpuarray
.
GpuArray
,
'uintp'
,
'intp'
,
gpuarray
.
GpuArray
,
'uintp'
,
'intp'
,
gpuarray
.
GpuArray
,
'uintp'
,
'intp'
,
gpuarray
.
GpuArray
,
'uintp'
,
'intp'
,
'intp'
,
gpuarray
.
GpuArray
,
'uintp'
,
'intp'
]
return
[
Kernel
(
code
=
sio
.
getvalue
(),
name
=
kname
,
params
=
params
,
flags
=
flags
,
objvar
=
k_var
)]
def
c_code
(
self
,
node
,
nodename
,
inp
,
out
,
sub
):
if
node
.
inputs
[
0
]
.
type
.
context
.
kind
!=
b
'cuda'
:
raise
NotImplementedError
(
'cuda only'
)
itemsize_x
=
np
.
dtype
(
node
.
inputs
[
0
]
.
dtype
)
.
itemsize
worksize_x
=
np
.
dtype
(
work_dtype
(
node
.
inputs
[
0
]
.
dtype
))
.
itemsize
itemsize_b
=
np
.
dtype
(
node
.
inputs
[
1
]
.
dtype
)
.
itemsize
...
...
@@ -266,7 +262,7 @@ class GpuCrossentropySoftmaxArgmax1HotWithBias(GpuKernelBase, Op):
return
sio
.
getvalue
()
def
c_code_cache_version
(
self
):
return
(
1
2
,)
return
(
1
3
,)
gpu_crossentropy_softmax_argmax_1hot_with_bias
=
GpuCrossentropySoftmaxArgmax1HotWithBias
()
...
...
@@ -292,14 +288,12 @@ class GpuCrossentropySoftmax1HotWithBiasDx(GpuKernelBase, Op):
return
Apply
(
self
,
[
dnll
,
sm
,
y_idx
],
[
sm
.
type
()])
def
c_code_cache_version
(
self
):
return
(
1
2
,)
return
(
1
3
,)
def
c_headers
(
self
):
return
[
'<numpy_compat.h>'
,
'<gpuarray/types.h>'
]
def
c_code
(
self
,
node
,
nodename
,
inp
,
out
,
sub
):
if
node
.
inputs
[
0
]
.
type
.
context
.
kind
!=
b
'cuda'
:
raise
NotImplementedError
(
"cuda only"
)
typecode_dx
=
pygpu
.
gpuarray
.
dtype_to_typecode
(
node
.
outputs
[
0
]
.
dtype
)
itemsize_dnll
=
np
.
dtype
(
node
.
inputs
[
0
]
.
dtype
)
.
itemsize
itemsize_sm
=
np
.
dtype
(
node
.
inputs
[
1
]
.
dtype
)
.
itemsize
...
...
@@ -429,30 +423,33 @@ class GpuCrossentropySoftmax1HotWithBiasDx(GpuKernelBase, Op):
type_dx
=
gpuarray
.
dtype_to_ctype
(
dtype_dx
)
kname
=
"kCrossEntropySoftmax1HotWithBiasDx"
k_var
=
"kCrossEntropySoftmax1HotWithBiasDx_"
+
nodename
params
=
[
gpuarray
.
SIZE
,
gpuarray
.
SIZE
,
gpuarray
.
GpuArray
,
gpuarray
.
SIZE
,
gpuarray
.
SSIZE
,
gpuarray
.
GpuArray
,
gpuarray
.
SIZE
,
gpuarray
.
SSIZE
,
gpuarray
.
SSIZE
,
gpuarray
.
GpuArray
,
gpuarray
.
SIZE
,
gpuarray
.
SSIZE
,
gpuarray
.
GpuArray
,
gpuarray
.
SIZE
,
gpuarray
.
SSIZE
,
gpuarray
.
SSIZE
,
]
sio
=
StringIO
()
print
(
"""
KERNEL void
%(kname)
s(
const ga_size N, const ga_size K,
const
%(type_dnll)
s* dnll, const ga_size offset_dnll,
const ga_ssize dnll_s0,
const
%(type_sm)
s* sm, const ga_size offset_sm,
const ga_ssize sm_s0, const ga_ssize sm_s1,
const
%(type_y_idx)
s* y_idx, const ga_size offset_y_idx,
const ga_ssize y_idx_s0,
%(type_dx)
s* dx, const ga_size offset_dx,
const ga_ssize dx_s0, const ga_ssize dx_s1)
GLOBAL_MEM const
%(type_dnll)
s* dnll, const ga_size offset_dnll, const ga_ssize dnll_s0,
GLOBAL_MEM const
%(type_sm)
s* sm, const ga_size offset_sm, const ga_ssize sm_s0, const ga_ssize sm_s1,
GLOBAL_MEM const
%(type_y_idx)
s* y_idx, const ga_size offset_y_idx, const ga_ssize y_idx_s0,
GLOBAL_MEM
%(type_dx)
s* dx, const ga_size offset_dx, const ga_ssize dx_s0, const ga_ssize dx_s1)
{
dnll = (
const
%(type_dnll)
s *)(((
char *)dnll)+offset_dnll);
sm = (
const
%(type_sm)
s *)(((
char *)sm)+offset_sm);
y_idx = (
const
%(type_y_idx)
s *)(((
char *)y_idx)+offset_y_idx);
dx = (
%(type_dx)
s *)(((
char *)dx)+offset_dx);
dnll = (
GLOBAL_MEM const
%(type_dnll)
s *)(((GLOBAL_MEM
char *)dnll)+offset_dnll);
sm = (
GLOBAL_MEM const
%(type_sm)
s *)(((GLOBAL_MEM
char *)sm)+offset_sm);
y_idx = (
GLOBAL_MEM const
%(type_y_idx)
s *)(((GLOBAL_MEM
char *)y_idx)+offset_y_idx);
dx = (
GLOBAL_MEM
%(type_dx)
s *)(((GLOBAL_MEM
char *)dx)+offset_dx);
for (
int i = blockIdx.x; i < N; i += gridDim.x
)
for (
ga_int i = GID_0; i < N; i += GDIM_0
)
{
%(wtype_dnll)
s dnll_i =
%(load_dnll)
s(dnll[i * dnll_s0]);
%(type_y_idx)
s y_i = y_idx[i * y_idx_s0];
for (
int j = threadIdx.x; j < K; j += blockDim.x
)
for (
ga_int j = LID_0; j < K; j += LDIM_0
)
{
if (y_i == j)
{
...
...
@@ -470,13 +467,6 @@ class GpuCrossentropySoftmax1HotWithBiasDx(GpuKernelBase, Op):
}
}
"""
%
locals
(),
file
=
sio
)
params
=
[
'uintp'
,
'uintp'
,
gpuarray
.
GpuArray
,
'uintp'
,
'intp'
,
gpuarray
.
GpuArray
,
'uintp'
,
'intp'
,
'intp'
,
gpuarray
.
GpuArray
,
'uintp'
,
'intp'
,
gpuarray
.
GpuArray
,
'uintp'
,
'intp'
,
'intp'
]
return
[
Kernel
(
code
=
sio
.
getvalue
(),
name
=
kname
,
params
=
params
,
flags
=
flags
,
objvar
=
k_var
)]
...
...
@@ -499,14 +489,12 @@ class GpuSoftmax(GpuKernelBase, Op):
return
shape
def
c_code_cache_version
(
self
):
return
(
1
5
,)
+
inline_softmax
.
code_version
return
(
1
6
,)
+
inline_softmax
.
code_version
def
c_headers
(
self
):
return
[
'<numpy_compat.h>'
,
'<gpuarray/types.h>'
]
def
c_code
(
self
,
node
,
nodename
,
inp
,
out
,
sub
):
if
node
.
inputs
[
0
]
.
type
.
context
.
kind
!=
b
'cuda'
:
raise
NotImplementedError
(
"cuda only"
)
dtype_x
=
node
.
inputs
[
0
]
.
dtype
work_x
=
work_dtype
(
dtype_x
)
dtype_z
=
node
.
outputs
[
0
]
.
dtype
...
...
@@ -607,67 +595,169 @@ class GpuSoftmax(GpuKernelBase, Op):
type_x
=
gpuarray
.
dtype_to_ctype
(
dtype_x
)
type_sm
=
gpuarray
.
dtype_to_ctype
(
dtype_sm
)
type_acc
=
gpuarray
.
dtype_to_ctype
(
work_sm
)
ctype
=
gpuarray
.
dtype_to_ctype
(
dtype_sm
)
params
=
[
'uintp'
,
'uintp'
,
gpuarray
.
GpuArray
,
'uintp'
,
'intp'
,
'intp'
,
gpuarray
.
GpuArray
,
'uintp'
,
'intp'
,
'intp'
]
gpuarray
.
SIZE
,
gpuarray
.
SIZE
,
gpuarray
.
GpuArray
,
gpuarray
.
SIZE
,
gpuarray
.
SSIZE
,
gpuarray
.
SSIZE
,
gpuarray
.
GpuArray
,
gpuarray
.
SIZE
,
gpuarray
.
SSIZE
,
gpuarray
.
SSIZE
]
kernels
=
[]
kname
=
"kSoftmax"
k_var
=
"kSoftmax_"
+
nodename
code
=
nvcc_kernel
(
kname
,
params
=
[
'const ga_size M'
,
'const ga_size N'
,
'const
%
s * x'
%
type_x
,
'const ga_size offset_x'
,
'const ga_ssize sx0'
,
'const ga_ssize sx1'
,
'
%
s * sm'
%
type_sm
,
'const ga_size offset_sm'
,
'const ga_ssize sm_s0'
,
'const ga_ssize sm_s1'
],
body
=
[
"extern __shared__
%
s buf[]"
%
type_acc
,
"
%
s * buf2 = buf + N"
%
type_acc
,
"x = (const
%
s *)(((char *)x)+offset_x)"
%
type_x
,
"sm = (
%
s *)(((char *)sm)+offset_sm)"
%
type_sm
,
"for (int blockIDX = blockIdx.x; blockIDX < M;"
" blockIDX += gridDim.x){"
,
"for (int tx = threadIdx.x; tx< N; tx += blockDim.x){"
,
"buf[tx] =
%
s(x[blockIDX * sx0 + tx * sx1])"
%
load_x
,
"buf2[tx] = buf[tx]"
,
"}"
,
"__syncthreads()"
,
inline_softmax
(
'N'
,
'buf'
,
'buf2'
,
'threadIdx.x'
,
'blockDim.x'
,
dtype
=
work_sm
),
"for (int tx = threadIdx.x; tx< N; tx += blockDim.x){"
,
# This set all value correctly
"sm[blockIDX * sm_s0 + tx * sm_s1] =
%
s(buf[tx])"
%
write_sm
,
"}"
,
"__syncthreads()"
,
"}"
,
])
code
=
"""
KERNEL void
%(kname)
s (const ga_size M, const ga_size N, GLOBAL_MEM const
%(type_x)
s * x, const ga_size offset_x,
const ga_ssize sx0, const ga_ssize sx1, GLOBAL_MEM
%(type_sm)
s * sm, const ga_size offset_sm, const ga_ssize sm_s0, const ga_ssize sm_s1 GA_DECL_SHARED_PARAM(
%(type_acc)
s, buf))
{
GA_DECL_SHARED_BODY(
%(type_acc)
s, buf);
LOCAL_MEM
%(type_acc)
s * buf2 = buf + N;
x = (GLOBAL_MEM const
%(type_x)
s *)(((GLOBAL_MEM char *)x)+offset_x);
sm = (GLOBAL_MEM
%(type_sm)
s *)(((GLOBAL_MEM char *)sm)+offset_sm);
for (ga_int blockIDX = GID_0; blockIDX < M; blockIDX += GDIM_0) {
for (ga_int tx = LID_0; tx< N; tx += LDIM_0) {
buf[tx] =
%(load_x)
s(x[blockIDX * sx0 + tx * sx1]);
buf2[tx] = buf[tx];
}
local_barrier();
{
// This function trashes buf[1..GA_WARP_SIZE],
// leaving the reduction result in buf[0].
if (LID_0 < GA_WARP_SIZE) {
for (ga_int i = LID_0 + GA_WARP_SIZE; i < N; i += GA_WARP_SIZE)
{
buf[LID_0] = max(buf[LID_0], buf[i]);
}
}
local_barrier();
//reduce so that LID_0 0 has the reduction of everything
for (ga_uint _n = GA_WARP_SIZE / 2; _n > 0; _n /= 2) {
if (LID_0 < _n && LID_0 + _n < N)
buf[LID_0] = max(buf[LID_0], buf[LID_0+_n]);
local_barrier();
}
}
local_barrier();
%(ctype)
s row_max = buf[0];
local_barrier();
for(ga_int __i=LID_0; __i<N; __i+=LDIM_0){
buf[__i] = exp(buf2[__i] - row_max);
buf2[__i] = buf[__i];
}
local_barrier();
{
// This function trashes buf[1..GA_WARP_SIZE],
// leaving the reduction result in buf[0].
if (LID_0 < GA_WARP_SIZE) {
for (ga_int i = LID_0 + GA_WARP_SIZE; i < N; i += GA_WARP_SIZE)
{
buf[LID_0] = buf[LID_0] + buf[i];
}
}
local_barrier();
//reduce so that LID_0 0 has the reduction of everything
for (ga_uint _n = GA_WARP_SIZE / 2; _n > 0; _n /= 2) {
if (LID_0 < _n && LID_0 + _n < N)
buf[LID_0] = buf[LID_0] + buf[LID_0+_n];
local_barrier();
}
}
local_barrier();
%(ctype)
s row_sum = buf[0];
local_barrier();
for(ga_int __i=LID_0; __i<N; __i+=LDIM_0) {
buf[__i] = buf2[__i] / row_sum;
}
local_barrier();
for (ga_int tx = LID_0; tx< N; tx += LDIM_0) {
sm[blockIDX * sm_s0 + tx * sm_s1] =
%(write_sm)
s(buf[tx]);
}
local_barrier();
}
}
"""
%
locals
()
kernels
.
append
(
Kernel
(
code
=
code
,
name
=
kname
,
params
=
params
,
flags
=
flags
,
objvar
=
k_var
))
kname
=
"kSoftmax_fixed_shared"
k_var
=
"kSoftmax_fixed_shared"
+
nodename
code
=
nvcc_kernel
(
kname
,
params
=
[
'const ga_size M'
,
'const ga_size N'
,
'const
%
s * x'
%
type_x
,
'const ga_size offset_x'
,
'const ga_ssize sx0'
,
'const ga_ssize sx1'
,
'
%
s * sm'
%
type_sm
,
'const ga_size offset_sm'
,
'const ga_ssize sm_s0'
,
'const ga_ssize sm_s1'
],
body
=
[
"extern __shared__
%
s buf[]"
%
type_acc
,
"x = (const
%
s *)(((char *)x)+offset_x)"
%
type_x
,
"sm = (
%
s *)(((char *)sm)+offset_sm)"
%
type_sm
,
"for (int blockIDX = blockIdx.x; blockIDX < M;"
" blockIDX += gridDim.x){"
,
"const
%
s *x_ptr = &x[blockIDX * sx0]"
%
type_x
,
"
%
s *sm_ptr = &sm[blockIDX * sm_s0]"
%
type_sm
,
inline_softmax_fixed_shared
(
'N'
,
'buf'
,
'x_ptr'
,
'sx1'
,
load_x
,
'sm_ptr'
,
'sm_s1'
,
write_sm
,
'threadIdx.x'
,
'blockDim.x'
,
dtype
=
work_sm
),
"__syncthreads()"
,
"}"
,
])
code
=
"""
KERNEL void
%(kname)
s (const ga_size M, const ga_size N, GLOBAL_MEM const
%(type_x)
s * x, const ga_size offset_x, const ga_ssize sx0, const ga_ssize sx1,
GLOBAL_MEM
%(type_sm)
s * sm, const ga_size offset_sm, const ga_ssize sm_s0, const ga_ssize sm_s1 GA_DECL_SHARED_PARAM(
%(type_acc)
s, buf))
{
GA_DECL_SHARED_BODY(
%(type_acc)
s, buf);
x = (GLOBAL_MEM const
%(type_x)
s *)(((GLOBAL_MEM char *)x)+offset_x);
sm = (GLOBAL_MEM
%(type_sm)
s *)(((GLOBAL_MEM char *)sm)+offset_sm);
for (ga_int blockIDX = GID_0; blockIDX < M; blockIDX += GDIM_0){
GLOBAL_MEM const
%(type_x)
s *x_ptr = &x[blockIDX * sx0];
GLOBAL_MEM
%(type_sm)
s *sm_ptr = &sm[blockIDX * sm_s0];
{
// This function trashes buf[1..n_threads],
// leaving the reduction result in buf[0].
%(ctype)
s red =
%(load_x)
s(x_ptr[LID_0 * sx1]);
#pragma unroll 16
for (ga_int i = LID_0 + LDIM_0; i<N; i += LDIM_0) {
red = max(red,
%(load_x)
s(x_ptr[i * sx1]));
}
buf[LID_0] = red;
local_barrier();
if (LID_0 < GA_WARP_SIZE) {
for (ga_int i = LID_0 + GA_WARP_SIZE; i < LDIM_0; i += GA_WARP_SIZE) {
buf[LID_0] = max(buf[LID_0], buf[i]);
}
}
local_barrier();
//reduce so that LID_0 0 has the reduction of everything
for (ga_uint _n = GA_WARP_SIZE / 2; _n > 0; _n /= 2) {
if (LID_0 < _n && LID_0 + _n < N)
buf[LID_0] = max(buf[LID_0], buf[LID_0+_n]);
local_barrier();
}
}
local_barrier();
%(ctype)
s row_max = buf[0];
local_barrier();
{
// This function trashes buf[1..n_threads],
// leaving the reduction result in buf[0].
%(ctype)
s red = exp(
%(load_x)
s(x_ptr[LID_0 * sx1]) - row_max);
#pragma unroll 16
for (ga_int i = LID_0 + LDIM_0; i<N; i += LDIM_0) {
red = red + exp(
%(load_x)
s(x_ptr[i * sx1]) - row_max);
}
buf[LID_0] = red;
local_barrier();
if (LID_0 < GA_WARP_SIZE) {
for (ga_int i = LID_0 + GA_WARP_SIZE; i < LDIM_0; i += GA_WARP_SIZE) {
buf[LID_0] = buf[LID_0] + buf[i];
}
}
local_barrier();
//reduce so that LID_0 0 has the reduction of everything
for (ga_uint _n = GA_WARP_SIZE / 2; _n > 0; _n /= 2) {
if (LID_0 < _n && LID_0 + _n < N)
buf[LID_0] = buf[LID_0] + buf[LID_0+_n];
local_barrier();
}
}
local_barrier();
%(ctype)
s row_sum = buf[0];
local_barrier();
for (ga_int tx = LID_0; tx< N; tx += LDIM_0){
sm_ptr[tx * sm_s1] =
%(write_sm)
s(exp(
%(load_x)
s(x_ptr[tx * sx1]) - row_max) / row_sum);
}
local_barrier();
local_barrier();
}
}
"""
%
locals
()
kernels
.
append
(
Kernel
(
code
=
code
,
name
=
kname
,
params
=
params
,
flags
=
flags
,
objvar
=
k_var
))
return
kernels
...
...
@@ -695,14 +785,12 @@ class GpuSoftmaxWithBias(GpuKernelBase, Op):
return
[
shape
[
0
]]
def
c_code_cache_version
(
self
):
return
(
1
4
,)
+
inline_softmax
.
code_version
return
(
1
5
,)
+
inline_softmax
.
code_version
def
c_headers
(
self
):
return
[
'<numpy_compat.h>'
,
'<gpuarray/types.h>'
]
def
c_code
(
self
,
node
,
nodename
,
inp
,
out
,
sub
):
if
node
.
inputs
[
0
]
.
type
.
context
.
kind
!=
b
'cuda'
:
raise
NotImplementedError
(
'cuda only'
)
dtype_x
=
node
.
inputs
[
0
]
.
dtype
dtype_b
=
node
.
inputs
[
1
]
.
dtype
dtype_z
=
node
.
outputs
[
0
]
.
dtype
...
...
@@ -821,74 +909,181 @@ class GpuSoftmaxWithBias(GpuKernelBase, Op):
type_b
=
gpuarray
.
dtype_to_ctype
(
dtype_b
)
type_sm
=
gpuarray
.
dtype_to_ctype
(
dtype_sm
)
type_acc
=
gpuarray
.
dtype_to_ctype
(
work_sm
)
ctype
=
gpuarray
.
dtype_to_ctype
(
dtype_sm
)
params
=
[
'uintp'
,
'uintp'
,
gpuarray
.
GpuArray
,
'uintp'
,
'intp'
,
'intp'
,
gpuarray
.
GpuArray
,
'uintp'
,
'intp'
,
gpuarray
.
GpuArray
,
'uintp'
,
'intp'
,
'intp'
]
gpuarray
.
SIZE
,
gpuarray
.
SIZE
,
gpuarray
.
GpuArray
,
gpuarray
.
SIZE
,
gpuarray
.
SSIZE
,
gpuarray
.
SSIZE
,
gpuarray
.
GpuArray
,
gpuarray
.
SIZE
,
gpuarray
.
SSIZE
,
gpuarray
.
GpuArray
,
gpuarray
.
SIZE
,
gpuarray
.
SSIZE
,
gpuarray
.
SSIZE
,
]
kernels
=
[]
kname
=
"kSoftmaxWithBias"
k_var
=
"kSoftmaxWithBias_"
+
nodename
code
=
nvcc_kernel
(
kname
,
params
=
[
'const ga_size M'
,
'const ga_size N'
,
'const
%
s * x'
%
type_x
,
'const ga_size offset_x'
,
'const ga_ssize sx0'
,
'const ga_ssize sx1'
,
'const
%
s * b'
%
type_b
,
'const ga_size offset_b'
,
'const ga_ssize sb0'
,
'
%
s * sm'
%
type_sm
,
'const ga_size offset_sm'
,
'const ga_ssize sm_s0'
,
'const ga_ssize sm_s1'
],
body
=
[
"extern __shared__
%
s buf[]"
%
type_acc
,
"
%
s * buf2 = buf + N"
%
type_acc
,
"x = (const
%
s *)(((char *)x)+offset_x)"
%
type_x
,
"b = (const
%
s *)(((char *)b)+offset_b)"
%
type_b
,
"sm = (
%
s *)(((char *)sm)+offset_sm)"
%
type_sm
,
"for (int blockIDX = blockIdx.x; blockIDX < M;"
" blockIDX += gridDim.x){"
,
"for (int tx = threadIdx.x; tx< N; tx += blockDim.x){"
,
"buf[tx] =
%
s(x[blockIDX * sx0 + tx * sx1])"
%
load_x
,
"buf[tx] +=
%
s(b[tx * sb0])"
%
load_b
,
"buf2[tx] = buf[tx]"
,
"}"
,
"__syncthreads()"
,
inline_softmax
(
'N'
,
'buf'
,
'buf2'
,
'threadIdx.x'
,
'blockDim.x'
,
work_sm
),
"for (int tx = threadIdx.x; tx< N; tx += blockDim.x){"
,
"sm[blockIDX * sm_s0 + tx * sm_s1] =
%
s(buf[tx])"
%
write_sm
,
"}"
,
"__syncthreads()"
,
"}"
,
])
code
=
"""
KERNEL void
%(kname)
s (const ga_size M, const ga_size N,
GLOBAL_MEM const
%(type_x)
s * x, const ga_size offset_x, const ga_ssize sx0, const ga_ssize sx1,
GLOBAL_MEM const
%(type_b)
s * b, const ga_size offset_b, const ga_ssize sb0,
GLOBAL_MEM
%(type_sm)
s * sm, const ga_size offset_sm, const ga_ssize sm_s0, const ga_ssize sm_s1 GA_DECL_SHARED_PARAM(
%(type_acc)
s, buf))
{
GA_DECL_SHARED_BODY(
%(type_acc)
s, buf);
LOCAL_MEM
%(type_acc)
s * buf2 = buf + N;
x = (GLOBAL_MEM const
%(type_x)
s *)(((GLOBAL_MEM char *)x)+offset_x);
b = (GLOBAL_MEM const
%(type_b)
s *)(((GLOBAL_MEM char *)b)+offset_b);
sm = (GLOBAL_MEM
%(type_sm)
s *)(((GLOBAL_MEM char *)sm)+offset_sm);
for (ga_int blockIDX = GID_0; blockIDX < M; blockIDX += GDIM_0){
for (ga_int tx = LID_0; tx< N; tx += LDIM_0){
buf[tx] =
%(load_x)
s(x[blockIDX * sx0 + tx * sx1]);
buf[tx] +=
%(load_b)
s(b[tx * sb0]);
buf2[tx] = buf[tx];
}
local_barrier();
{
// This function trashes buf[1..GA_WARP_SIZE],
// leaving the reduction result in buf[0].
if (LID_0 < GA_WARP_SIZE) {
for (ga_int i = LID_0 + GA_WARP_SIZE; i < N; i += GA_WARP_SIZE)
{
buf[LID_0] = max(buf[LID_0], buf[i]);
}
}
local_barrier();
//reduce so that LID_0 0 has the reduction of everything
for (ga_uint _n = GA_WARP_SIZE / 2; _n > 0; _n /= 2) {
if (LID_0 < _n && LID_0 + _n < N)
buf[LID_0] = max(buf[LID_0], buf[LID_0+_n]);
local_barrier();
}
}
local_barrier();
%(ctype)
s row_max = buf[0];
local_barrier();
for(ga_int __i=LID_0; __i<N; __i+=LDIM_0){;
buf[__i] = exp(buf2[__i] - row_max);
buf2[__i] = buf[__i];
}
local_barrier();
{
// This function trashes buf[1..GA_WARP_SIZE],
// leaving the reduction result in buf[0].
if (LID_0 < GA_WARP_SIZE) {
for (ga_int i = LID_0 + GA_WARP_SIZE; i < N; i += GA_WARP_SIZE)
{
buf[LID_0] = buf[LID_0] + buf[i];
}
}
local_barrier();
//reduce so that LID_0 0 has the reduction of everything
for (ga_uint _n = GA_WARP_SIZE / 2; _n > 0; _n /= 2) {
if (LID_0 < _n && LID_0 + _n < N)
buf[LID_0] = buf[LID_0] + buf[LID_0+_n];
local_barrier();
}
}
local_barrier();
%(ctype)
s row_sum = buf[0];
local_barrier();
for(ga_int __i=LID_0; __i<N; __i+=LDIM_0){
buf[__i] = buf2[__i] / row_sum;
}
local_barrier();
for (ga_int tx = LID_0; tx< N; tx += LDIM_0){
sm[blockIDX * sm_s0 + tx * sm_s1] =
%(write_sm)
s(buf[tx]);
}
local_barrier();
}
}
"""
%
locals
()
kernels
.
append
(
Kernel
(
code
=
code
,
name
=
kname
,
params
=
params
,
flags
=
flags
,
objvar
=
k_var
))
kname
=
"kSoftmaxWithBias_fixed_shared"
k_var
=
"kSoftmaxWithBias_fixed_shared"
+
nodename
code
=
nvcc_kernel
(
kname
,
params
=
[
'const ga_size M'
,
'const ga_size N'
,
'const
%
s * x'
%
type_x
,
'const ga_size offset_x'
,
'const ga_ssize sx0'
,
'const ga_ssize sx1'
,
'const
%
s * b'
%
type_b
,
'const ga_size offset_b'
,
'const ga_ssize sb0'
,
'
%
s * sm'
%
type_sm
,
'const ga_size offset_sm'
,
'const ga_ssize sm_s0'
,
'const ga_ssize sm_s1'
],
body
=
[
"extern __shared__
%
s buf[]"
%
type_acc
,
"x = (const
%
s *)(((char *)x)+offset_x)"
%
type_x
,
"b = (const
%
s *)(((char *)b)+offset_b)"
%
type_b
,
"sm = (
%
s *)(((char *)sm)+offset_sm)"
%
type_sm
,
"for (int blockIDX = blockIdx.x; blockIDX < M;"
" blockIDX += gridDim.x){"
,
"const
%
s *x_ptr = &x[blockIDX * sx0]"
%
type_x
,
"
%
s *sm_ptr = &sm[blockIDX * sm_s0]"
%
type_sm
,
inline_softmax_fixed_shared
(
'N'
,
'buf'
,
'x_ptr'
,
'sx1'
,
load_x
,
'sm_ptr'
,
'sm_s1'
,
write_sm
,
'threadIdx.x'
,
'blockDim.x'
,
'b'
,
'sb0'
,
load_b
,
work_sm
),
"__syncthreads()"
,
"}"
,
])
code
=
"""
KERNEL void
%(kname)
s (const ga_size M, const ga_size N,
GLOBAL_MEM const
%(type_x)
s * x, const ga_size offset_x, const ga_ssize sx0, const ga_ssize sx1,
GLOBAL_MEM const
%(type_b)
s * b, const ga_size offset_b, const ga_ssize sb0,
GLOBAL_MEM
%(type_sm)
s * sm, const ga_size offset_sm, const ga_ssize sm_s0, const ga_ssize sm_s1 GA_DECL_SHARED_PARAM(
%(type_acc)
s, buf))
{
GA_DECL_SHARED_BODY(
%(type_acc)
s, buf);
x = (GLOBAL_MEM const
%(type_x)
s *)(((GLOBAL_MEM char *)x)+offset_x);
b = (GLOBAL_MEM const
%(type_b)
s *)(((GLOBAL_MEM char *)b)+offset_b);
sm = (GLOBAL_MEM
%(type_sm)
s *)(((GLOBAL_MEM char *)sm)+offset_sm);
for (ga_int blockIDX = GID_0; blockIDX < M; blockIDX += GDIM_0){
GLOBAL_MEM const
%(type_x)
s *x_ptr = &x[blockIDX * sx0];
GLOBAL_MEM
%(type_sm)
s *sm_ptr = &sm[blockIDX * sm_s0];
{
// This function trashes buf[1..n_threads],
// leaving the reduction result in buf[0].
%(ctype)
s red =
%(load_x)
s(x_ptr[LID_0 * sx1]) +
%(load_b)
s(b[LID_0 * sb0]);
#pragma unroll 16
for (ga_int i = LID_0 + LDIM_0; i<N; i += LDIM_0) {
red = max(red,
%(load_x)
s(x_ptr[i * sx1]) +
%(load_b)
s(b[i * sb0]));
}
buf[LID_0] = red;
local_barrier();
if (LID_0 < GA_WARP_SIZE) {
for (ga_int i = LID_0 + GA_WARP_SIZE; i < LDIM_0; i += GA_WARP_SIZE) {
buf[LID_0] = max(buf[LID_0], buf[i]);
}
}
local_barrier();
//reduce so that LID_0 0 has the reduction of everything
for (ga_uint _n = GA_WARP_SIZE / 2; _n > 0; _n /= 2) {
if (LID_0 < _n && LID_0 + _n < N)
buf[LID_0] = max(buf[LID_0], buf[LID_0+_n]);
local_barrier();
}
}
local_barrier();
%(ctype)
s row_max = buf[0];
local_barrier();
{
// This function trashes buf[1..n_threads],
// leaving the reduction result in buf[0].
%(ctype)
s red = exp(
%(load_x)
s(x_ptr[LID_0 * sx1]) +
%(load_b)
s(b[LID_0 * sb0]) - row_max);
#pragma unroll 16
for (ga_int i = LID_0 + LDIM_0; i<N; i += LDIM_0) {
red = red + exp(
%(load_x)
s(x_ptr[i * sx1]) +
%(load_b)
s(b[i * sb0]) - row_max);
}
buf[LID_0] = red;
local_barrier();
if (LID_0 < GA_WARP_SIZE) {
for (ga_int i = LID_0 + GA_WARP_SIZE; i < LDIM_0; i += GA_WARP_SIZE) {
buf[LID_0] = buf[LID_0] + buf[i];
}
}
local_barrier();
//reduce so that LID_0 0 has the reduction of everything
for (ga_uint _n = GA_WARP_SIZE / 2; _n > 0; _n /= 2) {
if (LID_0 < _n && LID_0 + _n < N)
buf[LID_0] = buf[LID_0] + buf[LID_0+_n];
local_barrier();
}
}
local_barrier();
%(ctype)
s row_sum = buf[0];
local_barrier();
for (ga_int tx = LID_0; tx< N; tx += LDIM_0){
sm_ptr[tx * sm_s1] =
%(write_sm)
s(exp(
%(load_x)
s(x_ptr[tx * sx1]) +
%(load_b)
s(b[tx * sb0]) - row_max) / row_sum);
}
local_barrier();
local_barrier();
}
}
"""
%
locals
()
kernels
.
append
(
Kernel
(
code
=
code
,
name
=
kname
,
params
=
params
,
flags
=
flags
,
objvar
=
k_var
))
return
kernels
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论