提交 896cd396 authored 作者: Ian Goodfellow's avatar Ian Goodfellow

fixed OpFromGraph's handling of disconnected inputs

上级 9fdede03
...@@ -55,8 +55,12 @@ class OpFromGraph(gof.Op): ...@@ -55,8 +55,12 @@ class OpFromGraph(gof.Op):
if grad_depth > 0: if grad_depth > 0:
output_grads = [t() for t in self.output_types] output_grads = [t() for t in self.output_types]
# OpFromGraph doesn't implement a connection_pattern, so for now we regard
# all inputs and outputs as connected. This will compute the right numerical
# value for the gradients but could fail to raise the disconnected inputs error
# in some cases.
gs = G.grad(cost=None, known_grads=dict(zip(self.outputs, output_grads)), gs = G.grad(cost=None, known_grads=dict(zip(self.outputs, output_grads)),
wrt=self.inputs) wrt=self.inputs, disconnected_inputs='ignore')
self.grad_ops = [] self.grad_ops = []
for g in gs: for g in gs:
if g is None: if g is None:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论