提交 68db88d1 authored 作者: AdeB's avatar AdeB

change the mapping between multinomial_wo_replacement and choice

上级 e69d47e0
......@@ -1446,26 +1446,26 @@ class MRG_RandomStreams(object):
raise NotImplementedError(("MRG_RandomStreams.multinomial only"
" implemented for pvals.ndim = 2"))
def choice(self, size=None, a=2, replace=True, p=None, ndim=None,
def choice(self, size=1, a=None, replace=True, p=None, ndim=None,
dtype='int64', nstreams=None):
"""
Sample `size` times from a multinomial distribution defined by
probabilities `p` and values `a`.
probabilities `p`. Sampled values are between 0 and `p.shape[1]-1`.
Only sampling without replacement is implemented for now.
Parameters
----------
size: integer or None (default None)
the number of samples to be generated. If None, a single sample is
generated.
a: integer or 1d or 2d numpy array or theano tensor
the values of the samples. If a is an integer, the values are
generated between 0 and a-1.
p: 1d or 2d numpy array or theano tensor
the probabilities of the distribution associated to each value of
`a`. It should have the same shape as a (if a is an array/tensor).
size: integer or integer tensor (default 1)
The number of samples. It should be between 1 and `p.shape[1]-1`
a: None
For now, a should be None. This function will sample
values between 0 and `p.shape[1]-1`.
replace: bool (default True)
Whether the sample is with or without replacement.
Only replace=False is implemented for now.
p: 2d numpy array or theano tensor
the probabilities of the distribution, corresponding to values
0 to `p.shape[1]-1`.
Example : p = [[.98, .01, .01], [.01, .49, .50]] and size=1 will
probably result in [[0],[2]]. When setting size=2, this
......@@ -1473,17 +1473,15 @@ class MRG_RandomStreams(object):
Notes
-----
- `size` and `ndim` are only there keep the same signature as other
-`size` and `ndim` are only there keep the same signature as other
uniform, binomial, normal, etc.
- Does not do any value checking on pvals, i.e. there is no
-Does not do any value checking on pvals, i.e. there is no
check that the elements are non-negative, less than 1, or
sum to 1. passing pvals = [[-2., 2.]] will result in
sampling [[0, 0]]
- When `a` and `p` are tensors their shape should be the same.
- Only replace=False is implemented for now.
-Only replace=False is implemented for now.
"""
if replace:
......@@ -1491,40 +1489,37 @@ class MRG_RandomStreams(object):
"MRG_RandomStreams.choice only works without replacement "
"for now.")
if a is not None:
raise TypeError("For now, a has to be None in "
"MRG_RandomStreams.choice. Sampled values are "
"beween 0 and p.shape[1]-1")
if p is None:
raise TypeError("You have to specify p.")
raise TypeError("For now, p has to be specified in "
"MRG_RandomStreams.choice.")
p = as_tensor_variable(p)
if ndim is not None:
raise ValueError("ndim argument to "
"MRG_RandomStreams.multinomial_wo_replacement "
"MRG_RandomStreams.choice "
"is not used.")
if p.ndim == 1:
p = tensor.shape_padleft(p)
if p.ndim != 2:
raise NotImplementedError(
"MRG_RandomStreams.multinomial_wo_replacement only implemented"
" for p.ndim = 1 or p.ndim = 2")
"MRG_RandomStreams.choice is only implemented for p.ndim = 2")
shape = p[:, 0].shape * size
unis = self.uniform(size=shape, ndim=1, nstreams=nstreams)
op = multinomial.MultinomialWOReplacementFromUniform(dtype)
sampled_indices = op(p, unis, as_tensor_variable(size))
if isinstance(a, int):
return sampled_indices
a = tensor.as_tensor_variable(a)
return a[sampled_indices]
return op(p, unis, as_tensor_variable(size))
def multinomial_wo_replacement(self, size=None, n=1, pvals=None,
ndim=None, dtype='int64', nstreams=None):
warnings.warn('MRG_RandomStreams.multinomial_wo_replacement() is '
'deprecated and will be removed in the next release of '
'Theano. Please use MRG_RandomStreams.choice() instead.')
return self.choice(size=n, replace=False, p=pvals, dtype=dtype,
nstreams=nstreams, ndim=ndim)
return self.choice(size=n, a=None, replace=False, p=pvals,
dtype=dtype, nstreams=nstreams, ndim=ndim)
def normal(self, size, avg=0.0, std=1.0, ndim=None,
dtype=None, nstreams=None):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论