提交 75e573d1 authored 作者: Iban Harlouchet's avatar Iban Harlouchet

numpydoc for theano/gof/fg.py

上级 18c54eca
"""
fg.py: fg stands for FunctionGraph
Contains the FunctionGraph class and exception
types that it can raise
types that it can raise.
"""
from __future__ import print_function
import sys
......@@ -23,10 +24,13 @@ NullType = None
class CachedConstantError(Exception):
"""An exception thrown when we put in a FunctionGraph a Constant
that is cached. This should not happen as the user can reuse this
"""
An exception thrown when we put in a FunctionGraph a Constant
that is cached. This should not happen as the user can reuse this
cached constant in other FunctionGraph.
"""
pass
......@@ -34,24 +38,28 @@ class InconsistencyError(Exception):
"""
This exception should be thrown by listeners to FunctionGraph when the
graph's state is invalid.
"""
pass
class MissingInputError(Exception):
"""
A symbolic input needed to compute the outputs is missing.
"""
pass
class FunctionGraph(utils.object2):
""" WRITEME
A FunctionGraph represents a subgraph bound by a set of input variables and a
set of output variables, ie a subgraph that specifies a theano function.
The inputs list should contain all the inputs
on which the outputs depend. Variables of type Constant are
not counted as inputs.
"""
WRITEME
A FunctionGraph represents a subgraph bound by a set of input variables and
a set of output variables, ie a subgraph that specifies a theano function.
The inputs list should contain all the inputs on which the outputs depend.
Variables of type Constant are not counted as inputs.
The FunctionGraph supports the replace operation which allows to replace a
variable in the subgraph by another, e.g. replace (x + x).out by (2
......@@ -74,28 +82,35 @@ class FunctionGraph(utils.object2):
Historically, the FunctionGraph was called an Env. Keep this in mind
while reading out-of-date documentation, e-mail support threads, etc.
"""
The constructor creates a FunctionGraph which operates on the subgraph
bound by the inputs and outputs sets.
def __init__(self, inputs, outputs, features=None, clone=True):
"""
Create an FunctionGraph which operates on the subgraph bound by the inputs and
outputs sets.
This class keeps a pointer to the inputs and outputs, and also modifies
them.
This class keeps a pointer to the inputs and outputs, and also modifies
them.
#TODO: document what variables are[not] set in the FunctionGraph when a
feature is added via the constructor. How constructed is the
FunctionGraph?
#TODO: document what variables are[not] set in the FunctionGraph when a feature
is added via the constructor. How constructed is the FunctionGraph?
Parameters
----------
inputs
Inputs nodes of the graph, usually declared by the user.
outputs
Outputs nodes of the graph.
clone
If true, we will clone the graph. This is useful to remove the constant
cache problem.
Note: the intermediate nodes between 'inputs' and 'outputs' are not explicitely
passed.
Notes
-----
The intermediate nodes between 'inputs' and 'outputs' are not explicitely
passed.
:param inputs: inputs nodes of the graph, usually declared by the user
:param outputs: outputs nodes of the graph.
:param clone: If true, we will clone the graph. This is
useful to remove the constant cache problem.
"""
def __init__(self, inputs, outputs, features=None, clone=True):
"""
if clone:
inputs, outputs = graph.clone(inputs, outputs)
......@@ -180,15 +195,17 @@ class FunctionGraph(utils.object2):
# self.execute_callbacks('on_setup_node', node)
def disown(self):
""" WRITEME
Cleans up all of this FunctionGraph's nodes and variables so they are not
associated with this FunctionGraph anymore.
"""
WRITEME
Cleans up all of this FunctionGraph's nodes and variables so they are
not associated with this FunctionGraph anymore.
The FunctionGraph should not be used anymore after disown is called.
This may not clean everything this FunctionGraph's features set in the
nodes and variables. If there are no features, this should set
them back to what they were originally.
"""
for apply_node in self.apply_nodes:
del apply_node.fgraph
......@@ -205,18 +222,25 @@ class FunctionGraph(utils.object2):
def clients(self, r):
"""
Set of all the (node, i) pairs such that node.inputs[i] is r.
Tell differently, a list of (node,i) such that each node have
Told differently, a list of (node,i) such that each node have
r as input at index i.
"""
return r.clients
def __add_clients__(self, r, new_clients):
""" WRITEME
r -> variable
new_clients -> list of (node, i) pairs such that node.inputs[i] is r.
"""
Updates the list of clients of r with new_clients.
WRITEME
Parameters
----------
r
Variable.
new_clients
List of (node, i) pairs such that node.inputs[i] is r.
"""
if set(r.clients).intersection(set(new_clients)):
print('ERROR: clients intersect!', file=sys.stderr)
......@@ -229,11 +253,18 @@ class FunctionGraph(utils.object2):
def __remove_clients__(self, r, clients_to_remove,
prune=True, reason=None):
""" WRITEME
r -> variable
clients_to_remove -> list of (op, i) pairs such that node.inputs[i] is not r anymore.
"""
Removes all from the clients list of r.
WRITEME
Parameters
----------
r
Variable.
clients_to_remove
List of (op, i) pairs such that node.inputs[i] is not r anymore.
"""
for entry in clients_to_remove:
r.clients.remove(entry)
......@@ -286,11 +317,14 @@ class FunctionGraph(utils.object2):
if config.exception_verbosity == 'high':
def find_path_to(output_var, input_var):
""" Returns a list of each variable on a (not necessarily unique)
path from input_var to output_var, where each variable in the
list has the preceding variable as one of its inputs.
Returns None if no path exists"""
"""
Returns a list of each variable on a (not
necessarily unique) path from input_var to
output_var, where each variable in the list has
the preceding variable as one of its inputs.
Returns None if no path exists.
"""
# If output and input are the same we have a singleton path
if output_var is input_var:
return [output_var]
......@@ -376,12 +410,13 @@ class FunctionGraph(utils.object2):
# prune #
def __prune_r__(self, variable, reason=None):
"""Should be called for variable that aren't used anymore:
len(var.clients) == 0
"""
Should be called for variable that aren't used anymore:
len(var.clients) == 0.
This do not mean we will remove it from fgraph.variables. If
the owner stay in the fgraph as other outputs are still used,
the variable will be stay in fgraph.variables.
the variable will stay in fgraph.variables.
"""
# Prunes the owners of the variables.
......@@ -409,7 +444,8 @@ class FunctionGraph(utils.object2):
del variable.fgraph
def __prune__(self, apply_node, reason=None):
"""Always called on owner of pruned variable from the graph.
"""
Always called on owner of pruned variable from the graph.
This do not mean we will remove it from the graph. If other
outputs are still used, we will keep the node in the graph.
......@@ -433,14 +469,17 @@ class FunctionGraph(utils.object2):
# change input #
def change_input(self, node, i, new_r, reason=None):
"""WRITEME
"""
Changes node.inputs[i] to new_r.
WRITEME
new_r.type == old_r.type must be True, where old_r is the
current value of node.inputs[i] which we want to replace.
For each feature that has a 'on_change_input' method, calls:
feature.on_change_input(function_graph, node, i, old_r, new_r, reason)
feature.on_change_input(function_graph, node, i, old_r, new_r, reason)
"""
# TODO: ERROR HANDLING FOR LISTENERS (should it complete the change or revert it?)
if node == 'output':
......@@ -478,9 +517,12 @@ class FunctionGraph(utils.object2):
# replace #
def replace(self, r, new_r, reason=None, verbose=None):
""" WRITEME
"""
WRITEME
This is the main interface to manipulate the subgraph in FunctionGraph.
For every node that uses r as input, makes it use new_r instead.
"""
if verbose is None:
verbose = config.optimizer_verbose
......@@ -532,16 +574,19 @@ class FunctionGraph(utils.object2):
# print >> sys.stderr, "WARNING: CLIENTS LEFT AFTER REPLACE", r, r.clients
def replace_all(self, pairs, reason=None):
"""WRITEME"""
"""
WRITEME
"""
for r, new_r in pairs:
self.replace(r, new_r, reason=reason)
def attach_feature(self, feature):
"""
Adds a gof.toolbox.Feature to this function_graph
and triggers its on_attach callback
"""
Adds a gof.toolbox.Feature to this function_graph and triggers its
on_attach callback.
"""
# Filter out literally identical features
if feature in self._features:
return # the feature is already present
......@@ -567,7 +612,9 @@ class FunctionGraph(utils.object2):
self._features.append(feature)
def remove_feature(self, feature):
"""WRITEME
"""
WRITEME
Removes the feature from the graph.
Calls feature.on_detach(function_graph) if an on_detach method
......@@ -585,10 +632,13 @@ class FunctionGraph(utils.object2):
# callback utils #
def execute_callbacks(self, name, *args, **kwargs):
"""WRITEME
"""
WRITEME
Calls
getattr(feature, name)(*args)
for each feature which has a method called after name.
"""
t0 = time.time()
for feature in self._features:
......@@ -605,10 +655,13 @@ class FunctionGraph(utils.object2):
self.execute_callbacks_time += time.time() - t0
def collect_callbacks(self, name, *args):
"""WRITEME
"""
WRITEME
Returns a dictionary d such that:
d[feature] == getattr(feature, name)(*args)
For each feature which has a method called after name.
"""
d = {}
for feature in self._features:
......@@ -621,16 +674,19 @@ class FunctionGraph(utils.object2):
# misc #
def toposort(self):
"""WRITEME
Returns an ordering of the graph's Apply nodes such that:
- All the nodes of the inputs of a node are before that node.
- Satisfies the orderings provided by each feature that has
an 'orderings' method.
"""
WRITEME
Return an ordering of the graph's Apply nodes such that:
- All the nodes of the inputs of a node are before that node.
- Satisfies the orderings provided by each feature that has
an 'orderings' method.
If a feature has an 'orderings' method, it will be called with
this FunctionGraph as sole argument. It should return a dictionary of
{node: predecessors} where predecessors is a list of nodes
that should be computed before the key node.
"""
if len(self.apply_nodes) < 2:
# optimization
......@@ -652,11 +708,12 @@ class FunctionGraph(utils.object2):
before node itself can be evaluated.
This is used primarily by the destroy_handler feature to ensure that
all clients of any destroyed inputs have already computed their
outputs.
all clients of any destroyed inputs have already computed their outputs.
:note: This only calls the orderings() fct on all features. It does not
take care of computing dependencies by itself.
Notes
-----
This only calls the orderings() fct on all features. It does not
take care of computing dependencies by itself.
"""
ords = OrderedDict()
......@@ -682,8 +739,11 @@ class FunctionGraph(utils.object2):
return ords
def check_integrity(self):
"""WRITEME
"""
WRITEME
Call this for a diagnosis if things go awry.
"""
nodes = graph.ops(self.inputs, self.outputs)
if self.apply_nodes != nodes:
......@@ -740,11 +800,17 @@ class FunctionGraph(utils.object2):
# clone #
def clone(self, check_integrity=True):
"""WRITEME"""
"""
WRITEME
"""
return self.clone_get_equiv(check_integrity)[0]
def clone_get_equiv(self, check_integrity=True):
"""WRITEME"""
"""
WRITEME
"""
equiv = graph.clone_get_equiv(self.inputs, self.outputs)
if check_integrity:
self.check_integrity()
......@@ -757,8 +823,10 @@ class FunctionGraph(utils.object2):
return e, equiv
def __getstate__(self):
"""This is needed as some feature introduce instancemethod and
this is not picklable.
"""
This is needed as some features introduce instance methods.
This is not picklable.
"""
d = self.__dict__.copy()
for feature in self._features:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论