提交 ae635037 authored 作者: Amjad Almahairi's avatar Amjad Almahairi

change weighted_selection to multinomial_wo_replacement

上级 d775ab46
......@@ -194,7 +194,7 @@ class MultinomialFromUniform(Op):
break
class WeightedSelectionFromUniform(Op):
class MultinomialWOReplacementFromUniform(Op):
"""
Converts samples from a uniform into sample from a multinomial.
......
......@@ -1363,8 +1363,8 @@ class MRG_RandomStreams(object):
raise NotImplementedError(("MRG_RandomStreams.multinomial only"
" implemented for pvals.ndim = 2"))
def weighted_selection(self, size=None, n=1, pvals=None, ndim=None, dtype='int64',
nstreams=None):
def multinomial_wo_replacement(self, size=None, n=1, pvals=None, ndim=None, dtype='int64',
nstreams=None):
"""
Sample `n` times *WITHOUT replacement* from a multinomial distribution
defined by probabilities pvals, and returns the indices of the sampled
......@@ -1394,21 +1394,21 @@ class MRG_RandomStreams(object):
if size is not None:
raise ValueError("Provided a size argument to "
"MRG_RandomStreams.weighted_selection, which does not use "
"MRG_RandomStreams.multinomial_wo_replacement, which does not use "
"the size argument.")
if ndim is not None:
raise ValueError("Provided an ndim argument to "
"MRG_RandomStreams.weighted_selection, which does not use "
"MRG_RandomStreams.multinomial_wo_replacement, which does not use "
"the ndim argument.")
if pvals.ndim == 2:
# size = [pvals.shape[0], as_tensor_variable(n)]
size = pvals[:,0].shape * n
unis = self.uniform(size=size, ndim=1, nstreams=nstreams)
op = multinomial.WeightedSelectionFromUniform(dtype)
op = multinomial.MultinomialWOReplacementFromUniform(dtype)
n_samples = as_tensor_variable(n)
return op(pvals, unis, n_samples)
else:
raise NotImplementedError(("MRG_RandomStreams.weighted_selection only"
raise NotImplementedError(("MRG_RandomStreams.multinomial_wo_replacement only"
" implemented for pvals.ndim = 2"))
def normal(self, size, avg=0.0, std=1.0, ndim=None,
......
......@@ -9,12 +9,12 @@ class test_OP(unittest.TestCase):
def test_select_distinct(self):
"""
Tests that WeightedSelectionFromUniform always selects distinct elements
Tests that MultinomialWOReplacementFromUniform always selects distinct elements
"""
p = tensor.fmatrix()
u = tensor.fvector()
n = tensor.iscalar()
m = multinomial.WeightedSelectionFromUniform('auto')(p, u, n)
m = multinomial.MultinomialWOReplacementFromUniform('auto')(p, u, n)
f = function([p, u, n], m, allow_input_downcast=True)
......@@ -32,13 +32,13 @@ class test_OP(unittest.TestCase):
def test_fail_select_alot(self):
"""
Tests that WeightedSelectionFromUniform fails when asked to sample more
Tests that MultinomialWOReplacementFromUniform fails when asked to sample more
elements than the actual number of elements
"""
p = tensor.fmatrix()
u = tensor.fvector()
n = tensor.iscalar()
m = multinomial.WeightedSelectionFromUniform('auto')(p, u, n)
m = multinomial.MultinomialWOReplacementFromUniform('auto')(p, u, n)
f = function([p, u, n], m, allow_input_downcast=True)
......@@ -52,13 +52,13 @@ class test_OP(unittest.TestCase):
def test_select_proportional_to_weight(self):
"""
Tests that WeightedSelectionFromUniform selects elements, on average,
Tests that MultinomialWOReplacementFromUniform selects elements, on average,
proportional to the their probabilities
"""
p = tensor.fmatrix()
u = tensor.fvector()
n = tensor.iscalar()
m = multinomial.WeightedSelectionFromUniform('auto')(p, u, n)
m = multinomial.MultinomialWOReplacementFromUniform('auto')(p, u, n)
f = function([p, u, n], m, allow_input_downcast=True)
......@@ -83,13 +83,13 @@ class test_function(unittest.TestCase):
def test_select_distinct(self):
"""
Tests that weighted_selection always selects distinct elements
Tests that multinomial_wo_replacement always selects distinct elements
"""
th_rng = RandomStreams(12345)
p = tensor.fmatrix()
n = tensor.iscalar()
m = th_rng.weighted_selection(pvals=p, n=n)
m = th_rng.multinomial_wo_replacement(pvals=p, n=n)
f = function([p, n], m, allow_input_downcast=True)
......@@ -106,14 +106,14 @@ class test_function(unittest.TestCase):
def test_fail_select_alot(self):
"""
Tests that weighted_selection fails when asked to sample more
Tests that multinomial_wo_replacement fails when asked to sample more
elements than the actual number of elements
"""
th_rng = RandomStreams(12345)
p = tensor.fmatrix()
n = tensor.iscalar()
m = th_rng.weighted_selection(pvals=p, n=n)
m = th_rng.multinomial_wo_replacement(pvals=p, n=n)
f = function([p, n], m, allow_input_downcast=True)
......@@ -126,14 +126,14 @@ class test_function(unittest.TestCase):
def test_select_proportional_to_weight(self):
"""
Tests that weighted_selection selects elements, on average,
Tests that multinomial_wo_replacement selects elements, on average,
proportional to the their probabilities
"""
th_rng = RandomStreams(12345)
p = tensor.fmatrix()
n = tensor.iscalar()
m = th_rng.weighted_selection(pvals=p, n=n)
m = th_rng.multinomial_wo_replacement(pvals=p, n=n)
f = function([p, n], m, allow_input_downcast=True)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论