提交 0c116e86 authored 作者: James Bergstra's avatar James Bergstra

updated comments to grad functions

上级 ae0027ac
"""Driver for general gradient calculations."""
__docformat__ = "restructuredtext en"
import gof #, gof.variable import gof #, gof.variable
import numpy #for numeric_grad import numpy #for numeric_grad
...@@ -9,32 +13,34 @@ _msg_badlen = 'op.grad(...) returned wrong number of gradients' ...@@ -9,32 +13,34 @@ _msg_badlen = 'op.grad(...) returned wrong number of gradients'
def grad_sources_inputs(sources, graph_inputs): def grad_sources_inputs(sources, graph_inputs):
""" """
A gradient source is a pair (r, g_r), in which r is a variable, and g_r is a A gradient source is a pair (``r``, ``g_r``), in which ``r`` is a `Variable`, and ``g_r`` is a
variable that is a gradient wrt r. `Variable` that is a gradient wrt ``r``.
This function traverses the graph backward from the ``r`` sources,
calling ``op.grad(...)`` for all ops with some non-None gradient on an output.
The ``op.grad(...)`` functions are called like this:
This function traverses the graph backward from the 'r' sources, .. code-block:: python
calling L{Op.grad}(...) when it is provided by an L{Op}, and at least one of the op.grad(op.inputs[:], [total_gradient(v for v in op.outputs)])
outputs of the L{Op} has an associated gradient.
The L{Op.grad}(...) functions are called as such: This call to ``op.grad`` should return a list or tuple: one symbolic gradient per input.
op.grad( op.inputs[0], grad(op.outputs[0])) If ``op`` has a single input, then ``op.grad`` should return a list or tuple of length 1.
This function expects the L{Op.grad}(...) function to return the gradient For each input wrt to which ``op`` is not differentiable, it should return ``None`` instead
expression [variables] associated with the inputs of the L{Op}. The L{Op} should of a `Variable` instance.
return a list of variables corresponding to the gradients in the same order
as the inputs. If it has a single output it should return a list or tuple
of length 1.
For each input wrt to which an L{Op} is not differentiable, it should return If a source ``r`` receives a gradient from another source ``r2``, then the effective
None instead of a variable instance. gradient on ``r`` is the sum of both gradients.
@type sources: list :type sources: list of pairs of Variable: (v, gradient-on-v)
@param sources: gradient sources (explained below) :param sources: gradients to back-propagate using chain rule
@type graph_inputs: list :type graph_inputs: list of Variable
@param graph_inputs: variables considered to be constant :param graph_inputs: variables considered to be constant (do not backpropagate through
them)
@rtype: dictionary :rtype: dictionary whose keys and values are of type `Variable`
@return: dictionary mapping each variable necessary for a source to its gradient. :return: mapping from each Variable encountered in the backward traversal to its gradient.
""" """
gmap = {} gmap = {}
for (r, g_r) in sources: for (r, g_r) in sources:
......
...@@ -2389,20 +2389,24 @@ outer = Outer() ...@@ -2389,20 +2389,24 @@ outer = Outer()
def grad(cost, wrt, g_cost=None, consider_constant=[]): def grad(cost, wrt, g_cost=None, consider_constant=[]):
""" """
@type cost: L{Variable} :type cost: `Variable`
@type wrt: L{Variable} or list of L{Variable}s. :type wrt: `Variable` or list of `Variable`s.
@type g_cost: L{Variable} broadcastable to size of I{cost}, or None :type g_cost: `Variable` broadcastable to size of `cost`, or None
@param g_cost: an expression for the gradient through cost. The default is :param g_cost: an expression for the gradient through cost. The default is
{{{ones_like(cost)}}} ``ones_like(cost)``.
@param consider_constant: a list of expressions not to backpropagate through :param consider_constant: a list of expressions not to backpropagate through
@rtype: L{Variable} or list of L{Variable}s (depending upon I{wrt}) :rtype: `Variable` or list of `Variable`s (depending upon `wrt`)
@return: symbolic expression of gradient of I{cost} with respect to I{wrt}.
If I{wrt} is a list, then return a list containing the gradient of I{cost} wrt :return: symbolic expression of gradient of `cost` with respect to `wrt`.
each element of the list. If an element of I{wrt} is not differentiable If `wrt` is a list, then return a list containing the gradient of `cost` wrt
with respect to the output, then a L{TensorConstant} with an appropriate each element of the list. If an element of `wrt` is not differentiable
with respect to the output, then a `TensorConstant` with an appropriate
kind of zero is returned. kind of zero is returned.
This function is a wrapper around a the more general function
`theano.gradient.grad_sources_inputs``.
""" """
if not isinstance(cost, TensorVariable): if not isinstance(cost, TensorVariable):
raise TypeError('In tensor.grad(), cost argument should be a TensorVariable.', cost) raise TypeError('In tensor.grad(), cost argument should be a TensorVariable.', cost)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论