提交 6c88dd3d authored 作者: Frédéric Bastien's avatar Frédéric Bastien 提交者: GitHub

Merge pull request #6355 from yikangshen/MRG_add_target_parameter

MRG distribution, add a target parameter
......@@ -805,7 +805,7 @@ class MRG_RandomStreams(object):
return sample
def uniform(self, size, low=0.0, high=1.0, ndim=None, dtype=None,
nstreams=None):
nstreams=None, **kwargs):
# TODO : need description for parameter 'size', 'ndim', 'nstreams'
"""
Sample a tensor of given size whose element from a uniform
......@@ -865,7 +865,12 @@ class MRG_RandomStreams(object):
nstreams = self.n_streams(size)
rstates = self.get_substream_rstates(nstreams, dtype)
node_rstate = shared(rstates)
d = {}
if 'target' in kwargs:
d = dict(target=kwargs.pop('target'))
if len(kwargs) > 0:
raise TypeError("uniform() got unexpected keyword arguements %s" % (str(kwargs.keys())))
node_rstate = shared(rstates, **d)
u = self.pretty_return(node_rstate,
*mrg_uniform.new(node_rstate,
ndim, dtype, size),
......@@ -883,17 +888,17 @@ class MRG_RandomStreams(object):
return r
def binomial(self, size=None, n=1, p=0.5, ndim=None, dtype='int64',
nstreams=None):
nstreams=None, **kwargs):
# TODO : need description for method, parameter and return
if n == 1:
p = undefined_grad(as_tensor_variable(p))
x = self.uniform(size=size, nstreams=nstreams)
x = self.uniform(size=size, nstreams=nstreams, **kwargs)
return cast(x < p, dtype)
else:
raise NotImplementedError("MRG_RandomStreams.binomial with n > 1")
def multinomial(self, size=None, n=1, pvals=None, ndim=None, dtype='int64',
nstreams=None):
nstreams=None, **kwargs):
# TODO : need description for parameter and return
"""
Sample `n` (`n` needs to be >= 1, default 1) times from a multinomial
......@@ -935,7 +940,7 @@ class MRG_RandomStreams(object):
"which does not use the ndim argument.")
if pvals.ndim == 2:
size = pvals[:, 0].shape * n
unis = self.uniform(size=size, ndim=1, nstreams=nstreams)
unis = self.uniform(size=size, ndim=1, nstreams=nstreams, **kwargs)
op = multinomial.MultinomialFromUniform(dtype)
n_samples = as_tensor_variable(n)
return op(pvals, unis, n_samples)
......@@ -944,7 +949,7 @@ class MRG_RandomStreams(object):
" implemented for pvals.ndim = 2"))
def choice(self, size=1, a=None, replace=True, p=None, ndim=None,
dtype='int64', nstreams=None):
dtype='int64', nstreams=None, **kwargs):
"""
Sample `size` times from a multinomial distribution defined by
probabilities `p`, and returns the indices of the sampled elements.
......@@ -1011,18 +1016,18 @@ class MRG_RandomStreams(object):
"MRG_RandomStreams.choice is only implemented for p.ndim = 2")
shape = p[:, 0].shape * size
unis = self.uniform(size=shape, ndim=1, nstreams=nstreams)
unis = self.uniform(size=shape, ndim=1, nstreams=nstreams, **kwargs)
op = multinomial.ChoiceFromUniform(odtype=dtype)
return op(p, unis, as_tensor_variable(size))
def multinomial_wo_replacement(self, size=None, n=1, pvals=None,
ndim=None, dtype='int64', nstreams=None):
ndim=None, dtype='int64', nstreams=None, **kwargs):
warnings.warn('MRG_RandomStreams.multinomial_wo_replacement() is '
'deprecated and will be removed in the next release of '
'Theano. Please use MRG_RandomStreams.choice() instead.')
assert size is None
return self.choice(size=n, a=None, replace=False, p=pvals,
dtype=dtype, nstreams=nstreams, ndim=ndim)
dtype=dtype, nstreams=nstreams, ndim=ndim, **kwargs)
def normal(self, size, avg=0.0, std=1.0, ndim=None,
dtype=None, nstreams=None):
......
......@@ -753,6 +753,21 @@ def test_f16_nonzero(mode=None, op_to_check=rng_mrg.mrg_uniform):
assert np.all((0 < m_val) & (m_val < 1))
def test_target_parameter():
srng = MRG_RandomStreams()
pvals = np.array([[.98, .01, .01], [.01, .49, .50]])
def basic_target_parameter_test(x):
f = theano.function([], x)
assert isinstance(f(), np.ndarray)
basic_target_parameter_test(srng.uniform((3, 2), target='cpu'))
basic_target_parameter_test(srng.binomial((3, 2), target='cpu'))
basic_target_parameter_test(srng.multinomial(pvals=pvals.astype('float32'), target='cpu'))
basic_target_parameter_test(srng.choice(p=pvals.astype('float32'), replace=False, target='cpu'))
basic_target_parameter_test(srng.multinomial_wo_replacement(pvals=pvals.astype('float32'), target='cpu'))
if __name__ == "__main__":
rng = MRG_RandomStreams(np.random.randint(2147462579))
print(theano.__file__)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论