提交 ff50c400 authored 作者: Amjad Almahairi's avatar Amjad Almahairi

fixed bugs

上级 c42f938a
...@@ -222,7 +222,7 @@ class MultinomialWOReplacementFromUniform(MultinomialFromUniform): ...@@ -222,7 +222,7 @@ class MultinomialWOReplacementFromUniform(MultinomialFromUniform):
(pvals, unis, n) = ins (pvals, unis, n) = ins
(z,) = outs (z,) = outs
if self.odtype == 'auto': if self.odtype == 'auto':
t = "NPY_INTP" t = "NPY_INT64"
else: else:
t = theano.scalar.Scalar(self.odtype).dtype_specs()[1] t = theano.scalar.Scalar(self.odtype).dtype_specs()[1]
if t.startswith('theano_complex'): if t.startswith('theano_complex'):
...@@ -231,6 +231,9 @@ class MultinomialWOReplacementFromUniform(MultinomialFromUniform): ...@@ -231,6 +231,9 @@ class MultinomialWOReplacementFromUniform(MultinomialFromUniform):
t = t.upper() t = t.upper()
fail = sub['fail'] fail = sub['fail']
return """ return """
// create a copy of pvals matrix
PyArrayObject* pvals_copy = NULL;
if (PyArray_NDIM(%(pvals)s) != 2) if (PyArray_NDIM(%(pvals)s) != 2)
{ {
PyErr_Format(PyExc_TypeError, "pvals wrong rank"); PyErr_Format(PyExc_TypeError, "pvals wrong rank");
...@@ -254,6 +257,18 @@ class MultinomialWOReplacementFromUniform(MultinomialFromUniform): ...@@ -254,6 +257,18 @@ class MultinomialWOReplacementFromUniform(MultinomialFromUniform):
%(fail)s; %(fail)s;
} }
pvals_copy = (PyArrayObject*) PyArray_EMPTY(2,
PyArray_DIMS(%(pvals)s),
PyArray_TYPE(%(pvals)s),
0);
if (!pvals_copy)
{
PyErr_SetString(PyExc_MemoryError, "failed to alloc pvals_copy");
%(fail)s;
}
PyArray_CopyInto(pvals_copy, %(pvals)s);
if ((NULL == %(z)s) if ((NULL == %(z)s)
|| ((PyArray_DIMS(%(z)s))[0] != (PyArray_DIMS(%(pvals)s))[0]) || ((PyArray_DIMS(%(z)s))[0] != (PyArray_DIMS(%(pvals)s))[0])
|| ((PyArray_DIMS(%(z)s))[1] != %(n)s) || ((PyArray_DIMS(%(z)s))[1] != %(n)s)
...@@ -280,6 +295,20 @@ class MultinomialWOReplacementFromUniform(MultinomialFromUniform): ...@@ -280,6 +295,20 @@ class MultinomialWOReplacementFromUniform(MultinomialFromUniform):
const int nb_outcomes = PyArray_DIMS(%(pvals)s)[1]; const int nb_outcomes = PyArray_DIMS(%(pvals)s)[1];
const int n_samples = %(n)s; const int n_samples = %(n)s;
// for (int n = 0; n < nb_multi; ++n)
// {
// for (int m = 0; m < nb_outcomes; ++m)
// {
// const dtype_%(pvals)s* pvals_nm = (dtype_%(pvals)s*)PyArray_GETPTR2(%(pvals)s, n, m);
// dtype_%(pvals)s* pvals_copy_nm = (dtype_%(pvals)s*)PyArray_GETPTR2(pvals_copy, n, m);
// *pvals_copy_nm = *pvals_nm;
//
// // if (*pvals_copy_nm == *pvals_nm){
// // printf("OK");
// // }
// }
// }
// //
// For each multinomial, loop over each possible outcome, // For each multinomial, loop over each possible outcome,
// and set selected pval to 0 after being selected // and set selected pval to 0 after being selected
...@@ -293,7 +322,7 @@ class MultinomialWOReplacementFromUniform(MultinomialFromUniform): ...@@ -293,7 +322,7 @@ class MultinomialWOReplacementFromUniform(MultinomialFromUniform):
for (int m = 0; m < nb_outcomes; ++m) for (int m = 0; m < nb_outcomes; ++m)
{ {
dtype_%(z)s* z_nc = (dtype_%(z)s*)PyArray_GETPTR2(%(z)s, n, c); dtype_%(z)s* z_nc = (dtype_%(z)s*)PyArray_GETPTR2(%(z)s, n, c);
dtype_%(pvals)s* pvals_nm = (dtype_%(pvals)s*)PyArray_GETPTR2(%(pvals)s, n, m); dtype_%(pvals)s* pvals_nm = (dtype_%(pvals)s*)PyArray_GETPTR2(pvals_copy, n, m);
cummul += *pvals_nm; cummul += *pvals_nm;
if (cummul > *unis_n) if (cummul > *unis_n)
{ {
...@@ -301,13 +330,13 @@ class MultinomialWOReplacementFromUniform(MultinomialFromUniform): ...@@ -301,13 +330,13 @@ class MultinomialWOReplacementFromUniform(MultinomialFromUniform):
*pvals_nm = 0.; *pvals_nm = 0.;
// renormalize the nth row of pvals // renormalize the nth row of pvals
dtype_%(pvals)s sum = 0.; dtype_%(pvals)s sum = 0.;
dtype_%(pvals)s* pvals_n = (dtype_%(pvals)s*)PyArray_GETPTR2(%(pvals)s, n, 0); dtype_%(pvals)s* pvals_n = (dtype_%(pvals)s*)PyArray_GETPTR2(pvals_copy, n, 0);
for (int k = 0; k < nb_outcomes; ++k) for (int k = 0; k < nb_outcomes; ++k)
{ {
sum = sum + *pvals_n; sum = sum + *pvals_n;
pvals_n++; pvals_n++;
} }
pvals_n = (dtype_%(pvals)s*)PyArray_GETPTR2(%(pvals)s, n, 0); pvals_n = (dtype_%(pvals)s*)PyArray_GETPTR2(pvals_copy, n, 0);
for (int k = 0; k < nb_outcomes; ++k) for (int k = 0; k < nb_outcomes; ++k)
{ {
*pvals_n = *pvals_n / sum; *pvals_n = *pvals_n / sum;
...@@ -318,6 +347,11 @@ class MultinomialWOReplacementFromUniform(MultinomialFromUniform): ...@@ -318,6 +347,11 @@ class MultinomialWOReplacementFromUniform(MultinomialFromUniform):
} }
} }
} }
// delete pvals_copy
{
Py_XDECREF(pvals_copy);
}
} // END NESTED SCOPE } // END NESTED SCOPE
""" % locals() """ % locals()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论