提交 20b23e6e authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Normalize 1 scalar instead of a full vector

It also fixes a race condition.
上级 aa2bb423
......@@ -309,39 +309,29 @@ KERNEL void k_multi_warp_multinomial_wor(
if (n < nb_multi)
{
// Sum of the remaining p_vals in global_pvals_copy[n]
float pvals_sum = 1.;
for (int c = 0; c < n_samples; ++c)
{
float cummul = 0.;
bool done = false;
const float unis_n = global_unis[(c * nb_multi + n)*unis_stride];
const float unis_n = global_unis[(c * nb_multi + n)*unis_stride] * pvals_sum;
for (ga_size m = 0; m < nb_outcomes; ++m)
{
float pvals_nm = global_pvals_copy[m * pvals_col_stride + n * pvals_row_stride];
cummul += pvals_nm;
if (!done && unis_n < cummul)
if (unis_n < cummul)
{
//write out transposed for speed.
// write out transposed for speed.
global_outs[n * outs_col_stride +
c * outs_row_stride] = m;
if (! %(replace)s )
{
global_pvals_copy[m * pvals_col_stride + n * pvals_row_stride] = 0.0;
cummul -= pvals_nm;
pvals_sum -= pvals_nm;
}
done = true;
}
}
// No need to renormalize after the last samples.
if (c == (n_samples - 1))
break;
if (! %(replace)s )
{
// parallel renormalize the multinomial
for (ga_int k = LID_1; k < nb_outcomes; k+=LDIM_1)
{
global_pvals_copy[k * pvals_col_stride + n * pvals_row_stride] /= cummul;
break;
}
}
}
......@@ -466,18 +456,7 @@ KERNEL void k_multi_warp_multinomial_wor(
args[9] = (void*)&strides[3];
args[10] = (void*)&strides[4];
size_t nb_threads2[2], nb_blocks2[2];
nb_threads2[0] = nb_threads;
nb_threads2[1] = 1;
// If we can't schedule enough threads parallelize the renormalization.
// I do this because we don't always use those extra threads.
if ((nb_threads * nb_blocks < 2048) && ! %(replace)d )
nb_threads2[1] = 1024 / nb_threads;
nb_blocks2[0] = nb_blocks;
nb_blocks2[1] = 1;
err = GpuKernel_call(&%(kname)s, 2, nb_blocks2, nb_threads2, 0, args);
err = GpuKernel_call(&%(kname)s, 1, &nb_blocks, &nb_threads, 0, args);
if (err != GA_NO_ERROR) {
PyErr_Format(
PyExc_RuntimeError,
......@@ -495,7 +474,7 @@ KERNEL void k_multi_warp_multinomial_wor(
return s
def c_code_cache_version(self):
return (4,)
return (6,)
@register_opt('fast_compile')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论