提交 5dc2f764 authored 作者: Frederic's avatar Frederic

Automatiacally have all the theano.gradient function generated.

Also merge duplicate documentation of grad_sources_inputs.
上级 6c76c8d3
...@@ -16,57 +16,5 @@ function does the underlying work, and is more flexible, but is also more ...@@ -16,57 +16,5 @@ function does the underlying work, and is more flexible, but is also more
awkward to use when :func:`tensor.grad` can do the job. awkward to use when :func:`tensor.grad` can do the job.
.. function:: grad_sources_inputs(sources, graph_inputs, warn_type=True) .. automodule:: theano.gradient
:members:
A gradient source is a pair (``v``, ``g_v``), in which ``v`` is
a `Variable`, and ``g_v`` is a `Variable` that is a gradient wrt
``v``. More specifically, ``g_v`` is the gradient of an external
scalar cost, ``cost`` (that is not explicitly used), wrt ``v``.
This function traverses the graph backward from the ``r`` sources,
calling ``op.grad(...)`` for all ops with some non-None gradient
on an output, to compute gradients of ``cost`` wrt intermediate
variables and ``graph_inputs``.
The ``op.grad(...)`` functions are called like this:
.. code-block:: python
op.grad(op.inputs[:], [total_gradient(v) for v in op.outputs])
This call to ``op.grad`` should return a list or tuple: one symbolic
gradient per input. These gradients represent the gradients of
the same implicit ``cost`` mentionned above, wrt ``op.inputs``. Note
that this is **not** the same as the gradient of ``op.outputs`` wrt
``op.inputs``.
If ``op`` has a single input, then ``op.grad`` should return a list
or tuple of length 1.
For each input wrt to which ``op`` is not differentiable, it should
return ``None`` instead of a `Variable` instance.
If a source ``r`` receives a gradient from another source ``r2``,
then the effective gradient on ``r`` is the sum of both gradients.
:type sources: list of pairs of Variable: (v, gradient-on-v) to
initialize the total_gradient dictionary
:param sources: gradients to back-propagate using chain rule
:param warn_type: True will trigger warnings via the logging module when
the gradient on an expression has a different type than the original
expression
:type warn_type: bool
:type graph_inputs: list of Variable
:param graph_inputs: variables considered to be constant
(do not backpropagate through them)
:rtype: dictionary whose keys and values are of type `Variable`
:returns: mapping from each Variable encountered in the backward traversal to its [total] gradient.
...@@ -58,14 +58,50 @@ def format_as(use_list, use_tuple, outputs): ...@@ -58,14 +58,50 @@ def format_as(use_list, use_tuple, outputs):
def grad_sources_inputs(sources, graph_inputs, warn_type=True): def grad_sources_inputs(sources, graph_inputs, warn_type=True):
""" """
:type sources: list of pairs of Variable: (v, gradient-on-v) A gradient source is a pair (``v``, ``g_v``), in which ``v`` is
a `Variable`, and ``g_v`` is a `Variable` that is a gradient wrt
``v``. More specifically, ``g_v`` is the gradient of an external
scalar cost, ``cost`` (that is not explicitly used), wrt ``v``.
This function traverses the graph backward from the ``r`` sources,
calling ``op.grad(...)`` for all ops with some non-None gradient
on an output, to compute gradients of ``cost`` wrt intermediate
variables and ``graph_inputs``.
The ``op.grad(...)`` functions are called like this:
.. code-block:: python
op.grad(op.inputs[:], [total_gradient(v) for v in op.outputs])
This call to ``op.grad`` should return a list or tuple: one symbolic
gradient per input. These gradients represent the gradients of
the same implicit ``cost`` mentionned above, wrt ``op.inputs``. Note
that this is **not** the same as the gradient of ``op.outputs`` wrt
``op.inputs``.
If ``op`` has a single input, then ``op.grad`` should return a list
or tuple of length 1.
For each input wrt to which ``op`` is not differentiable, it should
return ``None`` instead of a `Variable` instance.
If a source ``r`` receives a gradient from another source ``r2``,
then the effective gradient on ``r`` is the sum of both gradients.
:type sources: list of pairs of Variable: (v, gradient-on-v) to
initialize the total_gradient dictionary
:param sources: gradients to back-propagate using chain rule :param sources: gradients to back-propagate using chain rule
:type graph_inputs: list of Variable :type graph_inputs: list of Variable
:param graph_inputs: variables considered to be constant :param graph_inputs: variables considered to be constant
(do not backpropagate through them) (do not backpropagate through them)
:type warn_type: bool
:param warn_type: True will trigger warnings via the logging module when
the gradient on an expression has a different type than the original
expression
:rtype: dictionary whose keys and values are of type Variable :rtype: dictionary whose keys and values are of type Variable
:return: mapping from each Variable encountered in the backward :return: mapping from each Variable encountered in the backward
traversal to the gradient with respect to that Variable. traversal to the gradient with respect to that Variable.
...@@ -73,9 +109,6 @@ def grad_sources_inputs(sources, graph_inputs, warn_type=True): ...@@ -73,9 +109,6 @@ def grad_sources_inputs(sources, graph_inputs, warn_type=True):
sources, so that for each v, gradient-on-v is the gradient of J with sources, so that for each v, gradient-on-v is the gradient of J with
respect to v respect to v
""" """
gmap = {} gmap = {}
for (r, g_r) in sources: for (r, g_r) in sources:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论