提交 f408152d authored 作者: Amjad Almahairi's avatar Amjad Almahairi

add suggested fixes

上级 846804d0
......@@ -79,12 +79,12 @@ class MultinomialFromUniform(Op):
return """
if (PyArray_NDIM(%(pvals)s) != 2)
{
PyErr_Format(PyExc_TypeError, "pvals wrong rank");
PyErr_Format(PyExc_TypeError, "pvals ndim should be 2");
%(fail)s;
}
if (PyArray_NDIM(%(unis)s) != 1)
{
PyErr_Format(PyExc_TypeError, "unis wrong rank");
PyErr_Format(PyExc_TypeError, "unis ndim should be 2");
%(fail)s;
}
......@@ -236,18 +236,18 @@ class MultinomialWOReplacementFromUniform(MultinomialFromUniform):
if (PyArray_NDIM(%(pvals)s) != 2)
{
PyErr_Format(PyExc_TypeError, "pvals wrong rank");
PyErr_Format(PyExc_TypeError, "pvals ndim should be 2");
%(fail)s;
}
if (PyArray_NDIM(%(unis)s) != 1)
{
PyErr_Format(PyExc_TypeError, "unis wrong rank");
PyErr_Format(PyExc_TypeError, "unis ndim should be 2");
%(fail)s;
}
if ( %(n)s > (PyArray_DIMS(%(pvals)s)[1]) )
{
PyErr_Format(PyExc_ValueError, "n > pvals.shape[1]");
PyErr_Format(PyExc_ValueError, "Cannot sample without replacement n samples bigger than the size of the distribution.");
%(fail)s;
}
......@@ -295,20 +295,6 @@ class MultinomialWOReplacementFromUniform(MultinomialFromUniform):
const int nb_outcomes = PyArray_DIMS(%(pvals)s)[1];
const int n_samples = %(n)s;
// for (int n = 0; n < nb_multi; ++n)
// {
// for (int m = 0; m < nb_outcomes; ++m)
// {
// const dtype_%(pvals)s* pvals_nm = (dtype_%(pvals)s*)PyArray_GETPTR2(%(pvals)s, n, m);
// dtype_%(pvals)s* pvals_copy_nm = (dtype_%(pvals)s*)PyArray_GETPTR2(pvals_copy, n, m);
// *pvals_copy_nm = *pvals_nm;
//
// // if (*pvals_copy_nm == *pvals_nm){
// // printf("OK");
// // }
// }
// }
//
// For each multinomial, loop over each possible outcome,
// and set selected pval to 0 after being selected
......@@ -321,17 +307,16 @@ class MultinomialWOReplacementFromUniform(MultinomialFromUniform):
dtype_%(z)s* z_nc = (dtype_%(z)s*)PyArray_GETPTR2(%(z)s, n, c);
for (int m = 0; m < nb_outcomes; ++m)
{
dtype_%(z)s* z_nc = (dtype_%(z)s*)PyArray_GETPTR2(%(z)s, n, c);
dtype_%(pvals)s* pvals_nm = (dtype_%(pvals)s*)PyArray_GETPTR2(pvals_copy, n, m);
cummul += *pvals_nm;
if (cummul > *unis_n)
{
*z_nc = m;
// renormalize the nth row of pvals, reuse (cummul-*pvals_nm) to initialize the sum
dtype_%(pvals)s sum = cummul - *pvals_nm;
dtype_%(pvals)s* pvals_n = (dtype_%(pvals)s*)PyArray_GETPTR2(pvals_copy, n, m+1);
*pvals_nm = 0.;
// renormalize the nth row of pvals
dtype_%(pvals)s sum = 0.;
dtype_%(pvals)s* pvals_n = (dtype_%(pvals)s*)PyArray_GETPTR2(pvals_copy, n, 0);
for (int k = 0; k < nb_outcomes; ++k)
for (int k = m+1; k < nb_outcomes; ++k)
{
sum = sum + *pvals_n;
pvals_n++;
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论