提交 75e573d1 authored 作者: Iban Harlouchet's avatar Iban Harlouchet

numpydoc for theano/gof/fg.py

上级 18c54eca
""" """
fg.py: fg stands for FunctionGraph fg.py: fg stands for FunctionGraph
Contains the FunctionGraph class and exception Contains the FunctionGraph class and exception
types that it can raise types that it can raise.
""" """
from __future__ import print_function from __future__ import print_function
import sys import sys
...@@ -23,10 +24,13 @@ NullType = None ...@@ -23,10 +24,13 @@ NullType = None
class CachedConstantError(Exception): class CachedConstantError(Exception):
"""An exception thrown when we put in a FunctionGraph a Constant """
that is cached. This should not happen as the user can reuse this An exception thrown when we put in a FunctionGraph a Constant
that is cached. This should not happen as the user can reuse this
cached constant in other FunctionGraph. cached constant in other FunctionGraph.
""" """
pass pass
...@@ -34,24 +38,28 @@ class InconsistencyError(Exception): ...@@ -34,24 +38,28 @@ class InconsistencyError(Exception):
""" """
This exception should be thrown by listeners to FunctionGraph when the This exception should be thrown by listeners to FunctionGraph when the
graph's state is invalid. graph's state is invalid.
""" """
pass pass
class MissingInputError(Exception): class MissingInputError(Exception):
""" """
A symbolic input needed to compute the outputs is missing. A symbolic input needed to compute the outputs is missing.
""" """
pass pass
class FunctionGraph(utils.object2): class FunctionGraph(utils.object2):
""" WRITEME """
A FunctionGraph represents a subgraph bound by a set of input variables and a WRITEME
set of output variables, ie a subgraph that specifies a theano function. A FunctionGraph represents a subgraph bound by a set of input variables and
The inputs list should contain all the inputs a set of output variables, ie a subgraph that specifies a theano function.
on which the outputs depend. Variables of type Constant are The inputs list should contain all the inputs on which the outputs depend.
not counted as inputs. Variables of type Constant are not counted as inputs.
The FunctionGraph supports the replace operation which allows to replace a The FunctionGraph supports the replace operation which allows to replace a
variable in the subgraph by another, e.g. replace (x + x).out by (2 variable in the subgraph by another, e.g. replace (x + x).out by (2
...@@ -74,28 +82,35 @@ class FunctionGraph(utils.object2): ...@@ -74,28 +82,35 @@ class FunctionGraph(utils.object2):
Historically, the FunctionGraph was called an Env. Keep this in mind Historically, the FunctionGraph was called an Env. Keep this in mind
while reading out-of-date documentation, e-mail support threads, etc. while reading out-of-date documentation, e-mail support threads, etc.
""" The constructor creates a FunctionGraph which operates on the subgraph
bound by the inputs and outputs sets.
def __init__(self, inputs, outputs, features=None, clone=True): This class keeps a pointer to the inputs and outputs, and also modifies
""" them.
Create an FunctionGraph which operates on the subgraph bound by the inputs and
outputs sets.
This class keeps a pointer to the inputs and outputs, and also modifies #TODO: document what variables are[not] set in the FunctionGraph when a
them. feature is added via the constructor. How constructed is the
FunctionGraph?
#TODO: document what variables are[not] set in the FunctionGraph when a feature Parameters
is added via the constructor. How constructed is the FunctionGraph? ----------
inputs
Inputs nodes of the graph, usually declared by the user.
outputs
Outputs nodes of the graph.
clone
If true, we will clone the graph. This is useful to remove the constant
cache problem.
Note: the intermediate nodes between 'inputs' and 'outputs' are not explicitely Notes
passed. -----
The intermediate nodes between 'inputs' and 'outputs' are not explicitely
passed.
:param inputs: inputs nodes of the graph, usually declared by the user """
:param outputs: outputs nodes of the graph.
:param clone: If true, we will clone the graph. This is def __init__(self, inputs, outputs, features=None, clone=True):
useful to remove the constant cache problem.
"""
if clone: if clone:
inputs, outputs = graph.clone(inputs, outputs) inputs, outputs = graph.clone(inputs, outputs)
...@@ -180,15 +195,17 @@ class FunctionGraph(utils.object2): ...@@ -180,15 +195,17 @@ class FunctionGraph(utils.object2):
# self.execute_callbacks('on_setup_node', node) # self.execute_callbacks('on_setup_node', node)
def disown(self): def disown(self):
""" WRITEME """
Cleans up all of this FunctionGraph's nodes and variables so they are not WRITEME
associated with this FunctionGraph anymore. Cleans up all of this FunctionGraph's nodes and variables so they are
not associated with this FunctionGraph anymore.
The FunctionGraph should not be used anymore after disown is called. The FunctionGraph should not be used anymore after disown is called.
This may not clean everything this FunctionGraph's features set in the This may not clean everything this FunctionGraph's features set in the
nodes and variables. If there are no features, this should set nodes and variables. If there are no features, this should set
them back to what they were originally. them back to what they were originally.
""" """
for apply_node in self.apply_nodes: for apply_node in self.apply_nodes:
del apply_node.fgraph del apply_node.fgraph
...@@ -205,18 +222,25 @@ class FunctionGraph(utils.object2): ...@@ -205,18 +222,25 @@ class FunctionGraph(utils.object2):
def clients(self, r): def clients(self, r):
""" """
Set of all the (node, i) pairs such that node.inputs[i] is r. Set of all the (node, i) pairs such that node.inputs[i] is r.
Tell differently, a list of (node,i) such that each node have Told differently, a list of (node,i) such that each node have
r as input at index i. r as input at index i.
""" """
return r.clients return r.clients
def __add_clients__(self, r, new_clients): def __add_clients__(self, r, new_clients):
""" WRITEME """
r -> variable
new_clients -> list of (node, i) pairs such that node.inputs[i] is r.
Updates the list of clients of r with new_clients. Updates the list of clients of r with new_clients.
WRITEME
Parameters
----------
r
Variable.
new_clients
List of (node, i) pairs such that node.inputs[i] is r.
""" """
if set(r.clients).intersection(set(new_clients)): if set(r.clients).intersection(set(new_clients)):
print('ERROR: clients intersect!', file=sys.stderr) print('ERROR: clients intersect!', file=sys.stderr)
...@@ -229,11 +253,18 @@ class FunctionGraph(utils.object2): ...@@ -229,11 +253,18 @@ class FunctionGraph(utils.object2):
def __remove_clients__(self, r, clients_to_remove, def __remove_clients__(self, r, clients_to_remove,
prune=True, reason=None): prune=True, reason=None):
""" WRITEME """
r -> variable
clients_to_remove -> list of (op, i) pairs such that node.inputs[i] is not r anymore.
Removes all from the clients list of r. Removes all from the clients list of r.
WRITEME
Parameters
----------
r
Variable.
clients_to_remove
List of (op, i) pairs such that node.inputs[i] is not r anymore.
""" """
for entry in clients_to_remove: for entry in clients_to_remove:
r.clients.remove(entry) r.clients.remove(entry)
...@@ -286,11 +317,14 @@ class FunctionGraph(utils.object2): ...@@ -286,11 +317,14 @@ class FunctionGraph(utils.object2):
if config.exception_verbosity == 'high': if config.exception_verbosity == 'high':
def find_path_to(output_var, input_var): def find_path_to(output_var, input_var):
""" Returns a list of each variable on a (not necessarily unique) """
path from input_var to output_var, where each variable in the Returns a list of each variable on a (not
list has the preceding variable as one of its inputs. necessarily unique) path from input_var to
Returns None if no path exists""" output_var, where each variable in the list has
the preceding variable as one of its inputs.
Returns None if no path exists.
"""
# If output and input are the same we have a singleton path # If output and input are the same we have a singleton path
if output_var is input_var: if output_var is input_var:
return [output_var] return [output_var]
...@@ -376,12 +410,13 @@ class FunctionGraph(utils.object2): ...@@ -376,12 +410,13 @@ class FunctionGraph(utils.object2):
# prune # # prune #
def __prune_r__(self, variable, reason=None): def __prune_r__(self, variable, reason=None):
"""Should be called for variable that aren't used anymore: """
len(var.clients) == 0 Should be called for variable that aren't used anymore:
len(var.clients) == 0.
This do not mean we will remove it from fgraph.variables. If This do not mean we will remove it from fgraph.variables. If
the owner stay in the fgraph as other outputs are still used, the owner stay in the fgraph as other outputs are still used,
the variable will be stay in fgraph.variables. the variable will stay in fgraph.variables.
""" """
# Prunes the owners of the variables. # Prunes the owners of the variables.
...@@ -409,7 +444,8 @@ class FunctionGraph(utils.object2): ...@@ -409,7 +444,8 @@ class FunctionGraph(utils.object2):
del variable.fgraph del variable.fgraph
def __prune__(self, apply_node, reason=None): def __prune__(self, apply_node, reason=None):
"""Always called on owner of pruned variable from the graph. """
Always called on owner of pruned variable from the graph.
This do not mean we will remove it from the graph. If other This do not mean we will remove it from the graph. If other
outputs are still used, we will keep the node in the graph. outputs are still used, we will keep the node in the graph.
...@@ -433,14 +469,17 @@ class FunctionGraph(utils.object2): ...@@ -433,14 +469,17 @@ class FunctionGraph(utils.object2):
# change input # # change input #
def change_input(self, node, i, new_r, reason=None): def change_input(self, node, i, new_r, reason=None):
"""WRITEME """
Changes node.inputs[i] to new_r. Changes node.inputs[i] to new_r.
WRITEME
new_r.type == old_r.type must be True, where old_r is the new_r.type == old_r.type must be True, where old_r is the
current value of node.inputs[i] which we want to replace. current value of node.inputs[i] which we want to replace.
For each feature that has a 'on_change_input' method, calls: For each feature that has a 'on_change_input' method, calls:
feature.on_change_input(function_graph, node, i, old_r, new_r, reason) feature.on_change_input(function_graph, node, i, old_r, new_r, reason)
""" """
# TODO: ERROR HANDLING FOR LISTENERS (should it complete the change or revert it?) # TODO: ERROR HANDLING FOR LISTENERS (should it complete the change or revert it?)
if node == 'output': if node == 'output':
...@@ -478,9 +517,12 @@ class FunctionGraph(utils.object2): ...@@ -478,9 +517,12 @@ class FunctionGraph(utils.object2):
# replace # # replace #
def replace(self, r, new_r, reason=None, verbose=None): def replace(self, r, new_r, reason=None, verbose=None):
""" WRITEME """
WRITEME
This is the main interface to manipulate the subgraph in FunctionGraph. This is the main interface to manipulate the subgraph in FunctionGraph.
For every node that uses r as input, makes it use new_r instead. For every node that uses r as input, makes it use new_r instead.
""" """
if verbose is None: if verbose is None:
verbose = config.optimizer_verbose verbose = config.optimizer_verbose
...@@ -532,16 +574,19 @@ class FunctionGraph(utils.object2): ...@@ -532,16 +574,19 @@ class FunctionGraph(utils.object2):
# print >> sys.stderr, "WARNING: CLIENTS LEFT AFTER REPLACE", r, r.clients # print >> sys.stderr, "WARNING: CLIENTS LEFT AFTER REPLACE", r, r.clients
def replace_all(self, pairs, reason=None): def replace_all(self, pairs, reason=None):
"""WRITEME""" """
WRITEME
"""
for r, new_r in pairs: for r, new_r in pairs:
self.replace(r, new_r, reason=reason) self.replace(r, new_r, reason=reason)
def attach_feature(self, feature): def attach_feature(self, feature):
""" """
Adds a gof.toolbox.Feature to this function_graph Adds a gof.toolbox.Feature to this function_graph and triggers its
and triggers its on_attach callback on_attach callback.
"""
"""
# Filter out literally identical features # Filter out literally identical features
if feature in self._features: if feature in self._features:
return # the feature is already present return # the feature is already present
...@@ -567,7 +612,9 @@ class FunctionGraph(utils.object2): ...@@ -567,7 +612,9 @@ class FunctionGraph(utils.object2):
self._features.append(feature) self._features.append(feature)
def remove_feature(self, feature): def remove_feature(self, feature):
"""WRITEME """
WRITEME
Removes the feature from the graph. Removes the feature from the graph.
Calls feature.on_detach(function_graph) if an on_detach method Calls feature.on_detach(function_graph) if an on_detach method
...@@ -585,10 +632,13 @@ class FunctionGraph(utils.object2): ...@@ -585,10 +632,13 @@ class FunctionGraph(utils.object2):
# callback utils # # callback utils #
def execute_callbacks(self, name, *args, **kwargs): def execute_callbacks(self, name, *args, **kwargs):
"""WRITEME """
WRITEME
Calls Calls
getattr(feature, name)(*args) getattr(feature, name)(*args)
for each feature which has a method called after name. for each feature which has a method called after name.
""" """
t0 = time.time() t0 = time.time()
for feature in self._features: for feature in self._features:
...@@ -605,10 +655,13 @@ class FunctionGraph(utils.object2): ...@@ -605,10 +655,13 @@ class FunctionGraph(utils.object2):
self.execute_callbacks_time += time.time() - t0 self.execute_callbacks_time += time.time() - t0
def collect_callbacks(self, name, *args): def collect_callbacks(self, name, *args):
"""WRITEME """
WRITEME
Returns a dictionary d such that: Returns a dictionary d such that:
d[feature] == getattr(feature, name)(*args) d[feature] == getattr(feature, name)(*args)
For each feature which has a method called after name. For each feature which has a method called after name.
""" """
d = {} d = {}
for feature in self._features: for feature in self._features:
...@@ -621,16 +674,19 @@ class FunctionGraph(utils.object2): ...@@ -621,16 +674,19 @@ class FunctionGraph(utils.object2):
# misc # # misc #
def toposort(self): def toposort(self):
"""WRITEME """
Returns an ordering of the graph's Apply nodes such that: WRITEME
- All the nodes of the inputs of a node are before that node.
- Satisfies the orderings provided by each feature that has Return an ordering of the graph's Apply nodes such that:
an 'orderings' method. - All the nodes of the inputs of a node are before that node.
- Satisfies the orderings provided by each feature that has
an 'orderings' method.
If a feature has an 'orderings' method, it will be called with If a feature has an 'orderings' method, it will be called with
this FunctionGraph as sole argument. It should return a dictionary of this FunctionGraph as sole argument. It should return a dictionary of
{node: predecessors} where predecessors is a list of nodes {node: predecessors} where predecessors is a list of nodes
that should be computed before the key node. that should be computed before the key node.
""" """
if len(self.apply_nodes) < 2: if len(self.apply_nodes) < 2:
# optimization # optimization
...@@ -652,11 +708,12 @@ class FunctionGraph(utils.object2): ...@@ -652,11 +708,12 @@ class FunctionGraph(utils.object2):
before node itself can be evaluated. before node itself can be evaluated.
This is used primarily by the destroy_handler feature to ensure that This is used primarily by the destroy_handler feature to ensure that
all clients of any destroyed inputs have already computed their all clients of any destroyed inputs have already computed their outputs.
outputs.
:note: This only calls the orderings() fct on all features. It does not Notes
take care of computing dependencies by itself. -----
This only calls the orderings() fct on all features. It does not
take care of computing dependencies by itself.
""" """
ords = OrderedDict() ords = OrderedDict()
...@@ -682,8 +739,11 @@ class FunctionGraph(utils.object2): ...@@ -682,8 +739,11 @@ class FunctionGraph(utils.object2):
return ords return ords
def check_integrity(self): def check_integrity(self):
"""WRITEME """
WRITEME
Call this for a diagnosis if things go awry. Call this for a diagnosis if things go awry.
""" """
nodes = graph.ops(self.inputs, self.outputs) nodes = graph.ops(self.inputs, self.outputs)
if self.apply_nodes != nodes: if self.apply_nodes != nodes:
...@@ -740,11 +800,17 @@ class FunctionGraph(utils.object2): ...@@ -740,11 +800,17 @@ class FunctionGraph(utils.object2):
# clone # # clone #
def clone(self, check_integrity=True): def clone(self, check_integrity=True):
"""WRITEME""" """
WRITEME
"""
return self.clone_get_equiv(check_integrity)[0] return self.clone_get_equiv(check_integrity)[0]
def clone_get_equiv(self, check_integrity=True): def clone_get_equiv(self, check_integrity=True):
"""WRITEME""" """
WRITEME
"""
equiv = graph.clone_get_equiv(self.inputs, self.outputs) equiv = graph.clone_get_equiv(self.inputs, self.outputs)
if check_integrity: if check_integrity:
self.check_integrity() self.check_integrity()
...@@ -757,8 +823,10 @@ class FunctionGraph(utils.object2): ...@@ -757,8 +823,10 @@ class FunctionGraph(utils.object2):
return e, equiv return e, equiv
def __getstate__(self): def __getstate__(self):
"""This is needed as some feature introduce instancemethod and """
this is not picklable. This is needed as some features introduce instance methods.
This is not picklable.
""" """
d = self.__dict__.copy() d = self.__dict__.copy()
for feature in self._features: for feature in self._features:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论