提交 9a0208da authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Replace dy with zeros if disconnected instead of dhy[-1].

上级 dd9de351
...@@ -2238,16 +2238,17 @@ class GpuDnnRNNOp(DnnBase): ...@@ -2238,16 +2238,17 @@ class GpuDnnRNNOp(DnnBase):
dcy = output_grads[3] if len(output_grads) == 4 else None dcy = output_grads[3] if len(output_grads) == 4 else None
# Since the op return two outputs which contain essentially # Since the op return two outputs which contain essentially
# the same information, the user will most likely only use one # the same information, the user will most likely only use one
# of them. This leads to the situation that the other is # of them. This leads to the situation that the other is
# considered "disconnected" by theano in the gradient. # considered "disconnected" by theano in the gradient.
# However we know that this isn't really the case so we fix it # However we know that this isn't really the case so we fix it
# up here. # here.
# If both dy and dhy are disconnected the fixup will fail, but # If all the ys are disconnected, then you get a boring
# that's ok as in that case we really are disconnected. # gradient instead of an error. But in that case you
# shouldn't call this method anyway.
if isinstance(dy.type, DisconnectedType): if isinstance(dy.type, DisconnectedType):
dy = as_gpuarray_variable(dhy[-1], dy = as_gpuarray_variable(y.zeros_like(),
context_name=dhy.type.context_name) context_name=y.type.context_name)
if isinstance(dhy.type, DisconnectedType): if isinstance(dhy.type, DisconnectedType):
dhy = None dhy = None
if dcy and isinstance(dcy.type, DisconnectedType): if dcy and isinstance(dcy.type, DisconnectedType):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论