提交 c1d9d010 authored 作者: Olivier Breuleux's avatar Olivier Breuleux

fixed inaccuracy in SelfGrad

上级 403d94df
......@@ -126,10 +126,12 @@ class SelfGrad (UpdateGradient):
This class defines update_gradient (necessary for Grad.bprop) to call a
self.grad function like this:
if len(self.outputs) > 1:
self.grad(self.inputs, [grad_d[o] for o in self.outputs])
else
self.grad(self.inputs, grad_d[output[0]])
passed_inputs = self.inputs
if len(self.inputs) == 1: passed_inputs = passed_inputs[0]
passed_ograds = [grad_d[o] for o in self.outputs]
if len(self.outputs) == 1: passed_ograds = passed_ograds[0]
igrads = self.grad(passed_inputs, passed_ograds)
if len(self.inputs) == 1: igrads = [igrads]
self.grad() is an Abstract function, see its documentation for the
expected behaviour.
......@@ -139,15 +141,11 @@ class SelfGrad (UpdateGradient):
def update_gradient(self, grad_d):
#Call self.grad(inputs, output_gradients) and add the result to grad_d
if len(self.outputs) > 1:
inputgs = self.grad(self.inputs, [grad_d[o] for o in self.outputs])
else:
inputgs = self.grad(self.inputs, grad_d[self.outputs[0]])
if len(self.inputs) == 1 and is_result(inputgs):
inputgs = [inputgs]
else:
assert len(inputgs) == len(self.inputs)
inputgs = gof.utils.from_return_values(
self.grad(gof.utils.to_return_values(self.inputs),
gof.utils.to_return_values([grad_d[o] for o in self.outputs])))
assert len(inputgs) == len(self.inputs)
for input, inputgrad in zip(self.inputs, inputgs):
grad_d.add(input, inputgrad)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论