提交 60b3686f authored 作者: Amjad Almahairi's avatar Amjad Almahairi

working 1st version

上级 2176d21b
......@@ -236,11 +236,16 @@ class WeightedSelectionFromUniform(Op):
(pvals, unis, n_samples) = ins
(z,) = outs
if n_samples > pvals.shape[1]:
raise ValueError("Cannot sample without replacement n samples bigger "
"than the size of the distribution.")
if unis.shape[0] != pvals.shape[0] * n_samples:
raise ValueError("unis.shape[0] != pvals.shape[0] * n_samples",
unis.shape[0], pvals.shape[0], n_samples)
if z[0] is None or numpy.any(z[0].shape != [pvals.shape[0], n_samples]):
z[0] = numpy.zeros((pvals.shape[0], n_samples), dtype=node.outputs[0].dtype)
if z[0] is None or not numpy.all(z[0].shape == [pvals.shape[0], n_samples]):
z[0] = -1 * numpy.ones((pvals.shape[0], n_samples), dtype=node.outputs[0].dtype)
nb_multi = pvals.shape[0]
nb_outcomes = pvals.shape[1]
......@@ -255,9 +260,10 @@ class WeightedSelectionFromUniform(Op):
cummul += pvals[n, m]
if (cummul > unis_n):
z[0][n, c] = m
# set to zero so that it's not selected again
# set to zero and re-normalize so that it's not selected again
pvals[n, m] = 0.
pvals[n] /= pvals[n].sum()
break
class GpuMultinomialFromUniform(MultinomialFromUniform, GpuOp):
"""
......
......@@ -1366,8 +1366,9 @@ class MRG_RandomStreams(object):
def weighted_selection(self, size=None, n=1, pvals=None, ndim=None, dtype='int64',
nstreams=None):
"""
Sample `n` times (`n` needs to be in [1, m], where m is pvals.shape[1], default 1)
*WITHOUT replacement* from a multinomial distribution defined by probabilities pvals.
Sample `n` times *WITHOUT replacement* from a multinomial distribution
defined by probabilities pvals. `n` needs to be in [1, m], where m is the number of
elements to select from, i.e. m == pvals.shape[1]. By default n = 1.
Example : WRITEME
......@@ -1387,9 +1388,6 @@ class MRG_RandomStreams(object):
raise TypeError("You have to specify pvals")
pvals = as_tensor_variable(pvals)
if n > pvals.shape[1]:
raise ValueError("Cannot sample without replacement n samples bigger "
"than the size of the distribution.")
if size is not None:
raise ValueError("Provided a size argument to "
"MRG_RandomStreams.weighted_selection, which does not use "
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论