提交 f6537797 authored 作者: Frederic's avatar Frederic

cache OpFromGrad sub grad computation. This remove duplicate op creation.

上级 6d87357d
...@@ -84,8 +84,13 @@ class OpFromGraph(gof.Op): ...@@ -84,8 +84,13 @@ class OpFromGraph(gof.Op):
# compute the right numerical value for the gradients but # compute the right numerical value for the gradients but
# could fail to raise the disconnected inputs error in some # could fail to raise the disconnected inputs error in some
# cases. # cases.
gs = G.grad(cost=None, known_grads=dict(zip(self.outputs, output_grads)), if hasattr(self, "grad_ops"):
grad_ops = self.grad_ops
else:
gs = G.grad(cost=None,
known_grads=dict(zip(self.outputs, output_grads)),
wrt=self.inputs, disconnected_inputs='ignore') wrt=self.inputs, disconnected_inputs='ignore')
grad_ops = [] grad_ops = []
for g in gs: for g in gs:
if g is None: if g is None:
...@@ -96,6 +101,7 @@ class OpFromGraph(gof.Op): ...@@ -96,6 +101,7 @@ class OpFromGraph(gof.Op):
grad_ops.append(OpFromGraph(self.inputs + output_grads, grad_ops.append(OpFromGraph(self.inputs + output_grads,
[g], [g],
on_unused_input='ignore')) on_unused_input='ignore'))
self.grad_ops = grad_ops
return [go(*(inputs + output_grads)) for go in grad_ops] return [go(*(inputs + output_grads)) for go in grad_ops]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论