提交 f6537797 authored 作者: Frederic's avatar Frederic

cache OpFromGrad sub grad computation. This remove duplicate op creation.

上级 6d87357d
......@@ -84,18 +84,24 @@ class OpFromGraph(gof.Op):
# compute the right numerical value for the gradients but
# could fail to raise the disconnected inputs error in some
# cases.
gs = G.grad(cost=None, known_grads=dict(zip(self.outputs, output_grads)),
wrt=self.inputs, disconnected_inputs='ignore')
grad_ops = []
for g in gs:
if g is None:
grad_ops.append(lambda *args: None)
else:
# It is normal if some inputs are not needed in order
# to compute the gradient, so we ignore them.
grad_ops.append(OpFromGraph(self.inputs + output_grads,
[g],
on_unused_input='ignore'))
if hasattr(self, "grad_ops"):
grad_ops = self.grad_ops
else:
gs = G.grad(cost=None,
known_grads=dict(zip(self.outputs, output_grads)),
wrt=self.inputs, disconnected_inputs='ignore')
grad_ops = []
for g in gs:
if g is None:
grad_ops.append(lambda *args: None)
else:
# It is normal if some inputs are not needed in order
# to compute the gradient, so we ignore them.
grad_ops.append(OpFromGraph(self.inputs + output_grads,
[g],
on_unused_input='ignore'))
self.grad_ops = grad_ops
return [go(*(inputs + output_grads)) for go in grad_ops]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论