提交 6150143e authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Minor changes.

上级 1ccf33c5
......@@ -22,24 +22,24 @@ class Multinomial(Op):
self.__dict__.update(dct)
try:
self.odtype
except:
except AttributeError:
self.odtype='auto'
def make_node(self, pvals, unis):
pvals = T.as_tensor_variable(pvals)
unis = T.as_tensor_variable(unis)
if pvals.ndim != 2:
raise NotImplementedError('pvals ndim', pvals.ndim)
raise NotImplementedError('pvals ndim should be 2', pvals.ndim)
if unis.ndim != 1:
raise NotImplementedError('unis ndim', unis.ndim)
raise NotImplementedError('unis ndim should be 1', unis.ndim)
if self.odtype=='auto':
odtype = pvals.dtype
else:
odtype = self.odtype
return Apply(self, [pvals, unis], [T.matrix(dtype=odtype)])
def grad(self, ins, outs):
def grad(self, ins, outgrads):
pvals, unis = ins
(gz,) = outs
(gz,) = outgrads
return [None, None]
def c_code_cache_version(self):
......@@ -134,7 +134,7 @@ class GpuMultinomial(Multinomial):
else:
odtype = self.odtype
if odtype != pvals.dtype:
raise NotImplementedError()
raise NotImplementedError('GpuMultinomial works only if self.odtype == pvals.dtype', odtype, pvals.dtype)
return Apply(self, [pvals, unis], [pvals.type()])
def c_code_cache_version(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论