提交 0c116e86 authored 作者: James Bergstra's avatar James Bergstra

updated comments to grad functions

上级 ae0027ac
"""Driver for general gradient calculations."""
__docformat__ = "restructuredtext en"
import gof #, gof.variable
import numpy #for numeric_grad
......@@ -9,32 +13,34 @@ _msg_badlen = 'op.grad(...) returned wrong number of gradients'
def grad_sources_inputs(sources, graph_inputs):
"""
A gradient source is a pair (r, g_r), in which r is a variable, and g_r is a
variable that is a gradient wrt r.
A gradient source is a pair (``r``, ``g_r``), in which ``r`` is a `Variable`, and ``g_r`` is a
`Variable` that is a gradient wrt ``r``.
This function traverses the graph backward from the ``r`` sources,
calling ``op.grad(...)`` for all ops with some non-None gradient on an output.
The ``op.grad(...)`` functions are called like this:
This function traverses the graph backward from the 'r' sources,
calling L{Op.grad}(...) when it is provided by an L{Op}, and at least one of the
outputs of the L{Op} has an associated gradient.
.. code-block:: python
op.grad(op.inputs[:], [total_gradient(v for v in op.outputs)])
The L{Op.grad}(...) functions are called as such:
op.grad( op.inputs[0], grad(op.outputs[0]))
This call to ``op.grad`` should return a list or tuple: one symbolic gradient per input.
If ``op`` has a single input, then ``op.grad`` should return a list or tuple of length 1.
This function expects the L{Op.grad}(...) function to return the gradient
expression [variables] associated with the inputs of the L{Op}. The L{Op} should
return a list of variables corresponding to the gradients in the same order
as the inputs. If it has a single output it should return a list or tuple
of length 1.
For each input wrt to which ``op`` is not differentiable, it should return ``None`` instead
of a `Variable` instance.
For each input wrt to which an L{Op} is not differentiable, it should return
None instead of a variable instance.
If a source ``r`` receives a gradient from another source ``r2``, then the effective
gradient on ``r`` is the sum of both gradients.
@type sources: list
@param sources: gradient sources (explained below)
@type graph_inputs: list
@param graph_inputs: variables considered to be constant
:type sources: list of pairs of Variable: (v, gradient-on-v)
:param sources: gradients to back-propagate using chain rule
:type graph_inputs: list of Variable
:param graph_inputs: variables considered to be constant (do not backpropagate through
them)
@rtype: dictionary
@return: dictionary mapping each variable necessary for a source to its gradient.
:rtype: dictionary whose keys and values are of type `Variable`
:return: mapping from each Variable encountered in the backward traversal to its gradient.
"""
gmap = {}
for (r, g_r) in sources:
......
......@@ -2389,20 +2389,24 @@ outer = Outer()
def grad(cost, wrt, g_cost=None, consider_constant=[]):
"""
@type cost: L{Variable}
@type wrt: L{Variable} or list of L{Variable}s.
@type g_cost: L{Variable} broadcastable to size of I{cost}, or None
@param g_cost: an expression for the gradient through cost. The default is
{{{ones_like(cost)}}}
@param consider_constant: a list of expressions not to backpropagate through
@rtype: L{Variable} or list of L{Variable}s (depending upon I{wrt})
@return: symbolic expression of gradient of I{cost} with respect to I{wrt}.
If I{wrt} is a list, then return a list containing the gradient of I{cost} wrt
each element of the list. If an element of I{wrt} is not differentiable
with respect to the output, then a L{TensorConstant} with an appropriate
:type cost: `Variable`
:type wrt: `Variable` or list of `Variable`s.
:type g_cost: `Variable` broadcastable to size of `cost`, or None
:param g_cost: an expression for the gradient through cost. The default is
``ones_like(cost)``.
:param consider_constant: a list of expressions not to backpropagate through
:rtype: `Variable` or list of `Variable`s (depending upon `wrt`)
:return: symbolic expression of gradient of `cost` with respect to `wrt`.
If `wrt` is a list, then return a list containing the gradient of `cost` wrt
each element of the list. If an element of `wrt` is not differentiable
with respect to the output, then a `TensorConstant` with an appropriate
kind of zero is returned.
This function is a wrapper around a the more general function
`theano.gradient.grad_sources_inputs``.
"""
if not isinstance(cost, TensorVariable):
raise TypeError('In tensor.grad(), cost argument should be a TensorVariable.', cost)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论