提交 20b23e6e authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Normalize 1 scalar instead of a full vector

It also fixes a race condition.
上级 aa2bb423
...@@ -309,39 +309,29 @@ KERNEL void k_multi_warp_multinomial_wor( ...@@ -309,39 +309,29 @@ KERNEL void k_multi_warp_multinomial_wor(
if (n < nb_multi) if (n < nb_multi)
{ {
// Sum of the remaining p_vals in global_pvals_copy[n]
float pvals_sum = 1.;
for (int c = 0; c < n_samples; ++c) for (int c = 0; c < n_samples; ++c)
{ {
float cummul = 0.; float cummul = 0.;
bool done = false; const float unis_n = global_unis[(c * nb_multi + n)*unis_stride] * pvals_sum;
const float unis_n = global_unis[(c * nb_multi + n)*unis_stride];
for (ga_size m = 0; m < nb_outcomes; ++m) for (ga_size m = 0; m < nb_outcomes; ++m)
{ {
float pvals_nm = global_pvals_copy[m * pvals_col_stride + n * pvals_row_stride]; float pvals_nm = global_pvals_copy[m * pvals_col_stride + n * pvals_row_stride];
cummul += pvals_nm; cummul += pvals_nm;
if (!done && unis_n < cummul) if (unis_n < cummul)
{ {
//write out transposed for speed. // write out transposed for speed.
global_outs[n * outs_col_stride + global_outs[n * outs_col_stride +
c * outs_row_stride] = m; c * outs_row_stride] = m;
if (! %(replace)s ) if (! %(replace)s )
{ {
global_pvals_copy[m * pvals_col_stride + n * pvals_row_stride] = 0.0; global_pvals_copy[m * pvals_col_stride + n * pvals_row_stride] = 0.0;
cummul -= pvals_nm; pvals_sum -= pvals_nm;
} }
done = true; break;
}
}
// No need to renormalize after the last samples.
if (c == (n_samples - 1))
break;
if (! %(replace)s )
{
// parallel renormalize the multinomial
for (ga_int k = LID_1; k < nb_outcomes; k+=LDIM_1)
{
global_pvals_copy[k * pvals_col_stride + n * pvals_row_stride] /= cummul;
} }
} }
} }
...@@ -466,18 +456,7 @@ KERNEL void k_multi_warp_multinomial_wor( ...@@ -466,18 +456,7 @@ KERNEL void k_multi_warp_multinomial_wor(
args[9] = (void*)&strides[3]; args[9] = (void*)&strides[3];
args[10] = (void*)&strides[4]; args[10] = (void*)&strides[4];
size_t nb_threads2[2], nb_blocks2[2]; err = GpuKernel_call(&%(kname)s, 1, &nb_blocks, &nb_threads, 0, args);
nb_threads2[0] = nb_threads;
nb_threads2[1] = 1;
// If we can't schedule enough threads parallelize the renormalization.
// I do this because we don't always use those extra threads.
if ((nb_threads * nb_blocks < 2048) && ! %(replace)d )
nb_threads2[1] = 1024 / nb_threads;
nb_blocks2[0] = nb_blocks;
nb_blocks2[1] = 1;
err = GpuKernel_call(&%(kname)s, 2, nb_blocks2, nb_threads2, 0, args);
if (err != GA_NO_ERROR) { if (err != GA_NO_ERROR) {
PyErr_Format( PyErr_Format(
PyExc_RuntimeError, PyExc_RuntimeError,
...@@ -495,7 +474,7 @@ KERNEL void k_multi_warp_multinomial_wor( ...@@ -495,7 +474,7 @@ KERNEL void k_multi_warp_multinomial_wor(
return s return s
def c_code_cache_version(self): def c_code_cache_version(self):
return (4,) return (6,)
@register_opt('fast_compile') @register_opt('fast_compile')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论