提交 afffb65c authored 作者: Frederic's avatar Frederic

Fix EighGrad that didn't returned the good dtype when the input is float32.

This was an error in the buildbot.
上级 93645004
...@@ -1033,7 +1033,18 @@ class EighGrad(Op): ...@@ -1033,7 +1033,18 @@ class EighGrad(Op):
def make_node(self, x, w, v, gw, gv): def make_node(self, x, w, v, gw, gv):
x, w, v, gw, gv = map(as_tensor_variable, (x, w, v, gw, gv)) x, w, v, gw, gv = map(as_tensor_variable, (x, w, v, gw, gv))
return Apply(self, [x, w, v, gw, gv], [x.type()]) assert x.ndim == 2
assert w.ndim == 1
assert v.ndim == 2
assert gw.ndim == 1
assert gv.ndim == 2
if x.dtype == "float32":
# The call to self.tri0 in perform upcast from float32 to
# float64.
out = theano.tensor.matrix(dtype="float64")
else:
out = x.type()
return Apply(self, [x, w, v, gw, gv], [out])
def perform(self, node, inputs, outputs): def perform(self, node, inputs, outputs):
r""" r"""
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论