提交 ae635037 authored 作者: Amjad Almahairi's avatar Amjad Almahairi

change weighted_selection to multinomial_wo_replacement

上级 d775ab46
...@@ -194,7 +194,7 @@ class MultinomialFromUniform(Op): ...@@ -194,7 +194,7 @@ class MultinomialFromUniform(Op):
break break
class WeightedSelectionFromUniform(Op): class MultinomialWOReplacementFromUniform(Op):
""" """
Converts samples from a uniform into sample from a multinomial. Converts samples from a uniform into sample from a multinomial.
......
...@@ -1363,7 +1363,7 @@ class MRG_RandomStreams(object): ...@@ -1363,7 +1363,7 @@ class MRG_RandomStreams(object):
raise NotImplementedError(("MRG_RandomStreams.multinomial only" raise NotImplementedError(("MRG_RandomStreams.multinomial only"
" implemented for pvals.ndim = 2")) " implemented for pvals.ndim = 2"))
def weighted_selection(self, size=None, n=1, pvals=None, ndim=None, dtype='int64', def multinomial_wo_replacement(self, size=None, n=1, pvals=None, ndim=None, dtype='int64',
nstreams=None): nstreams=None):
""" """
Sample `n` times *WITHOUT replacement* from a multinomial distribution Sample `n` times *WITHOUT replacement* from a multinomial distribution
...@@ -1394,21 +1394,21 @@ class MRG_RandomStreams(object): ...@@ -1394,21 +1394,21 @@ class MRG_RandomStreams(object):
if size is not None: if size is not None:
raise ValueError("Provided a size argument to " raise ValueError("Provided a size argument to "
"MRG_RandomStreams.weighted_selection, which does not use " "MRG_RandomStreams.multinomial_wo_replacement, which does not use "
"the size argument.") "the size argument.")
if ndim is not None: if ndim is not None:
raise ValueError("Provided an ndim argument to " raise ValueError("Provided an ndim argument to "
"MRG_RandomStreams.weighted_selection, which does not use " "MRG_RandomStreams.multinomial_wo_replacement, which does not use "
"the ndim argument.") "the ndim argument.")
if pvals.ndim == 2: if pvals.ndim == 2:
# size = [pvals.shape[0], as_tensor_variable(n)] # size = [pvals.shape[0], as_tensor_variable(n)]
size = pvals[:,0].shape * n size = pvals[:,0].shape * n
unis = self.uniform(size=size, ndim=1, nstreams=nstreams) unis = self.uniform(size=size, ndim=1, nstreams=nstreams)
op = multinomial.WeightedSelectionFromUniform(dtype) op = multinomial.MultinomialWOReplacementFromUniform(dtype)
n_samples = as_tensor_variable(n) n_samples = as_tensor_variable(n)
return op(pvals, unis, n_samples) return op(pvals, unis, n_samples)
else: else:
raise NotImplementedError(("MRG_RandomStreams.weighted_selection only" raise NotImplementedError(("MRG_RandomStreams.multinomial_wo_replacement only"
" implemented for pvals.ndim = 2")) " implemented for pvals.ndim = 2"))
def normal(self, size, avg=0.0, std=1.0, ndim=None, def normal(self, size, avg=0.0, std=1.0, ndim=None,
......
...@@ -9,12 +9,12 @@ class test_OP(unittest.TestCase): ...@@ -9,12 +9,12 @@ class test_OP(unittest.TestCase):
def test_select_distinct(self): def test_select_distinct(self):
""" """
Tests that WeightedSelectionFromUniform always selects distinct elements Tests that MultinomialWOReplacementFromUniform always selects distinct elements
""" """
p = tensor.fmatrix() p = tensor.fmatrix()
u = tensor.fvector() u = tensor.fvector()
n = tensor.iscalar() n = tensor.iscalar()
m = multinomial.WeightedSelectionFromUniform('auto')(p, u, n) m = multinomial.MultinomialWOReplacementFromUniform('auto')(p, u, n)
f = function([p, u, n], m, allow_input_downcast=True) f = function([p, u, n], m, allow_input_downcast=True)
...@@ -32,13 +32,13 @@ class test_OP(unittest.TestCase): ...@@ -32,13 +32,13 @@ class test_OP(unittest.TestCase):
def test_fail_select_alot(self): def test_fail_select_alot(self):
""" """
Tests that WeightedSelectionFromUniform fails when asked to sample more Tests that MultinomialWOReplacementFromUniform fails when asked to sample more
elements than the actual number of elements elements than the actual number of elements
""" """
p = tensor.fmatrix() p = tensor.fmatrix()
u = tensor.fvector() u = tensor.fvector()
n = tensor.iscalar() n = tensor.iscalar()
m = multinomial.WeightedSelectionFromUniform('auto')(p, u, n) m = multinomial.MultinomialWOReplacementFromUniform('auto')(p, u, n)
f = function([p, u, n], m, allow_input_downcast=True) f = function([p, u, n], m, allow_input_downcast=True)
...@@ -52,13 +52,13 @@ class test_OP(unittest.TestCase): ...@@ -52,13 +52,13 @@ class test_OP(unittest.TestCase):
def test_select_proportional_to_weight(self): def test_select_proportional_to_weight(self):
""" """
Tests that WeightedSelectionFromUniform selects elements, on average, Tests that MultinomialWOReplacementFromUniform selects elements, on average,
proportional to the their probabilities proportional to the their probabilities
""" """
p = tensor.fmatrix() p = tensor.fmatrix()
u = tensor.fvector() u = tensor.fvector()
n = tensor.iscalar() n = tensor.iscalar()
m = multinomial.WeightedSelectionFromUniform('auto')(p, u, n) m = multinomial.MultinomialWOReplacementFromUniform('auto')(p, u, n)
f = function([p, u, n], m, allow_input_downcast=True) f = function([p, u, n], m, allow_input_downcast=True)
...@@ -83,13 +83,13 @@ class test_function(unittest.TestCase): ...@@ -83,13 +83,13 @@ class test_function(unittest.TestCase):
def test_select_distinct(self): def test_select_distinct(self):
""" """
Tests that weighted_selection always selects distinct elements Tests that multinomial_wo_replacement always selects distinct elements
""" """
th_rng = RandomStreams(12345) th_rng = RandomStreams(12345)
p = tensor.fmatrix() p = tensor.fmatrix()
n = tensor.iscalar() n = tensor.iscalar()
m = th_rng.weighted_selection(pvals=p, n=n) m = th_rng.multinomial_wo_replacement(pvals=p, n=n)
f = function([p, n], m, allow_input_downcast=True) f = function([p, n], m, allow_input_downcast=True)
...@@ -106,14 +106,14 @@ class test_function(unittest.TestCase): ...@@ -106,14 +106,14 @@ class test_function(unittest.TestCase):
def test_fail_select_alot(self): def test_fail_select_alot(self):
""" """
Tests that weighted_selection fails when asked to sample more Tests that multinomial_wo_replacement fails when asked to sample more
elements than the actual number of elements elements than the actual number of elements
""" """
th_rng = RandomStreams(12345) th_rng = RandomStreams(12345)
p = tensor.fmatrix() p = tensor.fmatrix()
n = tensor.iscalar() n = tensor.iscalar()
m = th_rng.weighted_selection(pvals=p, n=n) m = th_rng.multinomial_wo_replacement(pvals=p, n=n)
f = function([p, n], m, allow_input_downcast=True) f = function([p, n], m, allow_input_downcast=True)
...@@ -126,14 +126,14 @@ class test_function(unittest.TestCase): ...@@ -126,14 +126,14 @@ class test_function(unittest.TestCase):
def test_select_proportional_to_weight(self): def test_select_proportional_to_weight(self):
""" """
Tests that weighted_selection selects elements, on average, Tests that multinomial_wo_replacement selects elements, on average,
proportional to the their probabilities proportional to the their probabilities
""" """
th_rng = RandomStreams(12345) th_rng = RandomStreams(12345)
p = tensor.fmatrix() p = tensor.fmatrix()
n = tensor.iscalar() n = tensor.iscalar()
m = th_rng.weighted_selection(pvals=p, n=n) m = th_rng.multinomial_wo_replacement(pvals=p, n=n)
f = function([p, n], m, allow_input_downcast=True) f = function([p, n], m, allow_input_downcast=True)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论