提交 52bc1c46 authored 作者: Amjad Almahairi's avatar Amjad Almahairi

fix bug

上级 6d6dcb46
......@@ -19,7 +19,6 @@ from theano.tensor import (raw_random, TensorType, as_tensor_variable,
from theano.tensor import sqrt, log, sin, cos, join, prod
from theano.compile import optdb
from theano.gof import local_optimizer
from theano.scalar import constant
from . import multinomial
from theano.sandbox.cuda import cuda_available, cuda_enabled, GpuOp
......@@ -1318,12 +1317,11 @@ class MRG_RandomStreams(object):
def multinomial(self, size=None, n=1, pvals=None, ndim=None, dtype='int64',
nstreams=None):
"""
Sample `n` (currently `n` needs to be > 1) times from a multinomial
Sample `n` (`n` needs to be >= 1) times from a multinomial
distribution defined by probabilities pvals.
TODO: MODIFY_ME
Example : pvals = [[.98, .01, .01], [.01, .98, .01]] will
probably result in [[1,0,0],[0,1,0]].
Example : pvals = [[.98, .01, .01], [.01, .49, .50]] and n=2 will
probably result in [[2,0,0],[0,1,1]].
Notes
-----
......@@ -1355,7 +1353,7 @@ class MRG_RandomStreams(object):
"MRG_RandomStreams.multinomial, which does not use "
"the ndim argument.")
if pvals.ndim == 2:
size = pvals[:,0].shape
size = pvals[:,0].shape * n
unis = self.uniform(size=size, ndim=1, nstreams=nstreams)
op = multinomial.MultinomialFromUniform(dtype)
n_samples = as_tensor_variable(n)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论