提交 a7d91eea authored 作者: Ian Goodfellow's avatar Ian Goodfellow

merged

.. _developer:
======================
Theano Design and Implementation Documentation
======================
.. toctree::
:maxdepth: 2
tensor
.. _tensor:
=======
Tensor
=======
This file describes the design of theano.tensor.
Elemwise grad and R_op
=================
Here's another straightforward example, though a bit more elaborate
than adding two numbers together. Let's say that you want to compute
the logistic curve, which is given by:
.. math::
s(x) = \frac{1}{1 + e^{-x}}
......@@ -78,7 +78,8 @@ Roughly in order of what you'll want to check out:
* :ref:`libdoc` -- Theano's functionality, module by module.
* :ref:`optimizations` -- Guide to Theano's graph optimizations.
* :ref:`extending` -- Learn to add a Type, Op, or graph optimization.
* :ref:`internal` -- How to maintaining Theano, LISA-specific tips, and more...
* :ref:`developer` -- Primarily of interest to developers of Theano
* :ref:`internal` -- How to maintain Theano, LISA-specific tips, and more...
* :ref:`release` -- How our release should work.
You can download the latest `PDF documentation <http://deeplearning.net/software/theano/theano.pdf>`_, rather than reading it online.
......
......@@ -36,7 +36,14 @@ def grad_sources_inputs(sources, graph_inputs, warn_type=True):
them)
:rtype: dictionary whose keys and values are of type `Variable`
:return: mapping from each Variable encountered in the backward traversal to its gradient.
:return: mapping from each Variable encountered in the backward traversal to the gradient with respect to that Variable.
It is assumed that there is some objective J shared between all members of
sources, so that for each v, gradient-on-v is the gradient of J with respect to v
"""
gmap = {}
for (r, g_r) in sources:
......@@ -78,7 +85,7 @@ def grad_sources_inputs(sources, graph_inputs, warn_type=True):
else:
new_input_arg.append(input)
input_arg = new_input_arg
#note that this function is not in a try-except block
# the rationale:
# If the op implements grad, then any exception should be passed to the
......@@ -93,8 +100,8 @@ def grad_sources_inputs(sources, graph_inputs, warn_type=True):
g_inputs = op_grad
assert isinstance(g_inputs, (list, tuple))
if len(g_inputs) != len(node.inputs):
raise ValueError(_msg_badlen,
node.op,
raise ValueError(_msg_badlen,
node.op,
len(g_inputs),
len(node.inputs))
for ii, (r, g_r) in enumerate(zip(node.inputs, g_inputs)):
......@@ -106,7 +113,7 @@ def grad_sources_inputs(sources, graph_inputs, warn_type=True):
node.op, g_r_type, ii, r_type))
if g_r and len(sources) == 1 and sources[0][0].name and r.name:
g_r.name = "(d%s/d%s)" % (sources[0][0].name, r.name)
if g_r is not None:
if g_r is not None:
assert r is not None
if r in gmap:
gmap[r] = gmap[r] + g_r
......@@ -125,3 +132,14 @@ def unimplemented_grad(op, x_pos, x):
"""
msg = '%s.grad not implemented for input %i'%(op, x_pos)
return Raise(msg=msg)(x)
class GradientUndefined(Exception): pass
def undefined_grad(op, x_pos, x):
msg = "Undefined gradient - do not use in computations"
exc = RuntimeError
return Raise(msg=msg, exc=exc)(x)
def grad(self, inputs, out_storage):
return [g_x0, undefined_grad(self, 1, inputs[1])]
......@@ -77,6 +77,16 @@ def constant(x):
class Scalar(Type):
"""
Internal class, should not be used by clients
Primarily used by tensor.elemwise and tensor.reduce
Analogous to TensorType, but for zero-dimensional objects
Maps directly to C primitives
TODO: refactor to be named ScalarType for consistency with TensorType
"""
def __init__(self, dtype):
if dtype == 'floatX':
dtype = config.floatX
......
......@@ -537,7 +537,10 @@ class Elemwise(Op):
def grad(self, inputs, ograds):
# Gradients (especially on the final costs) don't have to be symbolic
# e.g., ograds will be [ 1. ] if your objective is c and the output
# of the current apply node is c
ograds = map(as_tensor_variable, ograds)
scalar_inputs = [Scalar(dtype = t.type.dtype)() for t in inputs]
scalar_ograds = [Scalar(dtype = ograd.type.dtype)() for ograd in ograds]
scalar_igrads = self.scalar_op.grad(scalar_inputs, scalar_ograds)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论