提交 4336f227 authored 作者: bergstrj@iro.umontreal.ca's avatar bergstrj@iro.umontreal.ca

gradient.py written and tested

差异被折叠。
import gof
import gof, gof.result
_msg_retNone = 'op.grad(...) returned None, consider returning [None]'
_msg_badlen = 'op.grad(...) returned wrong number of gradients'
def _unpack_result(lst):
if len(lst) > 1:
return lst
else
else:
return lst[0]
def _pack_result(arg):
if gof.result.is_result(arg): return [arg]
if isinstance(arg, gof.result.ResultBase):
return [arg]
else:
return arg
def grad_sources_inputs(sources, inputs):
def grad_sources_inputs(sources, graph_inputs):
"""Return a dictionary mapping each result necessary for a source to its gradient
sources - a list of gradient sources (explained below)
inputs - a list of results considered to be constant
graph_inputs - a list of results considered to be constant
A gradient source is a pair (r, g_r), in which r is a result, and g_r is a
result that is a gradient wrt r.
......@@ -49,33 +54,37 @@ def grad_sources_inputs(sources, inputs):
None instead of a result instance.
"""
gmap = {}
for (r, g_r) in self.sources:
for (r, g_r) in sources:
if g_r is not None:
if r in gmap:
gmap[r] = gmap[r] + dr
gmap[r] = gmap[r] + g_r
else:
gmap[r] = dr
gmap[r] = g_r
graph_outputs = gmap.keys()
if graph_inputs is None:
graph_inputs = gof.graph.inputs(graph_outputs)
outputs = gmap.keys()
for op in gof.graph.io_toposort(graph_inputs, graph_outputs).__reversed__():
g_outputs = [gmap.get(o,None) for o in op.outputs]
if inputs is None:
inputs = gof.graph.inputs(outputs)
#if all output gradients are None, continue
if all(map(lambda x:x is None, g_outputs)): continue
for op in gof.graph.io_toposort(inputs, outputs).__reversed__():
g_outputs = [gmap[o] for o in self.outputs]
if all(map(lambda x:x is None, g_outputs)):
continue
output_arg = unpack_singleton(g_outputs)
input_arg = unpack_singleton(op.inputs)
output_arg = _unpack_result(g_outputs)
input_arg = _unpack_result(op.inputs)
op_grad = op.grad(input_arg, output_arg)
if op_grad is None:
raise Exception('If you really mean for grad(...) to return None,
please return [None]', op.__class__)
g_inputs = pack_singleton(op_grad)
assert len(g_inputs) == len(op.inputs)
for r, g_r in zip(self.inputs, g_inputs):
raise ValueError(_msg_retNone, op.__class__)
g_inputs = _pack_result(op_grad)
if len(g_inputs) != len(op.inputs):
raise ValueError(_msg_badlen,
op.__class__,
len(g_inputs),
len(op.inputs))
for r, g_r in zip(op.inputs, g_inputs):
if g_r is not None:
if r in gmap:
gmap[r] = gmap[r] + g_r
......@@ -83,17 +92,16 @@ def grad_sources_inputs(sources, inputs):
gmap[r] = g_r
return gmap
def diff(cost, param):
def grad(cost, param):
"""Return symbolic expression of gradient of <cost> wrt <param>.
If <param> is a list, then return a list containing the gradient of cost wrt
each element of the list.
"""
inputs = gof.graph.inputs([cost])
gmap = grad_sources_inputs([(cost, 1.0)], inputs)
if isinstance(param, lst):
return [gmap[p] for p in param]
if isinstance(param, list):
return [gmap.get(p, None) for p in param]
else:
return gmap[param]
return gmap.get(param, None)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论