提交 52bc1c46 authored 作者: Amjad Almahairi's avatar Amjad Almahairi

fix bug

上级 6d6dcb46
...@@ -19,7 +19,6 @@ from theano.tensor import (raw_random, TensorType, as_tensor_variable, ...@@ -19,7 +19,6 @@ from theano.tensor import (raw_random, TensorType, as_tensor_variable,
from theano.tensor import sqrt, log, sin, cos, join, prod from theano.tensor import sqrt, log, sin, cos, join, prod
from theano.compile import optdb from theano.compile import optdb
from theano.gof import local_optimizer from theano.gof import local_optimizer
from theano.scalar import constant
from . import multinomial from . import multinomial
from theano.sandbox.cuda import cuda_available, cuda_enabled, GpuOp from theano.sandbox.cuda import cuda_available, cuda_enabled, GpuOp
...@@ -1318,12 +1317,11 @@ class MRG_RandomStreams(object): ...@@ -1318,12 +1317,11 @@ class MRG_RandomStreams(object):
def multinomial(self, size=None, n=1, pvals=None, ndim=None, dtype='int64', def multinomial(self, size=None, n=1, pvals=None, ndim=None, dtype='int64',
nstreams=None): nstreams=None):
""" """
Sample `n` (currently `n` needs to be > 1) times from a multinomial Sample `n` (`n` needs to be >= 1) times from a multinomial
distribution defined by probabilities pvals. distribution defined by probabilities pvals.
TODO: MODIFY_ME Example : pvals = [[.98, .01, .01], [.01, .49, .50]] and n=2 will
Example : pvals = [[.98, .01, .01], [.01, .98, .01]] will probably result in [[2,0,0],[0,1,1]].
probably result in [[1,0,0],[0,1,0]].
Notes Notes
----- -----
...@@ -1355,11 +1353,11 @@ class MRG_RandomStreams(object): ...@@ -1355,11 +1353,11 @@ class MRG_RandomStreams(object):
"MRG_RandomStreams.multinomial, which does not use " "MRG_RandomStreams.multinomial, which does not use "
"the ndim argument.") "the ndim argument.")
if pvals.ndim == 2: if pvals.ndim == 2:
size = pvals[:,0].shape size = pvals[:,0].shape * n
unis = self.uniform(size=size, ndim=1, nstreams=nstreams) unis = self.uniform(size=size, ndim=1, nstreams=nstreams)
op = multinomial.MultinomialFromUniform(dtype) op = multinomial.MultinomialFromUniform(dtype)
n_samples = as_tensor_variable(n) n_samples = as_tensor_variable(n)
return op(pvals, unis, n_samples) return op(pvals, unis, n_samples)
else: else:
raise NotImplementedError(("MRG_RandomStreams.multinomial only" raise NotImplementedError(("MRG_RandomStreams.multinomial only"
" implemented for pvals.ndim = 2")) " implemented for pvals.ndim = 2"))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论