Changed documentation.

上级 bf3df169
...@@ -20,14 +20,6 @@ def _pack_result(arg): ...@@ -20,14 +20,6 @@ def _pack_result(arg):
def grad_sources_inputs(sources, graph_inputs): def grad_sources_inputs(sources, graph_inputs):
""" """
@rtype: dictionary
@return: dictionary mapping each result necessary for a source to its gradient.
@type sources: list
@param sources: gradient sources (explained below)
@type graph_inputs: list
@param graph_inputs: results considered to be constant
A gradient source is a pair (r, g_r), in which r is a result, and g_r is a A gradient source is a pair (r, g_r), in which r is a result, and g_r is a
result that is a gradient wrt r. result that is a gradient wrt r.
...@@ -39,16 +31,16 @@ def grad_sources_inputs(sources, graph_inputs): ...@@ -39,16 +31,16 @@ def grad_sources_inputs(sources, graph_inputs):
convenience of the L{Op} implementer) depending on the number of inputs and convenience of the L{Op} implementer) depending on the number of inputs and
outputs. outputs.
If there is one input and one output: If there is one input and one output::
op.grad( op.inputs[0], grad(op.outputs[0])) op.grad( op.inputs[0], grad(op.outputs[0]))
If there are several inputs and one output: If there are several inputs and one output::
op.grad( op.inputs, grad(op.outputs[0])) op.grad( op.inputs, grad(op.outputs[0]))
If there is one input and several outputs: If there is one input and several outputs::
op.grad( op.inputs[0], [grad(o) for o in op.outputs[0]]) op.grad( op.inputs[0], [grad(o) for o in op.outputs[0]])
If there are multiple inputs and outputs: If there are multiple inputs and outputs::
op.grad( op.inputs, [grad(o) for o in op.outputs[0]]) op.grad( op.inputs, [grad(o) for o in op.outputs[0]])
This function expects the L{Op.grad}(...) function to return the gradient This function expects the L{Op.grad}(...) function to return the gradient
...@@ -60,6 +52,13 @@ def grad_sources_inputs(sources, graph_inputs): ...@@ -60,6 +52,13 @@ def grad_sources_inputs(sources, graph_inputs):
For each input wrt to which an L{Op} is not differentiable, it should return For each input wrt to which an L{Op} is not differentiable, it should return
None instead of a result instance. None instead of a result instance.
@type sources: list
@param sources: gradient sources (explained below)
@type graph_inputs: list
@param graph_inputs: results considered to be constant
@rtype: dictionary
@return: dictionary mapping each result necessary for a source to its gradient.
""" """
gmap = {} gmap = {}
for (r, g_r) in sources: for (r, g_r) in sources:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论