提交 f6537797 authored 作者: Frederic's avatar Frederic

cache OpFromGrad sub grad computation. This remove duplicate op creation.

上级 6d87357d
...@@ -84,18 +84,24 @@ class OpFromGraph(gof.Op): ...@@ -84,18 +84,24 @@ class OpFromGraph(gof.Op):
# compute the right numerical value for the gradients but # compute the right numerical value for the gradients but
# could fail to raise the disconnected inputs error in some # could fail to raise the disconnected inputs error in some
# cases. # cases.
gs = G.grad(cost=None, known_grads=dict(zip(self.outputs, output_grads)), if hasattr(self, "grad_ops"):
wrt=self.inputs, disconnected_inputs='ignore') grad_ops = self.grad_ops
grad_ops = [] else:
for g in gs: gs = G.grad(cost=None,
if g is None: known_grads=dict(zip(self.outputs, output_grads)),
grad_ops.append(lambda *args: None) wrt=self.inputs, disconnected_inputs='ignore')
else:
# It is normal if some inputs are not needed in order grad_ops = []
# to compute the gradient, so we ignore them. for g in gs:
grad_ops.append(OpFromGraph(self.inputs + output_grads, if g is None:
[g], grad_ops.append(lambda *args: None)
on_unused_input='ignore')) else:
# It is normal if some inputs are not needed in order
# to compute the gradient, so we ignore them.
grad_ops.append(OpFromGraph(self.inputs + output_grads,
[g],
on_unused_input='ignore'))
self.grad_ops = grad_ops
return [go(*(inputs + output_grads)) for go in grad_ops] return [go(*(inputs + output_grads)) for go in grad_ops]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论