提交 8f442150 authored 作者: Caglar's avatar Caglar

fixed the multi grad bug.

上级 b46e24dc
......@@ -52,7 +52,11 @@ class MultinomialFromUniform(Op):
def grad(self, ins, outgrads):
pvals, unis, n = ins
(gz,) = outgrads
return [T.zeros_like(x) for x in ins]
if x.dtype in T.discrete_dtypes:
return [T.zeros_like(x, dtype=theano.config.floatX) for x in ins]
else:
return [T.zeros_like(x) for x in ins]
def c_code_cache_version(self):
return (8,)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论