提交 fcdd8c77 authored 作者: Frédéric Bastien's avatar Frédéric Bastien 提交者: GitHub

Merge pull request #5346 from nouiz/multinomial_wor

[ENH] Multinomial without replacement
...@@ -315,8 +315,11 @@ KERNEL void k_multi_warp_multinomial_wor( ...@@ -315,8 +315,11 @@ KERNEL void k_multi_warp_multinomial_wor(
done = true; done = true;
} }
} }
// renormalize the multinomial // No need to renormalize after the last samples.
for (ga_int k = 0; k < nb_outcomes; ++k) if (c == (n_samples - 1))
break;
// parallel renormalize the multinomial
for (ga_int k = LID_1; k < nb_outcomes; k+=LDIM_1)
{ {
global_pvals_copy[k * pvals_col_stride + n * pvals_row_stride] /= cummul; global_pvals_copy[k * pvals_col_stride + n * pvals_row_stride] /= cummul;
} }
...@@ -385,7 +388,8 @@ KERNEL void k_multi_warp_multinomial_wor( ...@@ -385,7 +388,8 @@ KERNEL void k_multi_warp_multinomial_wor(
if (theano_prep_output(&out, 2, dims, GA_LONG, if (theano_prep_output(&out, 2, dims, GA_LONG,
GA_C_ORDER, %(ctx)s) != 0){ GA_C_ORDER, %(ctx)s) != 0){
%(fail)s Py_DECREF(pvals_copy);
%(fail)s
} }
%(out)s = out; %(out)s = out;
...@@ -413,6 +417,7 @@ KERNEL void k_multi_warp_multinomial_wor( ...@@ -413,6 +417,7 @@ KERNEL void k_multi_warp_multinomial_wor(
PyExc_ValueError, PyExc_ValueError,
"Multinomial is not implemented for so many rows in the matrix (%%i)", "Multinomial is not implemented for so many rows in the matrix (%%i)",
nb_multi); nb_multi);
Py_DECREF(pvals_copy);
%(fail)s %(fail)s
} }
...@@ -439,23 +444,36 @@ KERNEL void k_multi_warp_multinomial_wor( ...@@ -439,23 +444,36 @@ KERNEL void k_multi_warp_multinomial_wor(
args[9] = (void*)&strides[3]; args[9] = (void*)&strides[3];
args[10] = (void*)&strides[4]; args[10] = (void*)&strides[4];
err = GpuKernel_call(&%(kname)s, 1, &nb_threads, &nb_blocks, 0, args); size_t nb_threads2[2], nb_blocks2[2];
nb_threads2[0] = nb_threads;
nb_threads2[1] = 1;
// If we can't schedule enough threads parallelize the renormalization.
// I do this because we don't always use those extra threads.
if (nb_threads * nb_blocks < 2048)
nb_threads2[1] = 1024 / nb_threads;
nb_blocks2[0] = nb_blocks;
nb_blocks2[1] = 1;
err = GpuKernel_call(&%(kname)s, 2, nb_threads2, nb_blocks2, 0, args);
if (err != GA_NO_ERROR) { if (err != GA_NO_ERROR) {
PyErr_Format( PyErr_Format(
PyExc_RuntimeError, PyExc_RuntimeError,
"gpuarray error: %%s: %%s.\\n", "gpuarray error: %%s: %%s.\\n",
"k_multi_warp_%(name)s", "k_multi_warp_%(name)s",
GpuKernel_error(&%(kname)s, err)); GpuKernel_error(&%(kname)s, err));
%(fail)s; Py_DECREF(pvals_copy);
%(fail)s;
} }
if(%(sync)d) if(%(sync)d)
GpuArray_sync(&(out->ga)); GpuArray_sync(&(out->ga));
Py_DECREF(pvals_copy);
} // END NESTED SCOPE } // END NESTED SCOPE
""" % locals() """ % locals()
return s return s
def c_code_cache_version(self): def c_code_cache_version(self):
return (1,) return (4,)
@register_opt('fast_compile') @register_opt('fast_compile')
......
...@@ -329,6 +329,9 @@ class MultinomialWOReplacementFromUniform(MultinomialFromUniform): ...@@ -329,6 +329,9 @@ class MultinomialWOReplacementFromUniform(MultinomialFromUniform):
if (cummul > *unis_n) if (cummul > *unis_n)
{ {
*z_nc = m; *z_nc = m;
// No need to renormalize after the last samples.
if (c == (n_samples - 1))
break;
// renormalize the nth row of pvals, reuse (cummul-*pvals_nm) to initialize the sum // renormalize the nth row of pvals, reuse (cummul-*pvals_nm) to initialize the sum
dtype_%(pvals)s sum = cummul - *pvals_nm; dtype_%(pvals)s sum = cummul - *pvals_nm;
dtype_%(pvals)s* pvals_n = (dtype_%(pvals)s*)PyArray_GETPTR2(pvals_copy, n, m); dtype_%(pvals)s* pvals_n = (dtype_%(pvals)s*)PyArray_GETPTR2(pvals_copy, n, m);
...@@ -434,7 +437,7 @@ class GpuMultinomialFromUniform(MultinomialFromUniform, GpuOp): ...@@ -434,7 +437,7 @@ class GpuMultinomialFromUniform(MultinomialFromUniform, GpuOp):
return Op.perform(self, node, ins, outs) return Op.perform(self, node, ins, outs)
def c_code_cache_version(self): def c_code_cache_version(self):
return (8,) return (9,)
def c_support_code_apply(self, node, nodename): def c_support_code_apply(self, node, nodename):
return """ return """
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论