提交 fc36eefb authored 作者: xiaoqie's avatar xiaoqie

cuda fix

All tests in test_nnet.py pass with CUDA. Only fp32 tests in test_nnet.py pass with OpenCL. GpuFromHost doesn't work with fp16 or fp64. Larger work item size doesn't improve performance. Add 2 local_barrier(), it's strange that AMD card doesn't need these local_barrier(), but they are necessary for NVIDIA cards.
上级 2c6d7e6e
......@@ -91,25 +91,18 @@ class GpuCrossentropySoftmaxArgmax1HotWithBias(GpuKernelBase, Op):
nll_data = (GLOBAL_MEM %(type_x)s *)(((GLOBAL_MEM char *)nll_data)+offset_nll);
sm_data = (GLOBAL_MEM %(type_x)s *)(((GLOBAL_MEM char *)sm_data)+offset_sm);
am_data = (GLOBAL_MEM %(type_y_idx)s *)(((GLOBAL_MEM char *)am_data)+offset_am);
for (ga_int row = GID_0; row < M; row += GDIM_0){
GLOBAL_MEM const %(type_x)s* x = x_data + xs0 * row;
GLOBAL_MEM %(type_x)s* sm = sm_data + sms0 * row;
GA_DECL_SHARED_BODY(%(work_x)s, per_thread_values);
LOCAL_MEM %(work_x)s row_max, sum, sum_inv;
LOCAL_MEM ga_int row_max_threadIdx;
%(work_x)s per_thread_row_max, per_thread_sum;
ga_int per_thread_row_max_j;
// COMPUTE ROW MAX AND ARGMAX
// compute separate per-thread maximums and argmaxes
per_thread_row_max = NAN;
per_thread_row_max_j = 0;
for (ga_int j = LID_0; j < N; j += LDIM_0)
{
%(work_x)s row_ij = %(load_x)s(x[j * xs1]) + %(load_b)s(b[j * bs0]);
......@@ -117,9 +110,7 @@ class GpuCrossentropySoftmaxArgmax1HotWithBias(GpuKernelBase, Op):
per_thread_row_max = fmax%(f)s(row_ij, per_thread_row_max);
}
per_thread_values[LID_0] = per_thread_row_max;
local_barrier();
if (LID_0 == 0) {
row_max = NAN;
row_max_threadIdx = 0;
......@@ -130,13 +121,10 @@ class GpuCrossentropySoftmaxArgmax1HotWithBias(GpuKernelBase, Op):
row_max = fmax%(f)s(per_thread_max, row_max);
}
}
local_barrier();
// The thread with the higest max writes out which of its
// values was the winner.
if (LID_0 == row_max_threadIdx) am_data[row * ams0] = per_thread_row_max_j;
// COMPUTE SOFTMAX
per_thread_sum = 0.0;
for (ga_int j = LID_0; j < N; j += LDIM_0)
......@@ -146,11 +134,8 @@ class GpuCrossentropySoftmaxArgmax1HotWithBias(GpuKernelBase, Op):
per_thread_sum += sm_ij;
sm[j * sms1] = %(write_x)s(sm_ij);
}
per_thread_values[LID_0] = per_thread_sum;
local_barrier();
if (LID_0 == 0) {
sum = 0.0;
for (ga_int j = 0; j < LDIM_0; j++) {
......@@ -158,13 +143,10 @@ class GpuCrossentropySoftmaxArgmax1HotWithBias(GpuKernelBase, Op):
}
sum_inv = 1.0 / sum;
}
local_barrier();
for (ga_int j = LID_0; j < N; j += LDIM_0) {
sm[j * sms1] = %(write_x)s(%(load_x)s(sm[j * sms1]) * sum_inv);
}
if (LID_0 == 0) {
const %(type_y_idx)s y_idx = (ga_int)y_idx_data[row * y_idxs0];
if ((y_idx >= N || y_idx < 0)) {
......@@ -325,13 +307,11 @@ class GpuCrossentropySoftmax1HotWithBiasDx(GpuKernelBase, Op):
const ssize_t %(dnll)s_dims0 = (PyGpuArray_NDIM(%(dnll)s) > 0 ?
PyGpuArray_DIMS(%(dnll)s)[0] :
(ssize_t) 0);
// Get `dnll.strides[0]` and set it to zero if `dnll` is a scalar
// or a vector with just one element.
const ssize_t %(dnll)s_strides0 = (%(dnll)s_dims0 > 1 ?
PyGpuArray_STRIDES(%(dnll)s)[0] :
(ssize_t) 0);
if ((PyGpuArray_NDIM(%(dnll)s) > 1)
|| (PyGpuArray_NDIM(%(sm)s) != 2)
|| (PyGpuArray_NDIM(%(y_idx)s) != 1))
......@@ -440,12 +420,10 @@ class GpuCrossentropySoftmax1HotWithBiasDx(GpuKernelBase, Op):
sm = (GLOBAL_MEM const %(type_sm)s *)(((GLOBAL_MEM char *)sm)+offset_sm);
y_idx = (GLOBAL_MEM const %(type_y_idx)s *)(((GLOBAL_MEM char *)y_idx)+offset_y_idx);
dx = (GLOBAL_MEM %(type_dx)s *)(((GLOBAL_MEM char *)dx)+offset_dx);
for (ga_int i = GID_0; i < N; i += GDIM_0)
{
%(wtype_dnll)s dnll_i = %(load_dnll)s(dnll[i * dnll_s0]);
%(type_y_idx)s y_i = y_idx[i * y_idx_s0];
for (ga_int j = LID_0; j < K; j += LDIM_0)
{
if (y_i == j)
......@@ -610,8 +588,7 @@ class GpuSoftmax(GpuKernelBase, Op):
GLOBAL_MEM %(type_sm)s * sm, const ga_size offset_sm, const ga_ssize sm_s0, const ga_ssize sm_s1 GA_DECL_SHARED_PARAM(%(type_acc)s, buf))
{
GA_DECL_SHARED_BODY(%(type_acc)s, buf);
LOCAL_MEM %(type_acc)s * buf2 = buf + N;
LOCAL_MEM_ARG %(type_acc)s * buf2 = buf + N;
x = (GLOBAL_MEM const %(type_x)s *)(((GLOBAL_MEM char *)x)+offset_x);
sm = (GLOBAL_MEM %(type_sm)s *)(((GLOBAL_MEM char *)sm)+offset_sm);
for (ga_int blockIDX = GID_0; blockIDX < M; blockIDX += GDIM_0) {
......@@ -620,11 +597,9 @@ class GpuSoftmax(GpuKernelBase, Op):
buf2[tx] = buf[tx];
}
local_barrier();
{
// This function trashes buf[1..GA_WARP_SIZE],
// leaving the reduction result in buf[0].
if (LID_0 < GA_WARP_SIZE) {
for (ga_int i = LID_0 + GA_WARP_SIZE; i < N; i += GA_WARP_SIZE)
{
......@@ -646,11 +621,9 @@ class GpuSoftmax(GpuKernelBase, Op):
buf2[__i] = buf[__i];
}
local_barrier();
{
// This function trashes buf[1..GA_WARP_SIZE],
// leaving the reduction result in buf[0].
if (LID_0 < GA_WARP_SIZE) {
for (ga_int i = LID_0 + GA_WARP_SIZE; i < N; i += GA_WARP_SIZE)
{
......@@ -666,6 +639,7 @@ class GpuSoftmax(GpuKernelBase, Op):
}
}
%(ctype)s row_sum = buf[0];
local_barrier();
for(ga_int __i=LID_0; __i<N; __i+=LDIM_0) {
buf[__i] = buf2[__i] / row_sum;
}
......@@ -687,13 +661,11 @@ class GpuSoftmax(GpuKernelBase, Op):
GLOBAL_MEM %(type_sm)s * sm, const ga_size offset_sm, const ga_ssize sm_s0, const ga_ssize sm_s1 GA_DECL_SHARED_PARAM(%(type_acc)s, buf))
{
GA_DECL_SHARED_BODY(%(type_acc)s, buf);
x = (GLOBAL_MEM const %(type_x)s *)(((GLOBAL_MEM char *)x)+offset_x);
sm = (GLOBAL_MEM %(type_sm)s *)(((GLOBAL_MEM char *)sm)+offset_sm);
for (ga_int blockIDX = GID_0; blockIDX < M; blockIDX += GDIM_0){
GLOBAL_MEM const %(type_x)s *x_ptr = &x[blockIDX * sx0];
GLOBAL_MEM %(type_sm)s *sm_ptr = &sm[blockIDX * sm_s0];
{
// This function trashes buf[1..n_threads],
// leaving the reduction result in buf[0].
......@@ -719,7 +691,6 @@ class GpuSoftmax(GpuKernelBase, Op):
}
%(ctype)s row_max = buf[0];
local_barrier();
{
// This function trashes buf[1..n_threads],
// leaving the reduction result in buf[0].
......@@ -743,8 +714,8 @@ class GpuSoftmax(GpuKernelBase, Op):
local_barrier();
}
}
%(ctype)s row_sum = buf[0];
local_barrier();
for (ga_int tx = LID_0; tx< N; tx += LDIM_0){
sm_ptr[tx * sm_s1] = %(write_sm)s(exp(%(load_x)s(x_ptr[tx * sx1]) - row_max) / row_sum);
}
......@@ -923,8 +894,7 @@ class GpuSoftmaxWithBias(GpuKernelBase, Op):
GLOBAL_MEM %(type_sm)s * sm, const ga_size offset_sm, const ga_ssize sm_s0, const ga_ssize sm_s1 GA_DECL_SHARED_PARAM(%(type_acc)s, buf))
{
GA_DECL_SHARED_BODY(%(type_acc)s, buf);
LOCAL_MEM %(type_acc)s * buf2 = buf + N;
LOCAL_MEM_ARG %(type_acc)s * buf2 = buf + N;
x = (GLOBAL_MEM const %(type_x)s *)(((GLOBAL_MEM char *)x)+offset_x);
b = (GLOBAL_MEM const %(type_b)s *)(((GLOBAL_MEM char *)b)+offset_b);
sm = (GLOBAL_MEM %(type_sm)s *)(((GLOBAL_MEM char *)sm)+offset_sm);
......@@ -935,11 +905,9 @@ class GpuSoftmaxWithBias(GpuKernelBase, Op):
buf2[tx] = buf[tx];
}
local_barrier();
{
// This function trashes buf[1..GA_WARP_SIZE],
// leaving the reduction result in buf[0].
if (LID_0 < GA_WARP_SIZE) {
for (ga_int i = LID_0 + GA_WARP_SIZE; i < N; i += GA_WARP_SIZE)
{
......@@ -954,7 +922,6 @@ class GpuSoftmaxWithBias(GpuKernelBase, Op):
local_barrier();
}
}
%(ctype)s row_max = buf[0];
local_barrier();
for(ga_int __i=LID_0; __i<N; __i+=LDIM_0){;
......@@ -962,11 +929,9 @@ class GpuSoftmaxWithBias(GpuKernelBase, Op):
buf2[__i] = buf[__i];
}
local_barrier();
{
// This function trashes buf[1..GA_WARP_SIZE],
// leaving the reduction result in buf[0].
if (LID_0 < GA_WARP_SIZE) {
for (ga_int i = LID_0 + GA_WARP_SIZE; i < N; i += GA_WARP_SIZE)
{
......@@ -981,8 +946,8 @@ class GpuSoftmaxWithBias(GpuKernelBase, Op):
local_barrier();
}
}
%(ctype)s row_sum = buf[0];
local_barrier();
for(ga_int __i=LID_0; __i<N; __i+=LDIM_0){
buf[__i] = buf2[__i] / row_sum;
}
......@@ -1005,14 +970,12 @@ class GpuSoftmaxWithBias(GpuKernelBase, Op):
GLOBAL_MEM %(type_sm)s * sm, const ga_size offset_sm, const ga_ssize sm_s0, const ga_ssize sm_s1 GA_DECL_SHARED_PARAM(%(type_acc)s, buf))
{
GA_DECL_SHARED_BODY(%(type_acc)s, buf);
x = (GLOBAL_MEM const %(type_x)s *)(((GLOBAL_MEM char *)x)+offset_x);
b = (GLOBAL_MEM const %(type_b)s *)(((GLOBAL_MEM char *)b)+offset_b);
sm = (GLOBAL_MEM %(type_sm)s *)(((GLOBAL_MEM char *)sm)+offset_sm);
for (ga_int blockIDX = GID_0; blockIDX < M; blockIDX += GDIM_0){
GLOBAL_MEM const %(type_x)s *x_ptr = &x[blockIDX * sx0];
GLOBAL_MEM %(type_sm)s *sm_ptr = &sm[blockIDX * sm_s0];
{
// This function trashes buf[1..n_threads],
// leaving the reduction result in buf[0].
......@@ -1036,7 +999,6 @@ class GpuSoftmaxWithBias(GpuKernelBase, Op):
local_barrier();
}
}
%(ctype)s row_max = buf[0];
local_barrier();
{
......@@ -1062,8 +1024,8 @@ class GpuSoftmaxWithBias(GpuKernelBase, Op):
local_barrier();
}
}
%(ctype)s row_sum = buf[0];
local_barrier();
for (ga_int tx = LID_0; tx< N; tx += LDIM_0){
sm_ptr[tx * sm_s1] = %(write_sm)s(exp(%(load_x)s(x_ptr[tx * sx1]) + %(load_b)s(b[tx * sb0]) - row_max) / row_sum);
}
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论