提交 c1d9d010 authored 作者: Olivier Breuleux's avatar Olivier Breuleux

fixed inaccuracy in SelfGrad

上级 403d94df
...@@ -126,10 +126,12 @@ class SelfGrad (UpdateGradient): ...@@ -126,10 +126,12 @@ class SelfGrad (UpdateGradient):
This class defines update_gradient (necessary for Grad.bprop) to call a This class defines update_gradient (necessary for Grad.bprop) to call a
self.grad function like this: self.grad function like this:
if len(self.outputs) > 1: passed_inputs = self.inputs
self.grad(self.inputs, [grad_d[o] for o in self.outputs]) if len(self.inputs) == 1: passed_inputs = passed_inputs[0]
else passed_ograds = [grad_d[o] for o in self.outputs]
self.grad(self.inputs, grad_d[output[0]]) if len(self.outputs) == 1: passed_ograds = passed_ograds[0]
igrads = self.grad(passed_inputs, passed_ograds)
if len(self.inputs) == 1: igrads = [igrads]
self.grad() is an Abstract function, see its documentation for the self.grad() is an Abstract function, see its documentation for the
expected behaviour. expected behaviour.
...@@ -139,15 +141,11 @@ class SelfGrad (UpdateGradient): ...@@ -139,15 +141,11 @@ class SelfGrad (UpdateGradient):
def update_gradient(self, grad_d): def update_gradient(self, grad_d):
#Call self.grad(inputs, output_gradients) and add the result to grad_d #Call self.grad(inputs, output_gradients) and add the result to grad_d
if len(self.outputs) > 1: inputgs = gof.utils.from_return_values(
inputgs = self.grad(self.inputs, [grad_d[o] for o in self.outputs]) self.grad(gof.utils.to_return_values(self.inputs),
else: gof.utils.to_return_values([grad_d[o] for o in self.outputs])))
inputgs = self.grad(self.inputs, grad_d[self.outputs[0]])
if len(self.inputs) == 1 and is_result(inputgs):
inputgs = [inputgs]
else:
assert len(inputgs) == len(self.inputs) assert len(inputgs) == len(self.inputs)
for input, inputgrad in zip(self.inputs, inputgs): for input, inputgrad in zip(self.inputs, inputgs):
grad_d.add(input, inputgrad) grad_d.add(input, inputgrad)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论