added support class update_gradient_via_grad

上级 53e523b0
...@@ -144,6 +144,35 @@ def grad(cost, param=None, cost_grad = 1.0): ...@@ -144,6 +144,35 @@ def grad(cost, param=None, cost_grad = 1.0):
else: else:
return rval(param) return rval(param)
class update_gradient_via_grad:
"""Inherit from this class to add a convenient self.update_gradient function"""
def update_gradient(self, grad_d):
"""Call self.grad() and add the result to grad_d
This function is called by grad.Grad.bprop() to construct a symbolic gradient graph.
self.grad is called like this:
self.grad(*(self.inputs + [grad_d[output] for output in self.outputs]))
In general, grad() should return a list of PythonR instances whose
length matches that of self.inputs, and whose elements are the
gradients of self.inputs.
There is a (but often used) special feature in place to automatically
wrap the return value of grad() in a list if it is a PythonR instance
and the op is unary. This makes many grad implementations a little
cuter.
"""
inputgs = self.grad(*(self.inputs + [grad_d[output] for output in self.outputs]))
if len(self.inputs) == 1 and isinstance(inputgs, gof.PythonR):
inputgs = [inputgs]
else:
assert len(inputgs) == len(self.inputs)
for input, inputg in zip(self.inputs, inputgs):
grad_d.add(input, inputg)
# #
# UNIT TEST # UNIT TEST
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论