提交 bb93f655 authored 作者: Amjad Almahairi's avatar Amjad Almahairi

first attempt

上级 42907a0c
......@@ -215,6 +215,112 @@ class MultinomialWOReplacementFromUniform(MultinomialFromUniform):
out = T.tensor(dtype=odtype, broadcastable=pvals.type.broadcastable)
return Apply(self, [pvals, unis, as_scalar(n)], [out])
def c_code_cache_version(self):
return None
def c_code(self, node, name, ins, outs, sub):
(pvals, unis, n) = ins
(z,) = outs
if self.odtype == 'auto':
t = "NPY_INTP"
else:
t = theano.scalar.Scalar(self.odtype).dtype_specs()[1]
if t.startswith('theano_complex'):
t = t.replace('theano_complex', 'NPY_COMPLEX')
else:
t = t.upper()
fail = sub['fail']
return """
if (PyArray_NDIM(%(pvals)s) != 2)
{
PyErr_Format(PyExc_TypeError, "pvals wrong rank");
%(fail)s;
}
if (PyArray_NDIM(%(unis)s) != 1)
{
PyErr_Format(PyExc_TypeError, "unis wrong rank");
%(fail)s;
}
if ( %(n)s > (PyArray_DIMS(%(pvals)s)[1]) )
{
PyErr_Format(PyExc_ValueError, "n > pvals.shape[1]");
%(fail)s;
}
if (PyArray_DIMS(%(unis)s)[0] != (PyArray_DIMS(%(pvals)s)[0] * %(n)s))
{
PyErr_Format(PyExc_ValueError, "unis.shape[0] != pvals.shape[0] * n");
%(fail)s;
}
if ((NULL == %(z)s)
|| ((PyArray_DIMS(%(z)s))[0] != (PyArray_DIMS(%(pvals)s))[0])
|| ((PyArray_DIMS(%(z)s))[1] != %(n)s)
)
{
Py_XDECREF(%(z)s);
npy_intp dims[2];
dims[0] = PyArray_DIMS(%(pvals)s)[0];
dims[1] = %(n)s;
%(z)s = (PyArrayObject*) PyArray_EMPTY(2,
dims,
%(t)s,
-1);
if (!%(z)s)
{
PyErr_SetString(PyExc_MemoryError, "failed to alloc z output");
%(fail)s;
}
}
{ // NESTED SCOPE
const int nb_multi = PyArray_DIMS(%(pvals)s)[0];
const int nb_outcomes = PyArray_DIMS(%(pvals)s)[1];
const int n_samples = %(n)s;
//
// For each multinomial, loop over each possible outcome,
// and set selected pval to 0 after being selected
//
for (int c = 0; c < n_samples; ++c){
for (int n = 0; n < nb_multi; ++n)
{
double cummul = 0.;
const dtype_%(unis)s* unis_n = (dtype_%(unis)s*)PyArray_GETPTR1(%(unis)s, c*nb_multi + n);
for (int m = 0; m < nb_outcomes; ++m)
{
dtype_%(z)s* z_nc = (dtype_%(z)s*)PyArray_GETPTR2(%(z)s, n, c);
dtype_%(pvals)s* pvals_nm = (dtype_%(pvals)s*)PyArray_GETPTR2(%(pvals)s, n, m);
cummul += *pvals_nm;
if (cummul > *unis_n)
{
*z_nc = m;
*pvals_nm = 0.;
// renormalize the nth row of pvals
dtype_%(pvals)s sum = 0.;
dtype_%(pvals)s* pvals_n = (dtype_%(pvals)s*)PyArray_GETPTR2(%(pvals)s, n, 0);
for (int k = 0; k < nb_outcomes; ++k)
{
sum = sum + *pvals_n;
pvals_n++;
}
pvals_n = (dtype_%(pvals)s*)PyArray_GETPTR2(%(pvals)s, n, 0);
for (int k = 0; k < nb_outcomes; ++k)
{
*pvals_n = *pvals_n / sum;
pvals_n++;
}
break;
}
}
}
}
} // END NESTED SCOPE
""" % locals()
def perform(self, node, ins, outs):
(pvals, unis, n_samples) = ins
# make a copy so we do not overwrite the input
......@@ -254,12 +360,6 @@ class MultinomialWOReplacementFromUniform(MultinomialFromUniform):
pvals[n] /= pvals[n].sum()
break
def c_code_cache_version(self):
return None
def c_code(self, node, name, ins, outs, sub):
raise NotImplementedError('no C implementation yet!')
class GpuMultinomialFromUniform(MultinomialFromUniform, GpuOp):
"""
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论