提交 1a3f4e6a authored 作者: Xavier Bouthillier's avatar Xavier Bouthillier

Merge pull request #3950 from aalmah/rand_weighted_select_c_impl

C implementation of sample without replacement OP
......@@ -79,12 +79,12 @@ class MultinomialFromUniform(Op):
return """
if (PyArray_NDIM(%(pvals)s) != 2)
{
PyErr_Format(PyExc_TypeError, "pvals wrong rank");
PyErr_Format(PyExc_TypeError, "pvals ndim should be 2");
%(fail)s;
}
if (PyArray_NDIM(%(unis)s) != 1)
{
PyErr_Format(PyExc_TypeError, "unis wrong rank");
PyErr_Format(PyExc_TypeError, "unis ndim should be 2");
%(fail)s;
}
......@@ -215,6 +215,131 @@ class MultinomialWOReplacementFromUniform(MultinomialFromUniform):
out = T.tensor(dtype=odtype, broadcastable=pvals.type.broadcastable)
return Apply(self, [pvals, unis, as_scalar(n)], [out])
def c_code_cache_version(self):
return (1,)
def c_code(self, node, name, ins, outs, sub):
(pvals, unis, n) = ins
(z,) = outs
if self.odtype == 'auto':
t = "NPY_INT64"
else:
t = theano.scalar.Scalar(self.odtype).dtype_specs()[1]
if t.startswith('theano_complex'):
t = t.replace('theano_complex', 'NPY_COMPLEX')
else:
t = t.upper()
fail = sub['fail']
return """
// create a copy of pvals matrix
PyArrayObject* pvals_copy = NULL;
if (PyArray_NDIM(%(pvals)s) != 2)
{
PyErr_Format(PyExc_TypeError, "pvals ndim should be 2");
%(fail)s;
}
if (PyArray_NDIM(%(unis)s) != 1)
{
PyErr_Format(PyExc_TypeError, "unis ndim should be 2");
%(fail)s;
}
if ( %(n)s > (PyArray_DIMS(%(pvals)s)[1]) )
{
PyErr_Format(PyExc_ValueError, "Cannot sample without replacement n samples bigger than the size of the distribution.");
%(fail)s;
}
if (PyArray_DIMS(%(unis)s)[0] != (PyArray_DIMS(%(pvals)s)[0] * %(n)s))
{
PyErr_Format(PyExc_ValueError, "unis.shape[0] != pvals.shape[0] * n");
%(fail)s;
}
pvals_copy = (PyArrayObject*) PyArray_EMPTY(2,
PyArray_DIMS(%(pvals)s),
PyArray_TYPE(%(pvals)s),
0);
if (!pvals_copy)
{
PyErr_SetString(PyExc_MemoryError, "failed to alloc pvals_copy");
%(fail)s;
}
PyArray_CopyInto(pvals_copy, %(pvals)s);
if ((NULL == %(z)s)
|| ((PyArray_DIMS(%(z)s))[0] != (PyArray_DIMS(%(pvals)s))[0])
|| ((PyArray_DIMS(%(z)s))[1] != %(n)s)
)
{
Py_XDECREF(%(z)s);
npy_intp dims[2];
dims[0] = PyArray_DIMS(%(pvals)s)[0];
dims[1] = %(n)s;
%(z)s = (PyArrayObject*) PyArray_EMPTY(2,
dims,
%(t)s,
-1);
if (!%(z)s)
{
PyErr_SetString(PyExc_MemoryError, "failed to alloc z output");
%(fail)s;
}
}
{ // NESTED SCOPE
const int nb_multi = PyArray_DIMS(%(pvals)s)[0];
const int nb_outcomes = PyArray_DIMS(%(pvals)s)[1];
const int n_samples = %(n)s;
//
// For each multinomial, loop over each possible outcome,
// and set selected pval to 0 after being selected
//
for (int c = 0; c < n_samples; ++c){
for (int n = 0; n < nb_multi; ++n)
{
double cummul = 0.;
const dtype_%(unis)s* unis_n = (dtype_%(unis)s*)PyArray_GETPTR1(%(unis)s, c*nb_multi + n);
dtype_%(z)s* z_nc = (dtype_%(z)s*)PyArray_GETPTR2(%(z)s, n, c);
for (int m = 0; m < nb_outcomes; ++m)
{
dtype_%(pvals)s* pvals_nm = (dtype_%(pvals)s*)PyArray_GETPTR2(pvals_copy, n, m);
cummul += *pvals_nm;
if (cummul > *unis_n)
{
*z_nc = m;
// renormalize the nth row of pvals, reuse (cummul-*pvals_nm) to initialize the sum
dtype_%(pvals)s sum = cummul - *pvals_nm;
dtype_%(pvals)s* pvals_n = (dtype_%(pvals)s*)PyArray_GETPTR2(pvals_copy, n, m);
*pvals_nm = 0.;
for (int k = m; k < nb_outcomes; ++k)
{
sum = sum + *pvals_n;
pvals_n++;
}
pvals_n = (dtype_%(pvals)s*)PyArray_GETPTR2(pvals_copy, n, 0);
for (int k = 0; k < nb_outcomes; ++k)
{
*pvals_n = *pvals_n / sum;
pvals_n++;
}
break;
}
}
}
}
// delete pvals_copy
{
Py_XDECREF(pvals_copy);
}
} // END NESTED SCOPE
""" % locals()
def perform(self, node, ins, outs):
(pvals, unis, n_samples) = ins
# make a copy so we do not overwrite the input
......@@ -254,12 +379,6 @@ class MultinomialWOReplacementFromUniform(MultinomialFromUniform):
pvals[n] /= pvals[n].sum()
break
def c_code_cache_version(self):
return None
def c_code(self, node, name, ins, outs, sub):
raise NotImplementedError('no C implementation yet!')
class GpuMultinomialFromUniform(MultinomialFromUniform, GpuOp):
"""
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论