提交 e9d5e0ac authored 作者: Matthew Willson's avatar Matthew Willson

Improve performance of GpuCrossentropySoftmaxArgmax1HotWithBias significantly by…

Improve performance of GpuCrossentropySoftmaxArgmax1HotWithBias significantly by implementing a TODO: launch more threads per row and do parallel sum and max reductions
上级 e9e94b4c
...@@ -36,61 +36,124 @@ class GpuCrossentropySoftmaxArgmax1HotWithBias(GpuOp): ...@@ -36,61 +36,124 @@ class GpuCrossentropySoftmaxArgmax1HotWithBias(GpuOp):
am = y_idx.type() am = y_idx.type()
return Apply(self, [x, b, y_idx], [nll, sm, am]) return Apply(self, [x, b, y_idx], [nll, sm, am])
def c_headers(self):
return ['<float.h>']
def c_support_code(self): def c_support_code(self):
return """ return """
__global__ void k_xent_sm_1hot_bias(int M, int N, __global__ void k_xent_sm_1hot_bias(const int M, const int N,
const float * x_data, int xs0, int xs1, const float * x_data, const int xs0, const int xs1,
const float * b, int bs0, const float * b, const int bs0,
const float * y_idx_data, int y_idxs0, const float * y_idx_data, const int y_idxs0,
float * nll_data, int nlls0, float * nll_data, const int nlls0,
float * sm_data, int sms0, int sms1, float * sm_data, const int sms0, const int sms1,
float * am_data, int ams0) float * am_data, const int ams0)
{ {
for (int row = blockIdx.x; row < M; row += gridDim.x){ for (int row = blockIdx.x; row < M; row += gridDim.x){
const float * x = x_data + xs0 * row; const float * x = x_data + xs0 * row;
const int y_idx = (int)y_idx_data[row * y_idxs0]; float * sm = sm_data + sms0 * row;
float * sm = sm_data + sms0 * row;
extern __shared__ float per_thread_values[];
float sum = 0.0; __shared__ float row_max, sum, sum_inv;
int row_max_j = 0; __shared__ int row_max_threadIdx;
float row_max = x[0] + b[0];
for (int j = 1; j < N; ++j) float per_thread_row_max, per_thread_sum;
{ int per_thread_row_max_j;
float row_ij = x[j*xs1] + b[j*bs0];
//todo: store to shared memory // COMPUTE ROW MAX AND ARGMAX
row_max_j = (row_ij > row_max) ? j : row_max_j;
row_max = (row_ij > row_max) ? row_ij : row_max; // compute separate per-thread maximums and argmax's
} per_thread_row_max = -FLT_MAX;
//compute the exp per_thread_row_max_j = 0;
for (int j = 0; j < N; ++j) for (int j = threadIdx.x; j < N; j += blockDim.x)
{ {
float row_ij = x[j*xs1] + b[j*bs0]; float row_ij = x[j*xs1] + b[j*bs0];
float sm_ij = exp(row_ij - row_max); per_thread_row_max_j = (row_ij > per_thread_row_max) ? j : per_thread_row_max_j;
sum += sm_ij; per_thread_row_max = fmaxf(row_ij, per_thread_row_max);
sm[j * sms1] = sm_ij; }
} per_thread_values[threadIdx.x] = per_thread_row_max;
float sum_inv = 1.0 / sum;
for (int j = 0; j < N; ++j) // wait for access to shared per_thread_values to do final
{ // reduction in thread 0
sm[j * sms1] *= sum_inv; __syncthreads();
}
if ((y_idx >= N) || (y_idx < 0)) // Finish the reduction in one go in a single thread. Could be
{ // smarter about this with more hierarchical reductions but think
//TODO: set raise an error bit in a global var? // this will do for now.
nll_data[row*nlls0] = 0.0; // raise some suspicion at least... if (threadIdx.x == 0) {
} // compute overall maximum and the id of the thread which has it
else row_max = -FLT_MAX;
{ row_max_threadIdx = 0;
nll_data[row*nlls0] = - x[y_idx*xs1] for (int j = 0; j < blockDim.x; ++j)
- b[y_idx*bs0] {
+ row_max float per_thread_max = per_thread_values[j];
+ log(sum); row_max_threadIdx = (per_thread_max > row_max) ? j : row_max_threadIdx;
} row_max = fmaxf(per_thread_max, row_max);
am_data[row*ams0] = row_max_j; }
} }
}
// all threads wait for access to shared row_max and row_maxThreadIdx
__syncthreads();
// thread whose max was the overall max writes out the overall argmax:
if (threadIdx.x == row_max_threadIdx) am_data[row*ams0] = per_thread_row_max_j;
// COMPUTE SOFTMAX
// compute the exp and the per-thread sums of exps
per_thread_sum = 0.0;
for (int j = threadIdx.x; j < N; j += blockDim.x)
{
float row_ij = x[j*xs1] + b[j*bs0];
float sm_ij = __expf(row_ij - row_max);
per_thread_sum += sm_ij;
sm[j * sms1] = sm_ij;
}
per_thread_values[threadIdx.x] = per_thread_sum;
// wait for access to shared per_thread_values to do final
// reduction in thread 0
__syncthreads();
if (threadIdx.x == 0) {
// compute overall sum
sum = 0.0;
for (int j = 0; j < blockDim.x; ++j)
{
sum += per_thread_values[j];
}
sum_inv = 1.0 / sum;
}
// all threads wait for access to shared sum, sum_inv
__syncthreads();
// all threads normalize their softmax result using sum_inv
for (int j = threadIdx.x; j < N; j += blockDim.x)
{
sm[j * sms1] *= sum_inv;
}
// COMPUTE NEGATIVE LOG-LIKELIHOOD FOR TARGET INDEX
if (threadIdx.x == 0) {
const int y_idx = (int)y_idx_data[row * y_idxs0];
if ((y_idx >= N) || (y_idx < 0))
{
//TODO: set raise an error bit in a global var?
nll_data[row*nlls0] = 0.0; // raise some suspicion at least...
}
else
{
nll_data[row*nlls0] = - x[y_idx*xs1]
- b[y_idx*bs0]
+ row_max
+ logf(sum);
}
}
}
}
""" """
def c_code(self, node, nodename, inp, out, sub): def c_code(self, node, nodename, inp, out, sub):
...@@ -176,10 +239,9 @@ class GpuCrossentropySoftmaxArgmax1HotWithBias(GpuOp): ...@@ -176,10 +239,9 @@ class GpuCrossentropySoftmaxArgmax1HotWithBias(GpuOp):
{ {
int n_blocks = std::min(CudaNdarray_HOST_DIMS(%(x)s)[0], int n_blocks = std::min(CudaNdarray_HOST_DIMS(%(x)s)[0],
NUM_VECTOR_OP_BLOCKS); NUM_VECTOR_OP_BLOCKS);
//TODO: launch more threads per row and do parallel sum and max reductions int n_threads = std::min(CudaNdarray_HOST_DIMS(%(x)s)[1],
int n_threads = 1; NUM_VECTOR_OP_THREADS_PER_BLOCK);
int n_shared_bytes = 0; //n_threads * sizeof(float); int n_shared_bytes = n_threads * sizeof(float);
k_xent_sm_1hot_bias<<<n_blocks, n_threads, n_shared_bytes>>>( k_xent_sm_1hot_bias<<<n_blocks, n_threads, n_shared_bytes>>>(
CudaNdarray_HOST_DIMS(%(x)s)[0], CudaNdarray_HOST_DIMS(%(x)s)[0],
CudaNdarray_HOST_DIMS(%(x)s)[1], CudaNdarray_HOST_DIMS(%(x)s)[1],
...@@ -216,7 +278,7 @@ class GpuCrossentropySoftmaxArgmax1HotWithBias(GpuOp): ...@@ -216,7 +278,7 @@ class GpuCrossentropySoftmaxArgmax1HotWithBias(GpuOp):
def c_code_cache_version(self): def c_code_cache_version(self):
# return () # return ()
return (4,) return (5,)
gpu_crossentropy_softmax_argmax_1hot_with_bias = GpuCrossentropySoftmaxArgmax1HotWithBias() gpu_crossentropy_softmax_argmax_1hot_with_bias = GpuCrossentropySoftmaxArgmax1HotWithBias()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论