提交 abd0a0fc authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Add support for float16 to one of the versions of GpuCrossentropy...

上级 cc7c365c
def work_dtype(dtype):
if dtype == 'float16':
return 'float32'
else:
return dtype
def load_w(dtype):
if dtype == 'float16':
return '__half2float'
else:
return ''
def write_w(dtype):
if dtype == 'float16':
return '__float2half_rn'
else:
return ''
...@@ -16,6 +16,7 @@ from .type import GpuArrayType ...@@ -16,6 +16,7 @@ from .type import GpuArrayType
from .kernel_codegen import (nvcc_kernel, from .kernel_codegen import (nvcc_kernel,
inline_softmax, inline_softmax,
inline_softmax_fixed_shared) inline_softmax_fixed_shared)
from .fp16_help import work_dtype, load_w, write_w
class GpuCrossentropySoftmaxArgmax1HotWithBias(Op): class GpuCrossentropySoftmaxArgmax1HotWithBias(Op):
...@@ -52,6 +53,12 @@ class GpuCrossentropySoftmaxArgmax1HotWithBias(Op): ...@@ -52,6 +53,12 @@ class GpuCrossentropySoftmaxArgmax1HotWithBias(Op):
dtype_x = node.inputs[0].dtype dtype_x = node.inputs[0].dtype
dtype_b = node.inputs[1].dtype dtype_b = node.inputs[1].dtype
dtype_y_idx = node.inputs[2].dtype dtype_y_idx = node.inputs[2].dtype
work_x = work_dtype(dtype_x)
work_b = work_dtype(dtype_b)
load_x = load_w(dtype_x)
load_b = load_w(dtype_b)
write_x = write_w(dtype_x)
write_b = write_w(dtype_b)
return """ return """
__global__ void k_xent_sm_1hot_bias_%(nodename)s(int M, int N, __global__ void k_xent_sm_1hot_bias_%(nodename)s(int M, int N,
const npy_%(dtype_x)s* x_data, int xs0, int xs1, const npy_%(dtype_x)s* x_data, int xs0, int xs1,
...@@ -67,12 +74,13 @@ class GpuCrossentropySoftmaxArgmax1HotWithBias(Op): ...@@ -67,12 +74,13 @@ class GpuCrossentropySoftmaxArgmax1HotWithBias(Op):
const npy_%(dtype_y_idx)s y_idx = y_idx_data[row * y_idxs0]; const npy_%(dtype_y_idx)s y_idx = y_idx_data[row * y_idxs0];
npy_%(dtype_x)s* sm = sm_data + sms0 * row; npy_%(dtype_x)s* sm = sm_data + sms0 * row;
npy_%(dtype_x)s sum = 0.0; npy_%(work_x)s sum = 0.0;
int row_max_j = 0; int row_max_j = 0;
npy_%(dtype_x)s row_max = x[0] + b[0]; npy_%(work_x)s row_max = %(load_x)s(x[0]) + %(load_b)s(b[0]);
for (int j = 1; j < N; ++j) for (int j = 1; j < N; ++j)
{ {
npy_%(dtype_x)s row_ij = x[j*xs1] + b[j*bs0]; npy_%(work_x)s row_ij = %(load_x)s(x[j*xs1]) +
%(load_b)s(b[j*bs0]);
//todo: store to shared memory //todo: store to shared memory
row_max_j = (row_ij > row_max) ? j : row_max_j; row_max_j = (row_ij > row_max) ? j : row_max_j;
row_max = (row_ij > row_max) ? row_ij : row_max; row_max = (row_ij > row_max) ? row_ij : row_max;
...@@ -80,27 +88,30 @@ class GpuCrossentropySoftmaxArgmax1HotWithBias(Op): ...@@ -80,27 +88,30 @@ class GpuCrossentropySoftmaxArgmax1HotWithBias(Op):
//compute the exp //compute the exp
for (int j = 0; j < N; ++j) for (int j = 0; j < N; ++j)
{ {
npy_%(dtype_x)s row_ij = x[j*xs1] + b[j*bs0]; npy_%(work_x)s row_ij = %(load_x)s(x[j*xs1]) +
npy_%(dtype_x)s sm_ij = exp(row_ij - row_max); %(load_b)s(b[j*bs0]);
npy_%(work_x)s sm_ij = exp(row_ij - row_max);
sum += sm_ij; sum += sm_ij;
sm[j * sms1] = sm_ij; sm[j * sms1] = %(write_x)s(sm_ij);
} }
npy_%(dtype_x)s sum_inv = 1.0 / sum; npy_%(work_x)s sum_inv = 1.0 / sum;
for (int j = 0; j < N; ++j) for (int j = 0; j < N; ++j)
{ {
sm[j * sms1] *= sum_inv; npy_%(work_x)s __tmp = %(load_x)s(sm[j * sms1]);
__tmp *= sum_inv;
sm[j * sms1] = %(write_x)s(__tmp);
} }
if ((y_idx >= N) || (y_idx < 0)) if ((y_idx >= N) || (y_idx < 0))
{ {
//TODO: set raise an error bit in a global var? //TODO: set raise an error bit in a global var?
nll_data[row*nlls0] = 0.0; // raise some suspicion at least... nll_data[row*nlls0] = %(write_x)s(0.0); // raise some suspicion at least...
} }
else else
{ {
nll_data[row*nlls0] = - x[y_idx*xs1] nll_data[row*nlls0] = %(write_x)s(- %(load_x)s(x[y_idx*xs1])
- b[y_idx*bs0] - %(load_b)s(b[y_idx*bs0])
+ row_max + row_max
+ log(sum); + log(sum));
} }
am_data[row*ams0] = row_max_j; am_data[row*ams0] = row_max_j;
} }
...@@ -259,7 +270,6 @@ class GpuCrossentropySoftmaxArgmax1HotWithBias(Op): ...@@ -259,7 +270,6 @@ class GpuCrossentropySoftmaxArgmax1HotWithBias(Op):
return sio.getvalue() return sio.getvalue()
def c_code_cache_version(self): def c_code_cache_version(self):
# return ()
return (5,) return (5,)
def c_compiler(self): def c_compiler(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论