提交 6150143e authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Minor changes.

上级 1ccf33c5
...@@ -22,24 +22,24 @@ class Multinomial(Op): ...@@ -22,24 +22,24 @@ class Multinomial(Op):
self.__dict__.update(dct) self.__dict__.update(dct)
try: try:
self.odtype self.odtype
except: except AttributeError:
self.odtype='auto' self.odtype='auto'
def make_node(self, pvals, unis): def make_node(self, pvals, unis):
pvals = T.as_tensor_variable(pvals) pvals = T.as_tensor_variable(pvals)
unis = T.as_tensor_variable(unis) unis = T.as_tensor_variable(unis)
if pvals.ndim != 2: if pvals.ndim != 2:
raise NotImplementedError('pvals ndim', pvals.ndim) raise NotImplementedError('pvals ndim should be 2', pvals.ndim)
if unis.ndim != 1: if unis.ndim != 1:
raise NotImplementedError('unis ndim', unis.ndim) raise NotImplementedError('unis ndim should be 1', unis.ndim)
if self.odtype=='auto': if self.odtype=='auto':
odtype = pvals.dtype odtype = pvals.dtype
else: else:
odtype = self.odtype odtype = self.odtype
return Apply(self, [pvals, unis], [T.matrix(dtype=odtype)]) return Apply(self, [pvals, unis], [T.matrix(dtype=odtype)])
def grad(self, ins, outs): def grad(self, ins, outgrads):
pvals, unis = ins pvals, unis = ins
(gz,) = outs (gz,) = outgrads
return [None, None] return [None, None]
def c_code_cache_version(self): def c_code_cache_version(self):
...@@ -134,7 +134,7 @@ class GpuMultinomial(Multinomial): ...@@ -134,7 +134,7 @@ class GpuMultinomial(Multinomial):
else: else:
odtype = self.odtype odtype = self.odtype
if odtype != pvals.dtype: if odtype != pvals.dtype:
raise NotImplementedError() raise NotImplementedError('GpuMultinomial works only if self.odtype == pvals.dtype', odtype, pvals.dtype)
return Apply(self, [pvals, unis], [pvals.type()]) return Apply(self, [pvals, unis], [pvals.type()])
def c_code_cache_version(self): def c_code_cache_version(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论