提交 18c54eca authored 作者: Iban Harlouchet's avatar Iban Harlouchet

numpydoc for theano/gof/destroyhandler.py

上级 5da33fd1
""" """
Classes and functions for validating graphs that contain view Classes and functions for validating graphs that contain view
and inplace operations. and inplace operations.
""" """
from collections import deque from collections import deque
...@@ -17,35 +18,41 @@ from six.moves.queue import Queue ...@@ -17,35 +18,41 @@ from six.moves.queue import Queue
class ProtocolError(Exception): class ProtocolError(Exception):
"""Raised when FunctionGraph calls DestroyHandler callbacks in """
Raised when FunctionGraph calls DestroyHandler callbacks in
an invalid way, for example, pruning or changing a node that has an invalid way, for example, pruning or changing a node that has
never been imported. never been imported.
""" """
pass pass
def _contains_cycle(fgraph, orderings): def _contains_cycle(fgraph, orderings):
""" """
fgraph - the FunctionGraph to check for cycles Parameters
----------
fgraph
The FunctionGraph to check for cycles.
orderings
Dictionary specifying extra dependencies besides those encoded in
Variable.owner / Apply.inputs.
orderings - dictionary specifying extra dependencies besides If orderings[my_apply] == dependencies, then my_apply is an Apply
those encoded in Variable.owner / Apply.inputs instance, dependencies is a set of Apply instances, and every member
of dependencies must be executed before my_apply.
If orderings[my_apply] == dependencies,
then my_apply is an Apply instance,
dependencies is a set of Apply instances,
and every member of dependencies must be executed
before my_apply.
The dependencies are typically used to prevent The dependencies are typically used to prevent
inplace apply nodes from destroying their input before inplace apply nodes from destroying their input before
other apply nodes with the same input access it. other apply nodes with the same input access it.
Returns True if the graph contains a cycle, False otherwise. Returns
""" -------
bool
True if the graph contains a cycle, False otherwise.
"""
# These are lists of Variable instances # These are lists of Variable instances
outputs = fgraph.outputs outputs = fgraph.outputs
...@@ -227,10 +234,15 @@ def _build_droot_impact(destroy_handler): ...@@ -227,10 +234,15 @@ def _build_droot_impact(destroy_handler):
def fast_inplace_check(inputs): def fast_inplace_check(inputs):
""" Return the variables in inputs that are posible candidate for as inputs of inplace operation """
Return the variables in inputs that are posible candidate for as inputs of
inplace operation.
Parameters
----------
inputs : list
Inputs Variable that you want to use as inplace destination.
:type inputs: list
:param inputs: inputs Variable that you want to use as inplace destination
""" """
fgraph = inputs[0].fgraph fgraph = inputs[0].fgraph
Supervisor = theano.compile.function_module.Supervisor Supervisor = theano.compile.function_module.Supervisor
...@@ -249,38 +261,42 @@ if 0: ...@@ -249,38 +261,42 @@ if 0:
# old, non-incremental version of the DestroyHandler # old, non-incremental version of the DestroyHandler
class DestroyHandler(toolbox.Bookkeeper): class DestroyHandler(toolbox.Bookkeeper):
""" """
The DestroyHandler class detects when a graph is impossible to evaluate because of The DestroyHandler class detects when a graph is impossible to evaluate
aliasing and destructive operations. because of aliasing and destructive operations.
Several data structures are used to do this. Several data structures are used to do this.
When an Op uses its view_map property to declare that an output may be aliased When an Op uses its view_map property to declare that an output may be
to an input, then if that output is destroyed, the input is also considering to be aliased to an input, then if that output is destroyed, the input is also
destroyed. The view_maps of several Ops can feed into one another and form a directed graph. considering to be destroyed. The view_maps of several Ops can feed into
The consequence of destroying any variable in such a graph is that all variables in the graph one another and form a directed graph. The consequence of destroying any
must be considered to be destroyed, because they could all be refering to the same variable in such a graph is that all variables in the graph must be
underlying storage. In the current implementation, that graph is a tree, and the root of considered to be destroyed, because they could all be refering to the
that tree is called the foundation. The `droot` property of this class maps from every same underlying storage. In the current implementation, that graph is a
graph variable to its foundation. The `impact` property maps backward from the foundation tree, and the root of that tree is called the foundation. The `droot`
to all of the variables that depend on it. When any variable is destroyed, this class marks property of this class maps from every graph variable to its foundation.
the foundation of that variable as being destroyed, with the `root_destroyer` property. The `impact` property maps backward from the foundation to all of the
variables that depend on it. When any variable is destroyed, this class
marks the foundation of that variable as being destroyed, with the
`root_destroyer` property.
""" """
droot = {} droot = {}
""" """
destroyed view + nonview variables -> foundation destroyed view + nonview variables -> foundation.
"""
impact = {}
""" """
destroyed nonview variable -> it + all views of it impact = {}
""" """
destroyed nonview variable -> it + all views of it.
root_destroyer = {}
""" """
root -> destroyer apply root_destroyer = {}
""" """
root -> destroyer apply.
"""
def __init__(self, do_imports_on_attach=True): def __init__(self, do_imports_on_attach=True):
self.fgraph = None self.fgraph = None
self.do_imports_on_attach = do_imports_on_attach self.do_imports_on_attach = do_imports_on_attach
...@@ -295,8 +311,8 @@ if 0: ...@@ -295,8 +311,8 @@ if 0:
compilation to be slower. compilation to be slower.
TODO: WRITEME: what does this do besides the checks? TODO: WRITEME: what does this do besides the checks?
"""
"""
# Do the checking # # Do the checking #
already_there = False already_there = False
if self.fgraph not in [None, fgraph]: if self.fgraph not in [None, fgraph]:
...@@ -363,8 +379,10 @@ if 0: ...@@ -363,8 +379,10 @@ if 0:
self.fgraph = None self.fgraph = None
def on_import(self, fgraph, app, reason): def on_import(self, fgraph, app, reason):
"""Add Apply instance to set which must be computed""" """
Add Apply instance to set which must be computed.
"""
# if app in self.debug_all_apps: raise ProtocolError("double import") # if app in self.debug_all_apps: raise ProtocolError("double import")
# self.debug_all_apps.add(app) # self.debug_all_apps.add(app)
# print 'DH IMPORT', app, id(app), id(self), len(self.debug_all_apps) # print 'DH IMPORT', app, id(app), id(self), len(self.debug_all_apps)
...@@ -395,7 +413,10 @@ if 0: ...@@ -395,7 +413,10 @@ if 0:
self.stale_droot = True self.stale_droot = True
def on_prune(self, fgraph, app, reason): def on_prune(self, fgraph, app, reason):
"""Remove Apply instance from set which must be computed""" """
Remove Apply instance from set which must be computed.
"""
# if app not in self.debug_all_apps: raise ProtocolError("prune without import") # if app not in self.debug_all_apps: raise ProtocolError("prune without import")
# self.debug_all_apps.remove(app) # self.debug_all_apps.remove(app)
...@@ -427,7 +448,10 @@ if 0: ...@@ -427,7 +448,10 @@ if 0:
self.stale_droot = True self.stale_droot = True
def on_change_input(self, fgraph, app, i, old_r, new_r, reason): def on_change_input(self, fgraph, app, i, old_r, new_r, reason):
"""app.inputs[i] changed from old_r to new_r """ """
app.inputs[i] changed from old_r to new_r.
"""
if app == 'output': if app == 'output':
# app == 'output' is special key that means FunctionGraph is redefining which nodes are being # app == 'output' is special key that means FunctionGraph is redefining which nodes are being
# considered 'outputs' of the graph. # considered 'outputs' of the graph.
...@@ -466,14 +490,14 @@ if 0: ...@@ -466,14 +490,14 @@ if 0:
self.stale_droot = True self.stale_droot = True
def validate(self, fgraph): def validate(self, fgraph):
"""Return None """
Return None.
Raise InconsistencyError when Raise InconsistencyError when
a) orderings() raises an error a) orderings() raises an error
b) orderings cannot be topologically sorted. b) orderings cannot be topologically sorted.
""" """
if self.destroyers: if self.destroyers:
ords = self.orderings(fgraph) ords = self.orderings(fgraph)
...@@ -487,7 +511,8 @@ if 0: ...@@ -487,7 +511,8 @@ if 0:
return True return True
def orderings(self, fgraph): def orderings(self, fgraph):
"""Return orderings induced by destructive operations. """
Return orderings induced by destructive operations.
Raise InconsistencyError when Raise InconsistencyError when
a) attempting to destroy indestructable variable, or a) attempting to destroy indestructable variable, or
...@@ -637,6 +662,7 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa ...@@ -637,6 +662,7 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa
The following data structures remain to be converted: The following data structures remain to be converted:
<unknown> <unknown>
""" """
pickle_rm_attr = ["destroyers"] pickle_rm_attr = ["destroyers"]
...@@ -644,29 +670,38 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa ...@@ -644,29 +670,38 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa
self.fgraph = None self.fgraph = None
self.do_imports_on_attach = do_imports_on_attach self.do_imports_on_attach = do_imports_on_attach
"""maps every variable in the graph to its "foundation" (deepest """
ancestor in view chain) Maps every variable in the graph to its "foundation" (deepest
TODO: change name to var_to_vroot""" ancestor in view chain).
TODO: change name to var_to_vroot.
"""
self.droot = OrderedDict() self.droot = OrderedDict()
"""maps a variable to all variables that are indirect or direct views of it """
(including itself) Maps a variable to all variables that are indirect or direct views of it
essentially the inverse of droot (including itself) essentially the inverse of droot.
TODO: do all variables appear in this dict, or only those that are foundations? TODO: do all variables appear in this dict, or only those that are
TODO: do only destroyed variables go in here? one old docstring said so foundations?
TODO: rename to x_to_views after reverse engineering what x is""" TODO: do only destroyed variables go in here? one old docstring said so.
TODO: rename to x_to_views after reverse engineering what x is
"""
self.impact = OrderedDict() self.impact = OrderedDict()
"""if a var is destroyed, then this dict will map """
If a var is destroyed, then this dict will map
droot[var] to the apply node that destroyed var droot[var] to the apply node that destroyed var
TODO: rename to vroot_to_destroyer""" TODO: rename to vroot_to_destroyer
"""
self.root_destroyer = OrderedDict() self.root_destroyer = OrderedDict()
def on_attach(self, fgraph): def on_attach(self, fgraph):
""" """
When attaching to a new fgraph, check that When attaching to a new fgraph, check that
1) This DestroyHandler wasn't already attached to some fgraph 1) This DestroyHandler wasn't already attached to some fgraph
(its data structures are only set up to serve one) (its data structures are only set up to serve one).
2) The FunctionGraph doesn't already have a DestroyHandler. 2) The FunctionGraph doesn't already have a DestroyHandler.
This would result in it validating everything twice, causing This would result in it validating everything twice, causing
compilation to be slower. compilation to be slower.
...@@ -676,6 +711,7 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa ...@@ -676,6 +711,7 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa
TODO: what does this do exactly? TODO: what does this do exactly?
2) A new attribute, "destroy_handler" 2) A new attribute, "destroy_handler"
TODO: WRITEME: what does this do besides the checks? TODO: WRITEME: what does this do besides the checks?
""" """
# Do the checking # # Do the checking #
...@@ -723,9 +759,9 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa ...@@ -723,9 +759,9 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa
def refresh_droot_impact(self): def refresh_droot_impact(self):
""" """
Makes sure self.droot, self.impact, and self.root_destroyer are Makes sure self.droot, self.impact, and self.root_destroyer are up to
up to date, and returns them. date, and returns them (see docstrings for these properties above).
(see docstrings for these properties above)
""" """
if self.stale_droot: if self.stale_droot:
self.droot, self.impact, self.root_destroyer =\ self.droot, self.impact, self.root_destroyer =\
...@@ -747,7 +783,10 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa ...@@ -747,7 +783,10 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa
self.fgraph = None self.fgraph = None
def on_import(self, fgraph, app, reason): def on_import(self, fgraph, app, reason):
"""Add Apply instance to set which must be computed""" """
Add Apply instance to set which must be computed.
"""
if app in self.debug_all_apps: if app in self.debug_all_apps:
raise ProtocolError("double import") raise ProtocolError("double import")
...@@ -780,7 +819,10 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa ...@@ -780,7 +819,10 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa
self.stale_droot = True self.stale_droot = True
def on_prune(self, fgraph, app, reason): def on_prune(self, fgraph, app, reason):
"""Remove Apply instance from set which must be computed""" """
Remove Apply instance from set which must be computed.
"""
if app not in self.debug_all_apps: if app not in self.debug_all_apps:
raise ProtocolError("prune without import") raise ProtocolError("prune without import")
self.debug_all_apps.remove(app) self.debug_all_apps.remove(app)
...@@ -814,7 +856,10 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa ...@@ -814,7 +856,10 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa
self.stale_droot = True self.stale_droot = True
def on_change_input(self, fgraph, app, i, old_r, new_r, reason): def on_change_input(self, fgraph, app, i, old_r, new_r, reason):
"""app.inputs[i] changed from old_r to new_r """ """
app.inputs[i] changed from old_r to new_r.
"""
if app == 'output': if app == 'output':
# app == 'output' is special key that means FunctionGraph is redefining which nodes are being # app == 'output' is special key that means FunctionGraph is redefining which nodes are being
# considered 'outputs' of the graph. # considered 'outputs' of the graph.
...@@ -854,14 +899,14 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa ...@@ -854,14 +899,14 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa
self.stale_droot = True self.stale_droot = True
def validate(self, fgraph): def validate(self, fgraph):
"""Return None """
Return None.
Raise InconsistencyError when Raise InconsistencyError when
a) orderings() raises an error a) orderings() raises an error
b) orderings cannot be topologically sorted. b) orderings cannot be topologically sorted.
""" """
if self.destroyers: if self.destroyers:
ords = self.orderings(fgraph) ords = self.orderings(fgraph)
...@@ -882,7 +927,8 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa ...@@ -882,7 +927,8 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa
return True return True
def orderings(self, fgraph): def orderings(self, fgraph):
"""Return orderings induced by destructive operations. """
Return orderings induced by destructive operations.
Raise InconsistencyError when Raise InconsistencyError when
a) attempting to destroy indestructable variable, or a) attempting to destroy indestructable variable, or
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论