提交 7f3bfb23 authored 作者: James Bergstra's avatar James Bergstra

added WRITEME many places

上级 77c55988
...@@ -7,7 +7,6 @@ from graph import Result, Apply ...@@ -7,7 +7,6 @@ from graph import Result, Apply
from op import Op from op import Op
from opt import * from opt import *
from ext import *
import destroyhandler import destroyhandler
from env import Env, InconsistencyError from env import Env, InconsistencyError
from toolbox import ReplaceValidate from toolbox import ReplaceValidate
......
...@@ -68,7 +68,7 @@ def get_compiledir(): ...@@ -68,7 +68,7 @@ def get_compiledir():
class CodeBlock: class CodeBlock:
""" """WRITEME
Represents a computation unit composed of declare, behavior, and cleanup. Represents a computation unit composed of declare, behavior, and cleanup.
@ivar declare: C code that declares variables for use by the computation @ivar declare: C code that declares variables for use by the computation
@ivar behavior: C code that performs the computation @ivar behavior: C code that performs the computation
...@@ -94,11 +94,12 @@ class CodeBlock: ...@@ -94,11 +94,12 @@ class CodeBlock:
def failure_code(sub): def failure_code(sub):
"""WRITEME"""
return "{%(failure_var)s = %(id)s; goto __label_%(id)i;}" % sub return "{%(failure_var)s = %(id)s; goto __label_%(id)i;}" % sub
def code_gen(blocks): def code_gen(blocks):
""" """WRITEME
From a list of L{CodeBlock} instances, returns a string that executes them From a list of L{CodeBlock} instances, returns a string that executes them
all in sequence. eg for C{(decl1, task1, cleanup1)} and C{(decl2, task2, cleanup2)} all in sequence. eg for C{(decl1, task1, cleanup1)} and C{(decl2, task2, cleanup2)}
the returned string will be of the form:: the returned string will be of the form::
...@@ -126,7 +127,7 @@ def code_gen(blocks): ...@@ -126,7 +127,7 @@ def code_gen(blocks):
def struct_gen(args, struct_builders, blocks, sub): def struct_gen(args, struct_builders, blocks, sub):
""" """WRITEME
Generates a struct conforming to the following specifications: Generates a struct conforming to the following specifications:
* args -> all of the PyObject* type, stored in the struct * args -> all of the PyObject* type, stored in the struct
they represent the storage and must be length 1 python lists. they represent the storage and must be length 1 python lists.
...@@ -253,16 +254,18 @@ def struct_gen(args, struct_builders, blocks, sub): ...@@ -253,16 +254,18 @@ def struct_gen(args, struct_builders, blocks, sub):
# with handling of the py_<name> variable. # with handling of the py_<name> variable.
def get_nothing(r, name, sub): def get_nothing(r, name, sub):
"" """WRITEME"""
return "" return ""
def get_c_declare(r, name, sub): def get_c_declare(r, name, sub):
"""WRITEME"""
pre = """ pre = """
PyObject* py_%(name)s; PyObject* py_%(name)s;
""" % locals() """ % locals()
return pre + r.type.c_declare(name, sub) return pre + r.type.c_declare(name, sub)
def get_c_init(r, name, sub): def get_c_init(r, name, sub):
"""WRITEME"""
pre = "" """ pre = "" """
py_%(name)s = Py_None; py_%(name)s = Py_None;
Py_XINCREF(py_%(name)s); Py_XINCREF(py_%(name)s);
...@@ -270,6 +273,7 @@ def get_c_init(r, name, sub): ...@@ -270,6 +273,7 @@ def get_c_init(r, name, sub):
return pre + r.type.c_init(name, sub) return pre + r.type.c_init(name, sub)
def get_c_extract(r, name, sub): def get_c_extract(r, name, sub):
"""WRITEME"""
pre = """ pre = """
py_%(name)s = PyList_GET_ITEM(storage_%(name)s, 0); py_%(name)s = PyList_GET_ITEM(storage_%(name)s, 0);
Py_XINCREF(py_%(name)s); Py_XINCREF(py_%(name)s);
...@@ -277,12 +281,14 @@ def get_c_extract(r, name, sub): ...@@ -277,12 +281,14 @@ def get_c_extract(r, name, sub):
return pre + r.type.c_extract(name, sub) return pre + r.type.c_extract(name, sub)
def get_c_cleanup(r, name, sub): def get_c_cleanup(r, name, sub):
"""WRITEME"""
post = """ post = """
Py_XDECREF(py_%(name)s); Py_XDECREF(py_%(name)s);
""" % locals() """ % locals()
return r.type.c_cleanup(name, sub) + post return r.type.c_cleanup(name, sub) + post
def get_c_sync(r, name, sub): def get_c_sync(r, name, sub):
"""WRITEME"""
return """ return """
if (!%(failure_var)s) { if (!%(failure_var)s) {
%(sync)s %(sync)s
...@@ -294,7 +300,7 @@ def get_c_sync(r, name, sub): ...@@ -294,7 +300,7 @@ def get_c_sync(r, name, sub):
""" % dict(sync = r.type.c_sync(name, sub), name = name, **sub) """ % dict(sync = r.type.c_sync(name, sub), name = name, **sub)
def apply_policy(policy, r, name, sub): def apply_policy(policy, r, name, sub):
""" """WRITEME
@param policy: list of functions that map a L{Result} to a string, or a single such function @param policy: list of functions that map a L{Result} to a string, or a single such function
@type r: L{Result} @type r: L{Result}
@return: C{policy[0](r) + policy[1](r) + ...} @return: C{policy[0](r) + policy[1](r) + ...}
...@@ -309,7 +315,7 @@ def apply_policy(policy, r, name, sub): ...@@ -309,7 +315,7 @@ def apply_policy(policy, r, name, sub):
def struct_result_codeblocks(result, policies, id, symbol_table, sub): def struct_result_codeblocks(result, policies, id, symbol_table, sub):
""" """WRITEME
result -> a Result result -> a Result
policies -> a pair of tuples ((declare_policy, behavior_policy, cleanup_policy), -- at construction policies -> a pair of tuples ((declare_policy, behavior_policy, cleanup_policy), -- at construction
(declare_policy, behavior_policy, cleanup_policy)) -- at execution (declare_policy, behavior_policy, cleanup_policy)) -- at execution
...@@ -339,7 +345,7 @@ def struct_result_codeblocks(result, policies, id, symbol_table, sub): ...@@ -339,7 +345,7 @@ def struct_result_codeblocks(result, policies, id, symbol_table, sub):
class CLinker(link.Linker): class CLinker(link.Linker):
""" """WRITEME
Creates C code for an env, compiles it and returns callables Creates C code for an env, compiles it and returns callables
through make_thunk and make_function that make use of the compiled through make_thunk and make_function that make use of the compiled
...@@ -354,6 +360,7 @@ class CLinker(link.Linker): ...@@ -354,6 +360,7 @@ class CLinker(link.Linker):
self.env = None self.env = None
def accept(self, env, no_recycling = []): def accept(self, env, no_recycling = []):
"""WRITEME"""
if self.env is not None and self.env is not env: if self.env is not None and self.env is not env:
return type(self)().accept(env, no_recycling) return type(self)().accept(env, no_recycling)
#raise Exception("Cannot accept from a Linker that is already tied to another Env.") #raise Exception("Cannot accept from a Linker that is already tied to another Env.")
...@@ -363,7 +370,7 @@ class CLinker(link.Linker): ...@@ -363,7 +370,7 @@ class CLinker(link.Linker):
return self return self
def fetch_results(self): def fetch_results(self):
""" """WRITEME
Fills the inputs, outputs, results, orphans, temps and node_order fields. Fills the inputs, outputs, results, orphans, temps and node_order fields.
""" """
env = self.env env = self.env
...@@ -376,7 +383,7 @@ class CLinker(link.Linker): ...@@ -376,7 +383,7 @@ class CLinker(link.Linker):
self.node_order = env.toposort() self.node_order = env.toposort()
def code_gen(self): def code_gen(self):
""" """WRITEME
Generates code for a struct that does the computation of the env and Generates code for a struct that does the computation of the env and
stores it in the struct_code field of the instance. stores it in the struct_code field of the instance.
...@@ -542,7 +549,7 @@ class CLinker(link.Linker): ...@@ -542,7 +549,7 @@ class CLinker(link.Linker):
return self.struct_code return self.struct_code
def support_code(self): def support_code(self):
""" """WRITEME
Returns a list of support code strings that are needed by Returns a list of support code strings that are needed by
one or more Results or Ops. The support code from Results is one or more Results or Ops. The support code from Results is
added before the support code from Ops. added before the support code from Ops.
...@@ -556,7 +563,7 @@ class CLinker(link.Linker): ...@@ -556,7 +563,7 @@ class CLinker(link.Linker):
return ret return ret
def compile_args(self): def compile_args(self):
""" """WRITEME
Returns a list of compile args that are needed by one Returns a list of compile args that are needed by one
or more Results or Ops. or more Results or Ops.
...@@ -569,7 +576,7 @@ class CLinker(link.Linker): ...@@ -569,7 +576,7 @@ class CLinker(link.Linker):
return ret return ret
def headers(self): def headers(self):
""" """WRITEME
Returns a list of headers that are needed by one Returns a list of headers that are needed by one
or more Results or Ops. or more Results or Ops.
...@@ -582,7 +589,7 @@ class CLinker(link.Linker): ...@@ -582,7 +589,7 @@ class CLinker(link.Linker):
return ret return ret
def libraries(self): def libraries(self):
""" """WRITEME
Returns a list of libraries that are needed by one Returns a list of libraries that are needed by one
or more Results or Ops. or more Results or Ops.
...@@ -595,7 +602,7 @@ class CLinker(link.Linker): ...@@ -595,7 +602,7 @@ class CLinker(link.Linker):
return ret return ret
def __compile__(self, input_storage = None, output_storage = None): def __compile__(self, input_storage = None, output_storage = None):
""" """WRITEME
Compiles this linker's env. Compiles this linker's env.
@type input_storage: list or None @type input_storage: list or None
...@@ -629,7 +636,7 @@ class CLinker(link.Linker): ...@@ -629,7 +636,7 @@ class CLinker(link.Linker):
error_storage error_storage
def make_thunk(self, input_storage = None, output_storage = None): def make_thunk(self, input_storage = None, output_storage = None):
""" """WRITEME
Compiles this linker's env and returns a function to perform the Compiles this linker's env and returns a function to perform the
computations, as well as lists of storage cells for both the computations, as well as lists of storage cells for both the
inputs and outputs. inputs and outputs.
...@@ -655,7 +662,7 @@ class CLinker(link.Linker): ...@@ -655,7 +662,7 @@ class CLinker(link.Linker):
return _execute(cthunk, self.init_tasks, self.tasks, error_storage), in_storage, out_storage return _execute(cthunk, self.init_tasks, self.tasks, error_storage), in_storage, out_storage
def cthunk_factory(self, error_storage, in_storage, out_storage): def cthunk_factory(self, error_storage, in_storage, out_storage):
""" """WRITEME
error_storage -> list of length 3 error_storage -> list of length 3
in_storage -> list of lists of length 1, one per input in_storage -> list of lists of length 1, one per input
out_storage -> list of lists of length 1, one per output out_storage -> list of lists of length 1, one per output
...@@ -754,6 +761,7 @@ class CLinker(link.Linker): ...@@ -754,6 +761,7 @@ class CLinker(link.Linker):
def _execute(cthunk, init_tasks, tasks, error_storage): def _execute(cthunk, init_tasks, tasks, error_storage):
"""WRITEME"""
def find_task(failure_code): def find_task(failure_code):
""" """
Maps a failure code to the task that is associated to it. Maps a failure code to the task that is associated to it.
...@@ -782,7 +790,7 @@ def _execute(cthunk, init_tasks, tasks, error_storage): ...@@ -782,7 +790,7 @@ def _execute(cthunk, init_tasks, tasks, error_storage):
class OpWiseCLinker(link.LocalLinker): class OpWiseCLinker(link.LocalLinker):
""" """WRITEME
Uses CLinker on the individual Ops that comprise an env and loops Uses CLinker on the individual Ops that comprise an env and loops
over them in Python. The result is slower than a compiled version of over them in Python. The result is slower than a compiled version of
the whole env, but saves on compilation time because small changes the whole env, but saves on compilation time because small changes
...@@ -881,7 +889,7 @@ class OpWiseCLinker(link.LocalLinker): ...@@ -881,7 +889,7 @@ class OpWiseCLinker(link.LocalLinker):
def _default_checker(x, y): def _default_checker(x, y):
""" """WRITEME
Default checker for DualLinker. This checks that the Default checker for DualLinker. This checks that the
results contain the same data using ==. results contain the same data using ==.
""" """
...@@ -889,7 +897,7 @@ def _default_checker(x, y): ...@@ -889,7 +897,7 @@ def _default_checker(x, y):
raise Exception("Output mismatch.", {'performlinker': x[0], 'clinker': y[0]}) raise Exception("Output mismatch.", {'performlinker': x[0], 'clinker': y[0]})
class DualLinker(link.Linker): class DualLinker(link.Linker):
""" """WRITEME
Runs the env in parallel using PerformLinker and CLinker. Runs the env in parallel using PerformLinker and CLinker.
The thunk/function produced by DualLinker uses PerformLinker as the The thunk/function produced by DualLinker uses PerformLinker as the
......
"""WRITEME"""
from collections import defaultdict from collections import defaultdict
import toolbox import toolbox
...@@ -5,9 +6,12 @@ import graph ...@@ -5,9 +6,12 @@ import graph
from env import InconsistencyError from env import InconsistencyError
class ProtocolError(Exception): pass class ProtocolError(Exception):
"""WRITEME"""
pass
class DestroyHandler(toolbox.Bookkeeper): class DestroyHandler(toolbox.Bookkeeper):
"""WRITEME"""
def __init__(self): def __init__(self):
self.map = {} self.map = {}
...@@ -36,6 +40,8 @@ class DestroyHandler(toolbox.Bookkeeper): ...@@ -36,6 +40,8 @@ class DestroyHandler(toolbox.Bookkeeper):
class DestroyHandlerHelper2(toolbox.Bookkeeper): class DestroyHandlerHelper2(toolbox.Bookkeeper):
"""WRITEME"""
def __init__(self): def __init__(self):
self.env = None self.env = None
......
"""WRITEME"""
from copy import copy from copy import copy
import graph import graph
import utils import utils
...@@ -15,7 +17,7 @@ class InconsistencyError(Exception): ...@@ -15,7 +17,7 @@ class InconsistencyError(Exception):
class Env(utils.object2): class Env(utils.object2):
""" """ WRITEME
An Env represents a subgraph bound by a set of input results and a An Env represents a subgraph bound by a set of input results and a
set of output results. The inputs list should contain all the inputs set of output results. The inputs list should contain all the inputs
on which the outputs depend. Results of type Value or Constant are on which the outputs depend. Results of type Value or Constant are
...@@ -35,6 +37,8 @@ class Env(utils.object2): ...@@ -35,6 +37,8 @@ class Env(utils.object2):
""" """
Create an Env which operates on the subgraph bound by the inputs and outputs Create an Env which operates on the subgraph bound by the inputs and outputs
sets. sets.
WRITEME
""" """
self._features = [] self._features = []
...@@ -79,7 +83,7 @@ class Env(utils.object2): ...@@ -79,7 +83,7 @@ class Env(utils.object2):
node.deps = {} node.deps = {}
def disown(self): def disown(self):
""" """ WRITEME
Cleans up all of this Env's nodes and results so they are not Cleans up all of this Env's nodes and results so they are not
associated with this Env anymore. associated with this Env anymore.
...@@ -104,11 +108,12 @@ class Env(utils.object2): ...@@ -104,11 +108,12 @@ class Env(utils.object2):
### clients ### ### clients ###
def clients(self, r): def clients(self, r):
"Set of all the (node, i) pairs such that node.inputs[i] is r." """WRITEME
Set of all the (node, i) pairs such that node.inputs[i] is r."""
return r.clients return r.clients
def __add_clients__(self, r, new_clients): def __add_clients__(self, r, new_clients):
""" """ WRITEME
r -> result r -> result
new_clients -> list of (node, i) pairs such that node.inputs[i] is r. new_clients -> list of (node, i) pairs such that node.inputs[i] is r.
...@@ -117,7 +122,7 @@ class Env(utils.object2): ...@@ -117,7 +122,7 @@ class Env(utils.object2):
r.clients += new_clients r.clients += new_clients
def __remove_clients__(self, r, clients_to_remove, prune = True): def __remove_clients__(self, r, clients_to_remove, prune = True):
""" """ WRITEME
r -> result r -> result
clients_to_remove -> list of (op, i) pairs such that node.inputs[i] is not r anymore. clients_to_remove -> list of (op, i) pairs such that node.inputs[i] is not r anymore.
...@@ -213,7 +218,7 @@ class Env(utils.object2): ...@@ -213,7 +218,7 @@ class Env(utils.object2):
### change input ### ### change input ###
def change_input(self, node, i, new_r): def change_input(self, node, i, new_r):
""" """WRITEME
Changes node.inputs[i] to new_r. Changes node.inputs[i] to new_r.
new_r.type == old_r.type must be True, where old_r is the new_r.type == old_r.type must be True, where old_r is the
...@@ -246,7 +251,7 @@ class Env(utils.object2): ...@@ -246,7 +251,7 @@ class Env(utils.object2):
### replace ### ### replace ###
def replace(self, r, new_r): def replace(self, r, new_r):
""" """ WRITEME
This is the main interface to manipulate the subgraph in Env. This is the main interface to manipulate the subgraph in Env.
For every node that uses r as input, makes it use new_r instead. For every node that uses r as input, makes it use new_r instead.
""" """
...@@ -264,6 +269,7 @@ class Env(utils.object2): ...@@ -264,6 +269,7 @@ class Env(utils.object2):
self.change_input(node, i, new_r) self.change_input(node, i, new_r)
def replace_all(self, pairs): def replace_all(self, pairs):
"""WRITEME"""
for r, new_r in pairs: for r, new_r in pairs:
self.replace(r, new_r) self.replace(r, new_r)
...@@ -271,7 +277,7 @@ class Env(utils.object2): ...@@ -271,7 +277,7 @@ class Env(utils.object2):
### features ### ### features ###
def extend(self, feature): def extend(self, feature):
""" """WRITEME
Adds a feature to this env. The feature may define one Adds a feature to this env. The feature may define one
or more of the following methods: or more of the following methods:
...@@ -310,7 +316,7 @@ class Env(utils.object2): ...@@ -310,7 +316,7 @@ class Env(utils.object2):
self._features.append(feature) self._features.append(feature)
def remove_feature(self, feature): def remove_feature(self, feature):
""" """WRITEME
Removes the feature from the graph. Removes the feature from the graph.
Calls feature.on_detach(env) if an on_detach method is defined. Calls feature.on_detach(env) if an on_detach method is defined.
...@@ -327,7 +333,7 @@ class Env(utils.object2): ...@@ -327,7 +333,7 @@ class Env(utils.object2):
### callback utils ### ### callback utils ###
def execute_callbacks(self, name, *args): def execute_callbacks(self, name, *args):
""" """WRITEME
Calls Calls
getattr(feature, name)(*args) getattr(feature, name)(*args)
for each feature which has a method called after name. for each feature which has a method called after name.
...@@ -340,7 +346,7 @@ class Env(utils.object2): ...@@ -340,7 +346,7 @@ class Env(utils.object2):
fn(self, *args) fn(self, *args)
def collect_callbacks(self, name, *args): def collect_callbacks(self, name, *args):
""" """WRITEME
Returns a dictionary d such that: Returns a dictionary d such that:
d[feature] == getattr(feature, name)(*args) d[feature] == getattr(feature, name)(*args)
For each feature which has a method called after name. For each feature which has a method called after name.
...@@ -358,7 +364,7 @@ class Env(utils.object2): ...@@ -358,7 +364,7 @@ class Env(utils.object2):
### misc ### ### misc ###
def toposort(self): def toposort(self):
""" """WRITEME
Returns an ordering of the graph's Apply nodes such that: Returns an ordering of the graph's Apply nodes such that:
- All the nodes of the inputs of a node are before that node. - All the nodes of the inputs of a node are before that node.
- Satisfies the orderings provided by each feature that has - Satisfies the orderings provided by each feature that has
...@@ -379,7 +385,7 @@ class Env(utils.object2): ...@@ -379,7 +385,7 @@ class Env(utils.object2):
return order return order
def nclients(self, r): def nclients(self, r):
"Same as len(self.clients(r))." """WRITEME Same as len(self.clients(r))."""
return len(self.clients(r)) return len(self.clients(r))
# def edge(self, r): # def edge(self, r):
...@@ -395,7 +401,7 @@ class Env(utils.object2): ...@@ -395,7 +401,7 @@ class Env(utils.object2):
# return node.inputs # return node.inputs
def check_integrity(self): def check_integrity(self):
""" """WRITEME
Call this for a diagnosis if things go awry. Call this for a diagnosis if things go awry.
""" """
nodes = graph.ops(self.inputs, self.outputs) nodes = graph.ops(self.inputs, self.outputs)
...@@ -438,9 +444,11 @@ class Env(utils.object2): ...@@ -438,9 +444,11 @@ class Env(utils.object2):
### clone ### ### clone ###
def clone(self): def clone(self):
"""WRITEME"""
return self.clone_get_equiv()[0] return self.clone_get_equiv()[0]
def clone_get_equiv(self): def clone_get_equiv(self):
"""WRITEME"""
equiv = graph.clone_get_equiv(self.inputs, self.outputs) equiv = graph.clone_get_equiv(self.inputs, self.outputs)
self.check_integrity() self.check_integrity()
e = Env([equiv[i] for i in self.inputs], e = Env([equiv[i] for i in self.inputs],
......
"""Node classes (Apply, Result) and expression graph algorithms."""
from copy import copy from copy import copy
from collections import deque from collections import deque
...@@ -137,11 +137,32 @@ class Apply(utils.object2): ...@@ -137,11 +137,32 @@ class Apply(utils.object2):
class Result(utils.object2): class Result(utils.object2):
""" """
Represents the result of some computation (pointed to by its owner field), A variable in a theano expression graph.
or an input to the graph (if owner is None)
A Result which is the output of a symbolic computation has a reference to the Apply
instance to which it belongs (property: owner) and the position of itself in the owner's
output list (property: index).
A Result which is not the output of a symbolic computation will have an owner == None.
""" """
#__slots__ = ['type', 'owner', 'index', 'name'] #__slots__ = ['type', 'owner', 'index', 'name']
def __init__(self, type, owner = None, index = None, name = None): def __init__(self, type, owner = None, index = None, name = None):
"""Initialize type, owner, index, name.
@type type: a Type instance
@param type: the type governs the kind of data that can be associated with this
variable
@type owner: None or Apply instance
@param owner: the Apply instance which computes the value for this variable
@type index: None or int
@param index: the position of this Result in owner.outputs
@type name: None or str
@param name: a string for pretty-printing and debugging
"""
self.tag = utils.scratchpad() self.tag = utils.scratchpad()
self.type = type self.type = type
if owner is not None and not isinstance(owner, Apply): if owner is not None and not isinstance(owner, Apply):
...@@ -167,6 +188,14 @@ class Result(utils.object2): ...@@ -167,6 +188,14 @@ class Result(utils.object2):
def __repr__(self): def __repr__(self):
return str(self) return str(self)
def clone(self): def clone(self):
"""Return a new Result like self.
@rtype: Result instance
@return: a new Result instance (or subclass instance) with no owner or index.
@note: tags are copied to the returned instance.
@note: name is copied to the returned instance.
"""
#return copy(self) #return copy(self)
cp = self.__class__(self.type, None, None, self.name) cp = self.__class__(self.type, None, None, self.name)
cp.tag = copy(self.tag) cp.tag = copy(self.tag)
...@@ -174,13 +203,18 @@ class Result(utils.object2): ...@@ -174,13 +203,18 @@ class Result(utils.object2):
class Value(Result): class Value(Result):
""" """
Result with a data field. The data field is filtered by what is Result with a default 'data' field.
The data field is filtered by what is
provided in the constructor for the Value's type field. provided in the constructor for the Value's type field.
Its owner field is always None. Its owner field is always None.
""" """
#__slots__ = ['data'] #__slots__ = ['data']
def __init__(self, type, data, name = None): def __init__(self, type, data, name = None):
"""Initialize self.
WRITEME
"""
Result.__init__(self, type, None, None, name) Result.__init__(self, type, None, None, name)
self.data = type.filter(data) self.data = type.filter(data)
def __str__(self): def __str__(self):
...@@ -188,6 +222,7 @@ class Value(Result): ...@@ -188,6 +222,7 @@ class Value(Result):
return self.name return self.name
return "<" + str(self.data) + ">" #+ "::" + str(self.type) return "<" + str(self.data) + ">" #+ "::" + str(self.type)
def clone(self): def clone(self):
"""WRITEME"""
return self.__class__(self.type, copy(self.data), self.name) return self.__class__(self.type, copy(self.data), self.name)
def __set_owner(self, value): def __set_owner(self, value):
if value is not None: if value is not None:
...@@ -218,7 +253,7 @@ def stack_search(start, expand, mode='bfs', build_inv = False): ...@@ -218,7 +253,7 @@ def stack_search(start, expand, mode='bfs', build_inv = False):
"""Search through L{Result}s, either breadth- or depth-first """Search through L{Result}s, either breadth- or depth-first
@type start: deque @type start: deque
@param start: search from these nodes @param start: search from these nodes
@type explore: function @type explore: callable
@param explore: when we get to a node, add explore(node) to the list of @param explore: when we get to a node, add explore(node) to the list of
nodes to visit. This function should return a list, or None nodes to visit. This function should return a list, or None
@rtype: list of L{Result} @rtype: list of L{Result}
...@@ -256,7 +291,8 @@ def stack_search(start, expand, mode='bfs', build_inv = False): ...@@ -256,7 +291,8 @@ def stack_search(start, expand, mode='bfs', build_inv = False):
def inputs(result_list): def inputs(result_list):
""" """Return the inputs required to compute the given Results.
@type result_list: list of L{Result} @type result_list: list of L{Result}
@param result_list: output L{Result}s (from which to search backward through owners) @param result_list: output L{Result}s (from which to search backward through owners)
@returns: the list of L{Result}s with no owner, in the order found by a @returns: the list of L{Result}s with no owner, in the order found by a
...@@ -275,7 +311,7 @@ def inputs(result_list): ...@@ -275,7 +311,7 @@ def inputs(result_list):
def results_and_orphans(i, o): def results_and_orphans(i, o):
""" """WRITEME
""" """
def expand(r): def expand(r):
if r.owner and r not in i: if r.owner and r not in i:
...@@ -288,7 +324,8 @@ def results_and_orphans(i, o): ...@@ -288,7 +324,8 @@ def results_and_orphans(i, o):
def ops(i, o): def ops(i, o):
""" """ WRITEME
@type i: list @type i: list
@param i: input L{Result}s @param i: input L{Result}s
@type o: list @type o: list
...@@ -309,7 +346,8 @@ def ops(i, o): ...@@ -309,7 +346,8 @@ def ops(i, o):
def results(i, o): def results(i, o):
""" """ WRITEME
@type i: list @type i: list
@param i: input L{Result}s @param i: input L{Result}s
@type o: list @type o: list
...@@ -323,7 +361,8 @@ def results(i, o): ...@@ -323,7 +361,8 @@ def results(i, o):
def orphans(i, o): def orphans(i, o):
""" """ WRITEME
@type i: list @type i: list
@param i: input L{Result}s @param i: input L{Result}s
@type o: list @type o: list
...@@ -339,7 +378,8 @@ def orphans(i, o): ...@@ -339,7 +378,8 @@ def orphans(i, o):
def clone(i, o, copy_inputs = True): def clone(i, o, copy_inputs = True):
""" """ WRITEME
@type i: list @type i: list
@param i: input L{Result}s @param i: input L{Result}s
@type o: list @type o: list
...@@ -355,7 +395,8 @@ def clone(i, o, copy_inputs = True): ...@@ -355,7 +395,8 @@ def clone(i, o, copy_inputs = True):
def clone_get_equiv(i, o, copy_inputs_and_orphans = True): def clone_get_equiv(i, o, copy_inputs_and_orphans = True):
""" """ WRITEME
@type i: list @type i: list
@param i: input L{Result}s @param i: input L{Result}s
@type o: list @type o: list
...@@ -400,7 +441,8 @@ def clone_get_equiv(i, o, copy_inputs_and_orphans = True): ...@@ -400,7 +441,8 @@ def clone_get_equiv(i, o, copy_inputs_and_orphans = True):
return d return d
def general_toposort(r_out, deps, debug_print = False): def general_toposort(r_out, deps, debug_print = False):
""" """ WRITEME
@note: deps(i) should behave like a pure function (no funny business with @note: deps(i) should behave like a pure function (no funny business with
internal state) internal state)
...@@ -446,6 +488,8 @@ def general_toposort(r_out, deps, debug_print = False): ...@@ -446,6 +488,8 @@ def general_toposort(r_out, deps, debug_print = False):
def io_toposort(i, o, orderings = {}): def io_toposort(i, o, orderings = {}):
"""WRITEME
"""
iset = set(i) iset = set(i)
def deps(obj): def deps(obj):
rval = [] rval = []
...@@ -470,6 +514,7 @@ default_node_formatter = lambda op, argstrings: "%s(%s)" % (op.op, ...@@ -470,6 +514,7 @@ default_node_formatter = lambda op, argstrings: "%s(%s)" % (op.op,
def op_as_string(i, op, def op_as_string(i, op,
leaf_formatter = default_leaf_formatter, leaf_formatter = default_leaf_formatter,
node_formatter = default_node_formatter): node_formatter = default_node_formatter):
"""WRITEME"""
strs = as_string(i, op.inputs, leaf_formatter, node_formatter) strs = as_string(i, op.inputs, leaf_formatter, node_formatter)
return node_formatter(op, strs) return node_formatter(op, strs)
...@@ -477,7 +522,8 @@ def op_as_string(i, op, ...@@ -477,7 +522,8 @@ def op_as_string(i, op,
def as_string(i, o, def as_string(i, o,
leaf_formatter = default_leaf_formatter, leaf_formatter = default_leaf_formatter,
node_formatter = default_node_formatter): node_formatter = default_node_formatter):
""" """WRITEME
@type i: list @type i: list
@param i: input L{Result}s @param i: input L{Result}s
@type o: list @type o: list
...@@ -549,6 +595,8 @@ def view_roots(r): ...@@ -549,6 +595,8 @@ def view_roots(r):
""" """
Utility function that returns the leaves of a search through Utility function that returns the leaves of a search through
consecutive view_map()s. consecutive view_map()s.
WRITEME
""" """
owner = r.owner owner = r.owner
if owner is not None: if owner is not None:
......
"""WRITEME"""
import utils import utils
import graph import graph
...@@ -8,7 +8,7 @@ from copy import copy ...@@ -8,7 +8,7 @@ from copy import copy
__excepthook = sys.excepthook __excepthook = sys.excepthook
def thunk_hook(type, value, trace): def thunk_hook(type, value, trace):
""" """WRITEME
This function is meant to replace excepthook and do some This function is meant to replace excepthook and do some
special work if the exception value has a __thunk_trace__ special work if the exception value has a __thunk_trace__
field. In that case, it retrieves the field, which should field. In that case, it retrieves the field, which should
...@@ -32,6 +32,7 @@ sys.excepthook = thunk_hook ...@@ -32,6 +32,7 @@ sys.excepthook = thunk_hook
def raise_with_op(op, exc_info = None): def raise_with_op(op, exc_info = None):
"""WRITEME"""
if exc_info is None: if exc_info is None:
exc_info = sys.exc_info() exc_info = sys.exc_info()
exc_type, exc_value, exc_trace = exc_info exc_type, exc_value, exc_trace = exc_info
...@@ -45,6 +46,7 @@ def raise_with_op(op, exc_info = None): ...@@ -45,6 +46,7 @@ def raise_with_op(op, exc_info = None):
class Linker(object): class Linker(object):
"""WRITEME"""
def make_thunk(self): def make_thunk(self):
""" """
...@@ -108,6 +110,7 @@ class Linker(object): ...@@ -108,6 +110,7 @@ class Linker(object):
class Filter(object): class Filter(object):
"""WRITEME"""
def __init__(self, r, storage, readonly = False, strict = False, trace = ()): def __init__(self, r, storage, readonly = False, strict = False, trace = ()):
self.r = r self.r = r
self.type = r.type self.type = r.type
...@@ -134,6 +137,7 @@ class Filter(object): ...@@ -134,6 +137,7 @@ class Filter(object):
def map_storage(env, order, input_storage, output_storage): def map_storage(env, order, input_storage, output_storage):
"""WRITEME"""
if input_storage is None: if input_storage is None:
input_storage = [[None] for input in env.inputs] input_storage = [[None] for input in env.inputs]
else: else:
...@@ -165,6 +169,7 @@ def map_storage(env, order, input_storage, output_storage): ...@@ -165,6 +169,7 @@ def map_storage(env, order, input_storage, output_storage):
def streamline(env, thunks, order, no_recycling = [], profiler = None): def streamline(env, thunks, order, no_recycling = [], profiler = None):
"""WRITEME"""
def clear(): def clear():
for thunk in thunks: for thunk in thunks:
for output in thunk.outputs: for output in thunk.outputs:
...@@ -191,7 +196,7 @@ def streamline(env, thunks, order, no_recycling = [], profiler = None): ...@@ -191,7 +196,7 @@ def streamline(env, thunks, order, no_recycling = [], profiler = None):
return f return f
class LocalLinker(Linker): class LocalLinker(Linker):
""" """WRITEME
Useful base class for L{Linker}s which keep all nodes in the graph, and run a Useful base class for L{Linker}s which keep all nodes in the graph, and run a
thunk associated with each node. thunk associated with each node.
""" """
...@@ -214,7 +219,7 @@ class LocalLinker(Linker): ...@@ -214,7 +219,7 @@ class LocalLinker(Linker):
class PerformLinker(LocalLinker): class PerformLinker(LocalLinker):
""" """WRITEME
Basic L{Linker} subclass that calls the perform method on each L{Op} in Basic L{Linker} subclass that calls the perform method on each L{Op} in
the L{Env} in the order given by L{Env.toposort}. the L{Env} in the order given by L{Env.toposort}.
""" """
...@@ -262,7 +267,7 @@ class PerformLinker(LocalLinker): ...@@ -262,7 +267,7 @@ class PerformLinker(LocalLinker):
class WrapLinker(Linker): class WrapLinker(Linker):
""" """ WRITEME
This class makes it easier to run several L{LocalLinker}s in parallel, and This class makes it easier to run several L{LocalLinker}s in parallel, and
offers some control over how each thunk is run. offers some control over how each thunk is run.
...@@ -373,6 +378,7 @@ class WrapLinker(Linker): ...@@ -373,6 +378,7 @@ class WrapLinker(Linker):
import time import time
class Stats: class Stats:
"""WRITEME"""
def __init__(self): def __init__(self):
self.ncalls = 0 self.ncalls = 0
self.time = 0 self.time = 0
...@@ -384,7 +390,7 @@ class Stats: ...@@ -384,7 +390,7 @@ class Stats:
def inc_time_failures(self, v): self.time_failures += v def inc_time_failures(self, v): self.time_failures += v
class Profiler: class Profiler:
""" """WRITEME
Collects performance statistics on a function on a per-L{Op} Collects performance statistics on a function on a per-L{Op}
or per-L{Op}-class basis. or per-L{Op}-class basis.
""" """
...@@ -404,6 +410,7 @@ class Profiler: ...@@ -404,6 +410,7 @@ class Profiler:
self.by_class = by_class self.by_class = by_class
def profile_env(self, f, env): def profile_env(self, f, env):
"""WRITEME"""
stats = self.stats.setdefault('TOTAL', Stats()) stats = self.stats.setdefault('TOTAL', Stats())
n, t = stats.inc_ncalls, stats.inc_time n, t = stats.inc_ncalls, stats.inc_time
failed = False failed = False
...@@ -423,6 +430,7 @@ class Profiler: ...@@ -423,6 +430,7 @@ class Profiler:
raise ety, eva, etr raise ety, eva, etr
def profile_op(self, f, op): def profile_op(self, f, op):
"""WRITEME"""
if self.by_class: if self.by_class:
entry = op.__class__ entry = op.__class__
else: else:
...@@ -449,6 +457,7 @@ class Profiler: ...@@ -449,6 +457,7 @@ class Profiler:
def print_stats(self, sort_by = 'time'): def print_stats(self, sort_by = 'time'):
"""WRITEME"""
def compare_fn((op1, stat1), (op2, stat2)): def compare_fn((op1, stat1), (op2, stat2)):
x1 = getattr(stat2, sort_by) x1 = getattr(stat2, sort_by)
......
...@@ -11,6 +11,7 @@ class Op(utils.object2): ...@@ -11,6 +11,7 @@ class Op(utils.object2):
default_output = None default_output = None
"""@todo """@todo
WRITEME
""" """
############# #############
......
...@@ -15,14 +15,14 @@ from collections import deque ...@@ -15,14 +15,14 @@ from collections import deque
class Optimizer: class Optimizer:
""" """WRITEME
An L{Optimizer} can be applied to an L{Env} to transform it. An L{Optimizer} can be applied to an L{Env} to transform it.
It can represent an optimization or in general any kind It can represent an optimization or in general any kind
of transformation you could apply to an L{Env}. of transformation you could apply to an L{Env}.
""" """
def apply(self, env): def apply(self, env):
""" """WRITEME
Applies the optimization to the provided L{Env}. It may use all Applies the optimization to the provided L{Env}. It may use all
the methods defined by the L{Env}. If the L{Optimizer} needs the methods defined by the L{Env}. If the L{Optimizer} needs
to use a certain tool, such as an L{InstanceFinder}, it can do to use a certain tool, such as an L{InstanceFinder}, it can do
...@@ -31,7 +31,7 @@ class Optimizer: ...@@ -31,7 +31,7 @@ class Optimizer:
pass pass
def optimize(self, env, *args, **kwargs): def optimize(self, env, *args, **kwargs):
""" """WRITEME
This is meant as a shortcut to:: This is meant as a shortcut to::
opt.add_requirements(env) opt.add_requirements(env)
opt.apply(env) opt.apply(env)
...@@ -40,13 +40,13 @@ class Optimizer: ...@@ -40,13 +40,13 @@ class Optimizer:
self.apply(env, *args, **kwargs) self.apply(env, *args, **kwargs)
def __call__(self, env): def __call__(self, env):
""" """WRITEME
Same as self.optimize(env) Same as self.optimize(env)
""" """
return self.optimize(env) return self.optimize(env)
def add_requirements(self, env): def add_requirements(self, env):
""" """WRITEME
Add features to the env that are required to apply the optimization. Add features to the env that are required to apply the optimization.
For example: For example:
env.extend(History()) env.extend(History())
...@@ -57,29 +57,33 @@ class Optimizer: ...@@ -57,29 +57,33 @@ class Optimizer:
class FromFunctionOptimizer(Optimizer): class FromFunctionOptimizer(Optimizer):
"""WRITEME"""
def __init__(self, fn): def __init__(self, fn):
self.apply = fn self.apply = fn
def add_requirements(self, env): def add_requirements(self, env):
"""WRITEME"""
env.extend(gof.toolbox.ReplaceValidate) env.extend(gof.toolbox.ReplaceValidate)
def optimizer(f): def optimizer(f):
"""WRITEME"""
return FromFunctionOptimizer(f) return FromFunctionOptimizer(f)
class SeqOptimizer(Optimizer, list): class SeqOptimizer(Optimizer, list):
""" """WRITEME
Takes a list of L{Optimizer} instances and applies them Takes a list of L{Optimizer} instances and applies them
sequentially. sequentially.
""" """
def __init__(self, *opts): def __init__(self, *opts):
"""WRITEME"""
if len(opts) == 1 and isinstance(opts[0], (list, tuple)): if len(opts) == 1 and isinstance(opts[0], (list, tuple)):
opts = opts[0] opts = opts[0]
self[:] = opts self[:] = opts
def apply(self, env): def apply(self, env):
""" """WRITEME
Applies each L{Optimizer} in self in turn. Applies each L{Optimizer} in self in turn.
""" """
for optimizer in self: for optimizer in self:
...@@ -94,6 +98,7 @@ class SeqOptimizer(Optimizer, list): ...@@ -94,6 +98,7 @@ class SeqOptimizer(Optimizer, list):
class _metadict: class _metadict:
"""WRITEME"""
# dict that accepts unhashable keys # dict that accepts unhashable keys
# uses an associative list # uses an associative list
# for internal use only # for internal use only
...@@ -130,7 +135,7 @@ class _metadict: ...@@ -130,7 +135,7 @@ class _metadict:
class MergeOptimizer(Optimizer): class MergeOptimizer(Optimizer):
""" """WRITEME
Merges parts of the graph that are identical, i.e. parts that Merges parts of the graph that are identical, i.e. parts that
take the same inputs and carry out the asme computations so we take the same inputs and carry out the asme computations so we
can avoid doing them more than once. Also merges results that can avoid doing them more than once. Also merges results that
...@@ -184,7 +189,7 @@ class MergeOptimizer(Optimizer): ...@@ -184,7 +189,7 @@ class MergeOptimizer(Optimizer):
def MergeOptMerge(opt): def MergeOptMerge(opt):
""" """WRITEME
Returns an Optimizer that merges the graph then applies the Returns an Optimizer that merges the graph then applies the
optimizer in opt and then merges the graph again in case the optimizer in opt and then merges the graph again in case the
opt introduced additional similarities. opt introduced additional similarities.
...@@ -199,22 +204,26 @@ def MergeOptMerge(opt): ...@@ -199,22 +204,26 @@ def MergeOptMerge(opt):
######################## ########################
class LocalOptimizer(utils.object2): class LocalOptimizer(utils.object2):
"""WRITEME"""
def transform(self, node): def transform(self, node):
raise utils.AbstractFunctionError() raise utils.AbstractFunctionError()
class FromFunctionLocalOptimizer(LocalOptimizer): class FromFunctionLocalOptimizer(LocalOptimizer):
"""WRITEME"""
def __init__(self, fn): def __init__(self, fn):
self.transform = fn self.transform = fn
def add_requirements(self, env): def add_requirements(self, env):
env.extend(gof.toolbox.ReplaceValidate) env.extend(gof.toolbox.ReplaceValidate)
def local_optimizer(f): def local_optimizer(f):
"""WRITEME"""
return FromFunctionLocalOptimizer(f) return FromFunctionLocalOptimizer(f)
class LocalOptGroup(LocalOptimizer): class LocalOptGroup(LocalOptimizer):
"""WRITEME"""
def __init__(self, *optimizers): def __init__(self, *optimizers):
self.opts = optimizers self.opts = optimizers
...@@ -229,6 +238,7 @@ class LocalOptGroup(LocalOptimizer): ...@@ -229,6 +238,7 @@ class LocalOptGroup(LocalOptimizer):
class LocalOpKeyOptGroup(LocalOptGroup): class LocalOpKeyOptGroup(LocalOptGroup):
"""WRITEME"""
def __init__(self, optimizers): def __init__(self, optimizers):
if any(not hasattr(opt, 'op_key'), optimizers): if any(not hasattr(opt, 'op_key'), optimizers):
...@@ -240,7 +250,7 @@ class LocalOpKeyOptGroup(LocalOptGroup): ...@@ -240,7 +250,7 @@ class LocalOpKeyOptGroup(LocalOptGroup):
class OpSub(LocalOptimizer): class OpSub(LocalOptimizer):
""" """WRITEME
Replaces the application of a certain op by the application of Replaces the application of a certain op by the application of
another op that take the same inputs as what they are replacing. another op that take the same inputs as what they are replacing.
...@@ -277,7 +287,7 @@ class OpSub(LocalOptimizer): ...@@ -277,7 +287,7 @@ class OpSub(LocalOptimizer):
class OpRemove(LocalOptimizer): class OpRemove(LocalOptimizer):
""" """WRITEME
Removes all applications of an op by transferring each of its Removes all applications of an op by transferring each of its
outputs to the corresponding input. outputs to the corresponding input.
""" """
...@@ -304,7 +314,7 @@ class OpRemove(LocalOptimizer): ...@@ -304,7 +314,7 @@ class OpRemove(LocalOptimizer):
class PatternSub(LocalOptimizer): class PatternSub(LocalOptimizer):
""" """WRITEME
@todo update @todo update
Replaces all occurrences of the input pattern by the output pattern: Replaces all occurrences of the input pattern by the output pattern:
...@@ -448,6 +458,7 @@ class PatternSub(LocalOptimizer): ...@@ -448,6 +458,7 @@ class PatternSub(LocalOptimizer):
class NavigatorOptimizer(Optimizer): class NavigatorOptimizer(Optimizer):
"""WRITEME"""
def __init__(self, local_opt, ignore_newtrees = 'auto', failure_callback = None): def __init__(self, local_opt, ignore_newtrees = 'auto', failure_callback = None):
self.local_opt = local_opt self.local_opt = local_opt
...@@ -498,6 +509,7 @@ class NavigatorOptimizer(Optimizer): ...@@ -498,6 +509,7 @@ class NavigatorOptimizer(Optimizer):
class TopoOptimizer(NavigatorOptimizer): class TopoOptimizer(NavigatorOptimizer):
"""WRITEME"""
def __init__(self, local_opt, order = 'in_to_out', ignore_newtrees = False, failure_callback = None): def __init__(self, local_opt, order = 'in_to_out', ignore_newtrees = False, failure_callback = None):
if order not in ['out_to_in', 'in_to_out']: if order not in ['out_to_in', 'in_to_out']:
...@@ -531,6 +543,7 @@ class TopoOptimizer(NavigatorOptimizer): ...@@ -531,6 +543,7 @@ class TopoOptimizer(NavigatorOptimizer):
class OpKeyOptimizer(NavigatorOptimizer): class OpKeyOptimizer(NavigatorOptimizer):
"""WRITEME"""
def __init__(self, local_opt, ignore_newtrees = False, failure_callback = None): def __init__(self, local_opt, ignore_newtrees = False, failure_callback = None):
if not hasattr(local_opt, 'op_key'): if not hasattr(local_opt, 'op_key'):
...@@ -570,6 +583,7 @@ class OpKeyOptimizer(NavigatorOptimizer): ...@@ -570,6 +583,7 @@ class OpKeyOptimizer(NavigatorOptimizer):
def keep_going(exc, nav, repl_pairs): def keep_going(exc, nav, repl_pairs):
"""WRITEME"""
pass pass
...@@ -578,6 +592,7 @@ def keep_going(exc, nav, repl_pairs): ...@@ -578,6 +592,7 @@ def keep_going(exc, nav, repl_pairs):
################# #################
def _check_chain(r, chain): def _check_chain(r, chain):
"""WRITEME"""
chain = list(reversed(chain)) chain = list(reversed(chain))
while chain: while chain:
elem = chain.pop() elem = chain.pop()
...@@ -600,6 +615,7 @@ def _check_chain(r, chain): ...@@ -600,6 +615,7 @@ def _check_chain(r, chain):
return r return r
def check_chain(r, *chain): def check_chain(r, *chain):
"""WRITEME"""
if isinstance(r, graph.Apply): if isinstance(r, graph.Apply):
r = r.outputs[0] r = r.outputs[0]
return _check_chain(r, reduce(list.__iadd__, ([x, 0] for x in chain))) return _check_chain(r, reduce(list.__iadd__, ([x, 0] for x in chain)))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论