提交 d893ae53 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Remove the Multinomial Op to MultinomialFromUniform, that's what it does.

上级 b22677b0
......@@ -9,7 +9,8 @@ if cuda_available:
from theano.sandbox.cuda import CudaNdarrayType
from theano.sandbox.cuda.basic_ops import host_from_gpu, gpu_from_host
class Multinomial(Op):
class MultinomialFromUniform(Op):
'''Converts samples from a uniform into sample from a multinomial.'''
def __init__(self, odtype):
self.odtype=odtype
def __eq__(self, other):
......@@ -91,7 +92,7 @@ class Multinomial(Op):
const int nb_outcomes = %(pvals)s->dimensions[1];
//
// For each multinomials, loop over each possible outcome
// For each multinomial, loop over each possible outcome
//
for (int n = 0; n < nb_multi; ++n)
{
......@@ -117,10 +118,9 @@ class Multinomial(Op):
}
} // END NESTED SCOPE
""" % locals()
#multinomial = Multinomial()
class GpuMultinomial(Multinomial):
class GpuMultinomialFromUniform(MultinomialFromUniform):
def make_node(self, pvals, unis):
assert pvals.dtype == 'float32'
......@@ -134,7 +134,9 @@ class GpuMultinomial(Multinomial):
else:
odtype = self.odtype
if odtype != pvals.dtype:
raise NotImplementedError('GpuMultinomial works only if self.odtype == pvals.dtype', odtype, pvals.dtype)
raise NotImplementedError(
'GpuMultinomialFromUniform works only if '
'self.odtype == pvals.dtype', odtype, pvals.dtype)
return Apply(self, [pvals, unis], [pvals.type()])
def c_code_cache_version(self):
......@@ -281,11 +283,11 @@ class GpuMultinomial(Multinomial):
@local_optimizer()
def use_gpu_multinomial(node):
if node.op == multinomial:
if isinstance(node.op, MultinomialFromUniform):
p, u = node.inputs
m, = node.outputs
if p.dtype == u.dtype == m.dtype == 'float32':
gpu_op = GpuMultinomial(op.odtype)
gpu_op = GpuMultinomialFromUniform(op.odtype)
return [host_from_gpu(gpu_op(*[gpu_from_host(i) for i in node.inputs]))]
if cuda_enabled:#theano.config.device.startswith('gpu'):
register_specialize(use_gpu_multinomial)
......@@ -759,7 +759,7 @@ class MRG_RandomStreams(object):
assert ndim==1
bcast = bcast+(pvals.type.broadcastable[-1],)
unis = self.uniform(size=size, ndim=1)
op = multinomial.Multinomial(dtype)
op = multinomial.MultinomialFromUniform(dtype)
return op(pvals, unis)
else:
raise NotImplementedError(("MRG_RandomStreams.multinomial only"
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论