提交 ff50c400 authored 作者: Amjad Almahairi's avatar Amjad Almahairi

fixed bugs

上级 c42f938a
......@@ -222,7 +222,7 @@ class MultinomialWOReplacementFromUniform(MultinomialFromUniform):
(pvals, unis, n) = ins
(z,) = outs
if self.odtype == 'auto':
t = "NPY_INTP"
t = "NPY_INT64"
else:
t = theano.scalar.Scalar(self.odtype).dtype_specs()[1]
if t.startswith('theano_complex'):
......@@ -231,6 +231,9 @@ class MultinomialWOReplacementFromUniform(MultinomialFromUniform):
t = t.upper()
fail = sub['fail']
return """
// create a copy of pvals matrix
PyArrayObject* pvals_copy = NULL;
if (PyArray_NDIM(%(pvals)s) != 2)
{
PyErr_Format(PyExc_TypeError, "pvals wrong rank");
......@@ -254,6 +257,18 @@ class MultinomialWOReplacementFromUniform(MultinomialFromUniform):
%(fail)s;
}
pvals_copy = (PyArrayObject*) PyArray_EMPTY(2,
PyArray_DIMS(%(pvals)s),
PyArray_TYPE(%(pvals)s),
0);
if (!pvals_copy)
{
PyErr_SetString(PyExc_MemoryError, "failed to alloc pvals_copy");
%(fail)s;
}
PyArray_CopyInto(pvals_copy, %(pvals)s);
if ((NULL == %(z)s)
|| ((PyArray_DIMS(%(z)s))[0] != (PyArray_DIMS(%(pvals)s))[0])
|| ((PyArray_DIMS(%(z)s))[1] != %(n)s)
......@@ -275,11 +290,25 @@ class MultinomialWOReplacementFromUniform(MultinomialFromUniform):
}
{ // NESTED SCOPE
const int nb_multi = PyArray_DIMS(%(pvals)s)[0];
const int nb_outcomes = PyArray_DIMS(%(pvals)s)[1];
const int n_samples = %(n)s;
// for (int n = 0; n < nb_multi; ++n)
// {
// for (int m = 0; m < nb_outcomes; ++m)
// {
// const dtype_%(pvals)s* pvals_nm = (dtype_%(pvals)s*)PyArray_GETPTR2(%(pvals)s, n, m);
// dtype_%(pvals)s* pvals_copy_nm = (dtype_%(pvals)s*)PyArray_GETPTR2(pvals_copy, n, m);
// *pvals_copy_nm = *pvals_nm;
//
// // if (*pvals_copy_nm == *pvals_nm){
// // printf("OK");
// // }
// }
// }
//
// For each multinomial, loop over each possible outcome,
// and set selected pval to 0 after being selected
......@@ -293,7 +322,7 @@ class MultinomialWOReplacementFromUniform(MultinomialFromUniform):
for (int m = 0; m < nb_outcomes; ++m)
{
dtype_%(z)s* z_nc = (dtype_%(z)s*)PyArray_GETPTR2(%(z)s, n, c);
dtype_%(pvals)s* pvals_nm = (dtype_%(pvals)s*)PyArray_GETPTR2(%(pvals)s, n, m);
dtype_%(pvals)s* pvals_nm = (dtype_%(pvals)s*)PyArray_GETPTR2(pvals_copy, n, m);
cummul += *pvals_nm;
if (cummul > *unis_n)
{
......@@ -301,13 +330,13 @@ class MultinomialWOReplacementFromUniform(MultinomialFromUniform):
*pvals_nm = 0.;
// renormalize the nth row of pvals
dtype_%(pvals)s sum = 0.;
dtype_%(pvals)s* pvals_n = (dtype_%(pvals)s*)PyArray_GETPTR2(%(pvals)s, n, 0);
dtype_%(pvals)s* pvals_n = (dtype_%(pvals)s*)PyArray_GETPTR2(pvals_copy, n, 0);
for (int k = 0; k < nb_outcomes; ++k)
{
sum = sum + *pvals_n;
pvals_n++;
}
pvals_n = (dtype_%(pvals)s*)PyArray_GETPTR2(%(pvals)s, n, 0);
pvals_n = (dtype_%(pvals)s*)PyArray_GETPTR2(pvals_copy, n, 0);
for (int k = 0; k < nb_outcomes; ++k)
{
*pvals_n = *pvals_n / sum;
......@@ -318,6 +347,11 @@ class MultinomialWOReplacementFromUniform(MultinomialFromUniform):
}
}
}
// delete pvals_copy
{
Py_XDECREF(pvals_copy);
}
} // END NESTED SCOPE
""" % locals()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论