提交 c0b294ec authored 作者: Frédéric Bastien's avatar Frédéric Bastien 提交者: GitHub

Merge pull request #4629 from chinnadhurai/ccw_4483_indent_fix

Ccw 4483 indent fix
...@@ -17,12 +17,26 @@ class CallCache(object): ...@@ -17,12 +17,26 @@ class CallCache(object):
self.cache = {} self.cache = {}
def persist(self, filename=None): def persist(self, filename=None):
"""
Cache "filename" as a pickle file
"""
if filename is None: if filename is None:
filename = self.filename filename = self.filename
with open(filename, 'w') as f: with open(filename, 'w') as f:
pickle.dump(self.cache, f) pickle.dump(self.cache, f)
def call(self, fn, args=(), key=None): def call(self, fn, args=(), key=None):
"""
Retrieve item from the cache(if available)
based on a key
Parameters:
----------
key
parameter to retrieve cache item
fn,args
key to retrieve if "key" is None
"""
if key is None: if key is None:
key = (fn, tuple(args)) key = (fn, tuple(args))
if key not in self.cache: if key not in self.cache:
......
...@@ -61,8 +61,6 @@ def get_persistent_module_cache(): ...@@ -61,8 +61,6 @@ def get_persistent_module_cache():
class CodeBlock: class CodeBlock:
""" """
WRITEME
Represents a computation unit composed of declare, behavior, and cleanup. Represents a computation unit composed of declare, behavior, and cleanup.
The constructor initializes a L{CodeBlock} with templatized declare, The constructor initializes a L{CodeBlock} with templatized declare,
...@@ -118,6 +116,12 @@ def failure_code_init(sub): ...@@ -118,6 +116,12 @@ def failure_code_init(sub):
""" """
Code for failure in the struct init. Code for failure in the struct init.
Parameters:
----------
sub
Dictionary used to template the struct.
* failure_var -> must contain a variable name to use for
the failure code.
""" """
return '''{ return '''{
if (!PyErr_Occurred()) { if (!PyErr_Occurred()) {
...@@ -131,10 +135,10 @@ def failure_code_init(sub): ...@@ -131,10 +135,10 @@ def failure_code_init(sub):
def code_gen(blocks): def code_gen(blocks):
""" """
WRITEME
From a list of L{CodeBlock} instances, returns a string From a list of L{CodeBlock} instances, returns a string
that executes them all in sequence. eg for C{(decl1, task1, that executes them all in sequence.
Eg for C{(decl1, task1,
cleanup1)} and C{(decl2, task2, cleanup2)} the returned string cleanup1)} and C{(decl2, task2, cleanup2)} the returned string
will be of the form: will be of the form:
...@@ -149,6 +153,12 @@ def code_gen(blocks): ...@@ -149,6 +153,12 @@ def code_gen(blocks):
cleanup1 cleanup1
} }
Parameters:
----------
blocks
List of CodeBlock instances such that
* declarations, behavior and cleanup are in the run()
method of the struct
""" """
decl = "" decl = ""
head = "" head = ""
...@@ -162,8 +172,6 @@ def code_gen(blocks): ...@@ -162,8 +172,6 @@ def code_gen(blocks):
def struct_gen(args, struct_builders, blocks, sub): def struct_gen(args, struct_builders, blocks, sub):
""" """
WRITEME
Generates a struct conforming to the following specifications: Generates a struct conforming to the following specifications:
Parameters Parameters
...@@ -453,7 +461,7 @@ def get_c_sync(r, name, sub): ...@@ -453,7 +461,7 @@ def get_c_sync(r, name, sub):
def apply_policy(policy, r, name, sub): def apply_policy(policy, r, name, sub):
""" """
WRITEME Apply the list of policies to name.r,sub
Parameters Parameters
---------- ----------
...@@ -478,7 +486,7 @@ def apply_policy(policy, r, name, sub): ...@@ -478,7 +486,7 @@ def apply_policy(policy, r, name, sub):
def struct_variable_codeblocks(variable, policies, id, symbol_table, sub): def struct_variable_codeblocks(variable, policies, id, symbol_table, sub):
""" """
WRITEME Update "sub" dict and create two codeblocks with different failure modes
Parameters Parameters
---------- ----------
...@@ -525,8 +533,6 @@ def struct_variable_codeblocks(variable, policies, id, symbol_table, sub): ...@@ -525,8 +533,6 @@ def struct_variable_codeblocks(variable, policies, id, symbol_table, sub):
class CLinker(link.Linker): class CLinker(link.Linker):
""" """
WRITEME
Creates C code for an fgraph, compiles it and returns callables Creates C code for an fgraph, compiles it and returns callables
through make_thunk and make_function that make use of the compiled through make_thunk and make_function that make use of the compiled
code. code.
...@@ -544,7 +550,7 @@ class CLinker(link.Linker): ...@@ -544,7 +550,7 @@ class CLinker(link.Linker):
def accept(self, fgraph, no_recycling=None): def accept(self, fgraph, no_recycling=None):
""" """
WRITEME Associate linker with fgraph
""" """
if no_recycling is None: if no_recycling is None:
...@@ -559,8 +565,6 @@ class CLinker(link.Linker): ...@@ -559,8 +565,6 @@ class CLinker(link.Linker):
def fetch_variables(self): def fetch_variables(self):
""" """
WRITEME
Fills the inputs, outputs, variables, orphans, temps and node_order Fills the inputs, outputs, variables, orphans, temps and node_order
fields. fields.
...@@ -617,8 +621,6 @@ class CLinker(link.Linker): ...@@ -617,8 +621,6 @@ class CLinker(link.Linker):
def code_gen(self): def code_gen(self):
""" """
WRITEME
Generates code for a struct that does the computation of the fgraph and Generates code for a struct that does the computation of the fgraph and
stores it in the struct_code field of the instance. stores it in the struct_code field of the instance.
...@@ -890,14 +892,9 @@ class CLinker(link.Linker): ...@@ -890,14 +892,9 @@ class CLinker(link.Linker):
def support_code(self): def support_code(self):
""" """
WRITEME
Returns a list of support code strings that are needed by Returns a list of support code strings that are needed by
one or more Variables or Ops. The support code from Variables is one or more Variables or Ops.
added before the support code from Ops. The support code from Variables is added before the support code from Ops.This might contain duplicates.
This might contain duplicates.
""" """
ret = [] ret = []
# generic support code # generic support code
...@@ -911,8 +908,6 @@ class CLinker(link.Linker): ...@@ -911,8 +908,6 @@ class CLinker(link.Linker):
def compile_args(self): def compile_args(self):
""" """
WRITEME
Returns a list of compile args that are needed by one Returns a list of compile args that are needed by one
or more Variables or Ops. or more Variables or Ops.
...@@ -971,8 +966,6 @@ class CLinker(link.Linker): ...@@ -971,8 +966,6 @@ class CLinker(link.Linker):
def headers(self): def headers(self):
""" """
WRITEME
Returns a list of headers that are needed by one Returns a list of headers that are needed by one
or more Types or Ops. or more Types or Ops.
...@@ -1032,8 +1025,6 @@ class CLinker(link.Linker): ...@@ -1032,8 +1025,6 @@ class CLinker(link.Linker):
def header_dirs(self): def header_dirs(self):
""" """
WRITEME
Returns a list of lib directories that are needed by one Returns a list of lib directories that are needed by one
or more Types or Ops. or more Types or Ops.
...@@ -1055,8 +1046,6 @@ class CLinker(link.Linker): ...@@ -1055,8 +1046,6 @@ class CLinker(link.Linker):
def libraries(self): def libraries(self):
""" """
WRITEME
Returns a list of libraries that are needed by one Returns a list of libraries that are needed by one
or more Types or Ops. or more Types or Ops.
...@@ -1078,8 +1067,6 @@ class CLinker(link.Linker): ...@@ -1078,8 +1067,6 @@ class CLinker(link.Linker):
def lib_dirs(self): def lib_dirs(self):
""" """
WRITEME
Returns a list of lib directories that are needed by one Returns a list of lib directories that are needed by one
or more Types or Ops. or more Types or Ops.
...@@ -1101,7 +1088,7 @@ class CLinker(link.Linker): ...@@ -1101,7 +1088,7 @@ class CLinker(link.Linker):
def __compile__(self, input_storage=None, output_storage=None, def __compile__(self, input_storage=None, output_storage=None,
storage_map=None, keep_lock=False): storage_map=None, keep_lock=False):
"""WRITEME """
Compiles this linker's fgraph. Compiles this linker's fgraph.
Parameters Parameters
...@@ -1166,7 +1153,7 @@ class CLinker(link.Linker): ...@@ -1166,7 +1153,7 @@ class CLinker(link.Linker):
def make_thunk(self, input_storage=None, output_storage=None, def make_thunk(self, input_storage=None, output_storage=None,
storage_map=None, keep_lock=False): storage_map=None, keep_lock=False):
"""WRITEME """
Compiles this linker's fgraph and returns a function to perform the Compiles this linker's fgraph and returns a function to perform the
computations, as well as lists of storage cells for both the inputs computations, as well as lists of storage cells for both the inputs
and outputs. and outputs.
...@@ -1183,8 +1170,10 @@ class CLinker(link.Linker): ...@@ -1183,8 +1170,10 @@ class CLinker(link.Linker):
be allocated. be allocated.
storage_map: dict that map variables to storages. storage_map: dict that map variables to storages.
This is used when you need to customize the storage of This is used when you need to customize the storage of
this thunk. this thunk
keep_lock:
If True, we won't release the lock on the compiledir
at the end of this function call.
Returns: thunk, input_storage, output_storage Returns: thunk, input_storage, output_storage
The return values can be used as follows: The return values can be used as follows:
...@@ -1568,7 +1557,12 @@ class CLinker(link.Linker): ...@@ -1568,7 +1557,12 @@ class CLinker(link.Linker):
def cthunk_factory(self, error_storage, in_storage, out_storage, def cthunk_factory(self, error_storage, in_storage, out_storage,
storage_map=None, keep_lock=False): storage_map=None, keep_lock=False):
"""WRITEME """
Returns a thunk that points to an instance of a C struct that
can carry on the computation of this linker's fgraph
Parameters:
----------
error_storage -> list of length 3 error_storage -> list of length 3
in_storage -> list of lists of length 1, one per input in_storage -> list of lists of length 1, one per input
out_storage -> list of lists of length 1, one per output out_storage -> list of lists of length 1, one per output
...@@ -1705,8 +1699,6 @@ class _CThunk(object): ...@@ -1705,8 +1699,6 @@ class _CThunk(object):
class OpWiseCLinker(link.LocalLinker): class OpWiseCLinker(link.LocalLinker):
""" """
WRITEME
Uses CLinker on the individual Ops that comprise an fgraph and loops Uses CLinker on the individual Ops that comprise an fgraph and loops
over them in Python. The variable is slower than a compiled version of over them in Python. The variable is slower than a compiled version of
the whole fgraph, but saves on compilation time because small changes the whole fgraph, but saves on compilation time because small changes
...@@ -1746,6 +1738,9 @@ class OpWiseCLinker(link.LocalLinker): ...@@ -1746,6 +1738,9 @@ class OpWiseCLinker(link.LocalLinker):
self.schedule = schedule self.schedule = schedule
def accept(self, fgraph, no_recycling=None): def accept(self, fgraph, no_recycling=None):
"""
Associate linker with fgraph
"""
if no_recycling is None: if no_recycling is None:
no_recycling = [] no_recycling = []
if self.fgraph is not None and self.fgraph is not fgraph: if self.fgraph is not None and self.fgraph is not fgraph:
...@@ -1846,11 +1841,14 @@ class OpWiseCLinker(link.LocalLinker): ...@@ -1846,11 +1841,14 @@ class OpWiseCLinker(link.LocalLinker):
def _default_checker(x, y): def _default_checker(x, y):
""" """
WRITEME
Default checker for DualLinker. This checks that the Default checker for DualLinker. This checks that the
variables contain the same data using ==. variables contain the same data using ==.
Parameters:
----------
x,y
the variables to compare data
""" """
if x[0] != y[0]: if x[0] != y[0]:
raise Exception("Output mismatch.", raise Exception("Output mismatch.",
...@@ -1859,8 +1857,6 @@ def _default_checker(x, y): ...@@ -1859,8 +1857,6 @@ def _default_checker(x, y):
class DualLinker(link.Linker): class DualLinker(link.Linker):
""" """
WRITEME
Runs the fgraph in parallel using PerformLinker and CLinker. Runs the fgraph in parallel using PerformLinker and CLinker.
The thunk/function produced by DualLinker uses PerformLinker as the The thunk/function produced by DualLinker uses PerformLinker as the
...@@ -1902,6 +1898,9 @@ class DualLinker(link.Linker): ...@@ -1902,6 +1898,9 @@ class DualLinker(link.Linker):
self.schedule = schedule self.schedule = schedule
def accept(self, fgraph, no_recycling=None): def accept(self, fgraph, no_recycling=None):
"""
Update/tie self with fgraph
"""
if no_recycling is None: if no_recycling is None:
no_recycling = [] no_recycling = []
if self.fgraph is not None and self.fgraph is not fgraph: if self.fgraph is not None and self.fgraph is not fgraph:
...@@ -1912,7 +1911,10 @@ class DualLinker(link.Linker): ...@@ -1912,7 +1911,10 @@ class DualLinker(link.Linker):
return self return self
def make_thunk(self, **kwargs): def make_thunk(self, **kwargs):
"""
Compiles this linker's fgraph and returns a function to perform the
computations
"""
fgraph = self.fgraph fgraph = self.fgraph
no_recycling = self.no_recycling no_recycling = self.no_recycling
......
...@@ -1474,10 +1474,25 @@ class ModuleCache(object): ...@@ -1474,10 +1474,25 @@ class ModuleCache(object):
def _rmtree(parent, ignore_nocleanup=False, msg='', level=logging.DEBUG, def _rmtree(parent, ignore_nocleanup=False, msg='', level=logging.DEBUG,
ignore_if_missing=False): ignore_if_missing=False):
# On NFS filesystems, it is impossible to delete a directory with open """
# files in it. So instead, some commands in this file will respond to a On NFS filesystems, it is impossible to delete a directory with open
# failed rmtree() by touching a 'delete.me' file. This file is a message files in it.
# for a future process to try deleting the directory.
So instead, some commands in this file will respond to a
failed rmtree() by touching a 'delete.me' file. This file is a message
for a future process to try deleting the directory.
Parameters:
----------
parent
Root node to start deleting from
ignore_nocleanup
Delete the tree if flag is TRUE
level
Python Logging level. Set to "DEBUG" by default
ignore_if_missing
If set to True, just return without any issue if parent is NULL
"""
if ignore_if_missing and not os.path.exists(parent): if ignore_if_missing and not os.path.exists(parent):
return return
try: try:
...@@ -1504,6 +1519,7 @@ _module_cache = None ...@@ -1504,6 +1519,7 @@ _module_cache = None
def get_module_cache(dirname, init_args=None): def get_module_cache(dirname, init_args=None):
""" """
Create a new module_cache with the (k, v) pairs in this dictionary
Parameters Parameters
---------- ----------
......
...@@ -94,6 +94,9 @@ def cleanup(): ...@@ -94,6 +94,9 @@ def cleanup():
def print_compiledir_content(): def print_compiledir_content():
"""
print list of %d compiled individual ops in the "theano.config.compiledir"
"""
max_key_file_size = 1 * 1024 * 1024 # 1M max_key_file_size = 1 * 1024 * 1024 # 1M
compiledir = theano.config.compiledir compiledir = theano.config.compiledir
...@@ -178,6 +181,9 @@ def compiledir_purge(): ...@@ -178,6 +181,9 @@ def compiledir_purge():
def basecompiledir_ls(): def basecompiledir_ls():
"""
Print list of files in the "theano.config.base_compiledir"
"""
subdirs = [] subdirs = []
others = [] others = []
for f in os.listdir(config.base_compiledir): for f in os.listdir(config.base_compiledir):
......
...@@ -32,6 +32,7 @@ class ProtocolError(Exception): ...@@ -32,6 +32,7 @@ class ProtocolError(Exception):
def _contains_cycle(fgraph, orderings): def _contains_cycle(fgraph, orderings):
""" """
Function to check if the given graph contains a cycle
Parameters Parameters
---------- ----------
......
...@@ -66,7 +66,6 @@ class MissingInputError(Exception): ...@@ -66,7 +66,6 @@ class MissingInputError(Exception):
class FunctionGraph(utils.object2): class FunctionGraph(utils.object2):
""" """
WRITEME
A FunctionGraph represents a subgraph bound by a set of input variables and A FunctionGraph represents a subgraph bound by a set of input variables and
a set of output variables, ie a subgraph that specifies a theano function. a set of output variables, ie a subgraph that specifies a theano function.
The inputs list should contain all the inputs on which the outputs depend. The inputs list should contain all the inputs on which the outputs depend.
...@@ -265,8 +264,6 @@ class FunctionGraph(utils.object2): ...@@ -265,8 +264,6 @@ class FunctionGraph(utils.object2):
""" """
Updates the list of clients of r with new_clients. Updates the list of clients of r with new_clients.
WRITEME
Parameters Parameters
---------- ----------
r r
...@@ -365,6 +362,11 @@ class FunctionGraph(utils.object2): ...@@ -365,6 +362,11 @@ class FunctionGraph(utils.object2):
""" """
Import variables to this FunctionGraph and also their apply_node, Import variables to this FunctionGraph and also their apply_node,
if those nodes are not in this graph. if those nodes are not in this graph.
Parameters:
----------
reason
reason is the name of the optimization or operation in progress.
""" """
global NullType global NullType
if NullType is None: if NullType is None:
...@@ -438,8 +440,6 @@ class FunctionGraph(utils.object2): ...@@ -438,8 +440,6 @@ class FunctionGraph(utils.object2):
""" """
Changes node.inputs[i] to new_r. Changes node.inputs[i] to new_r.
WRITEME
new_r.type == old_r.type must be True, where old_r is the new_r.type == old_r.type must be True, where old_r is the
current value of node.inputs[i] which we want to replace. current value of node.inputs[i] which we want to replace.
...@@ -483,8 +483,6 @@ class FunctionGraph(utils.object2): ...@@ -483,8 +483,6 @@ class FunctionGraph(utils.object2):
# replace # # replace #
def replace(self, r, new_r, reason=None, verbose=None): def replace(self, r, new_r, reason=None, verbose=None):
""" """
WRITEME
This is the main interface to manipulate the subgraph in FunctionGraph. This is the main interface to manipulate the subgraph in FunctionGraph.
For every node that uses r as input, makes it use new_r instead. For every node that uses r as input, makes it use new_r instead.
...@@ -540,7 +538,7 @@ class FunctionGraph(utils.object2): ...@@ -540,7 +538,7 @@ class FunctionGraph(utils.object2):
def replace_all(self, pairs, reason=None): def replace_all(self, pairs, reason=None):
""" """
WRITEME For every node that uses r as input, makes it use new_r instead
""" """
for r, new_r in pairs: for r, new_r in pairs:
...@@ -578,8 +576,6 @@ class FunctionGraph(utils.object2): ...@@ -578,8 +576,6 @@ class FunctionGraph(utils.object2):
def remove_feature(self, feature): def remove_feature(self, feature):
""" """
WRITEME
Removes the feature from the graph. Removes the feature from the graph.
Calls feature.on_detach(function_graph) if an on_detach method Calls feature.on_detach(function_graph) if an on_detach method
...@@ -598,8 +594,6 @@ class FunctionGraph(utils.object2): ...@@ -598,8 +594,6 @@ class FunctionGraph(utils.object2):
# callback utils # # callback utils #
def execute_callbacks(self, name, *args, **kwargs): def execute_callbacks(self, name, *args, **kwargs):
""" """
WRITEME
Calls Calls
getattr(feature, name)(*args) getattr(feature, name)(*args)
for each feature which has a method called after name. for each feature which has a method called after name.
...@@ -621,8 +615,6 @@ class FunctionGraph(utils.object2): ...@@ -621,8 +615,6 @@ class FunctionGraph(utils.object2):
def collect_callbacks(self, name, *args): def collect_callbacks(self, name, *args):
""" """
WRITEME
Returns a dictionary d such that: Returns a dictionary d such that:
d[feature] == getattr(feature, name)(*args) d[feature] == getattr(feature, name)(*args)
For each feature which has a method called after name. For each feature which has a method called after name.
...@@ -640,8 +632,6 @@ class FunctionGraph(utils.object2): ...@@ -640,8 +632,6 @@ class FunctionGraph(utils.object2):
# misc # # misc #
def toposort(self): def toposort(self):
""" """
WRITEME
Return an ordering of the graph's Apply nodes such that: Return an ordering of the graph's Apply nodes such that:
- All the nodes of the inputs of a node are before that node. - All the nodes of the inputs of a node are before that node.
- Satisfies the orderings provided by each feature that has - Satisfies the orderings provided by each feature that has
...@@ -705,8 +695,6 @@ class FunctionGraph(utils.object2): ...@@ -705,8 +695,6 @@ class FunctionGraph(utils.object2):
def check_integrity(self): def check_integrity(self):
""" """
WRITEME
Call this for a diagnosis if things go awry. Call this for a diagnosis if things go awry.
""" """
...@@ -766,7 +754,7 @@ class FunctionGraph(utils.object2): ...@@ -766,7 +754,7 @@ class FunctionGraph(utils.object2):
# clone # # clone #
def clone(self, check_integrity=True): def clone(self, check_integrity=True):
""" """
WRITEME Clone the graph and get a memo( a dict )that map old node to new node
""" """
return self.clone_get_equiv(check_integrity)[0] return self.clone_get_equiv(check_integrity)[0]
......
...@@ -701,7 +701,15 @@ def inputs(variable_list, blockers=None): ...@@ -701,7 +701,15 @@ def inputs(variable_list, blockers=None):
def variables_and_orphans(i, o): def variables_and_orphans(i, o):
""" """
WRITEME Extract list of variables between i and o nodes via
dfs traversal and chooses the orphans among them
Parameters
----------
i : list
Input variables.
o : list
Output variables.
""" """
def expand(r): def expand(r):
...@@ -716,21 +724,21 @@ def variables_and_orphans(i, o): ...@@ -716,21 +724,21 @@ def variables_and_orphans(i, o):
def ops(i, o): def ops(i, o):
""" """
WRITEME Set of Ops contained within the subgraph between i and o
Parameters Parameters
---------- ----------
i : list i : list
Input L{Variable}s. Input variables.
o : list o : list
Output L{Variable}s. Output variables.
Returns Returns
------- -------
object object
The set of ops that are contained within the subgraph that lies The set of ops that are contained within the subgraph that lies
between i and o, including the owners of the L{Variable}s in o and between i and o, including the owners of the variables in o and
intermediary ops between i and o, but not the owners of the L{Variable}s intermediary ops between i and o, but not the owners of the variables
in i. in i.
""" """
...@@ -745,14 +753,14 @@ def ops(i, o): ...@@ -745,14 +753,14 @@ def ops(i, o):
def variables(i, o): def variables(i, o):
""" """
WRITEME Extracts list of variables within input and output nodes via dfs travesal
Parameters Parameters
---------- ----------
i : list i : list
Input L{Variable}s. Input variables.
o : list o : list
Output L{Variable}s. Output variables.
Returns Returns
------- -------
...@@ -767,14 +775,15 @@ def variables(i, o): ...@@ -767,14 +775,15 @@ def variables(i, o):
def orphans(i, o): def orphans(i, o):
""" """
WRITEME Extracts list of variables within input and output nodes
via dfs travesal and returns the orphans among them
Parameters Parameters
---------- ----------
i : list i : list
Input L{Variable}s. Input Variables.
o : list o : list
Output L{Variable}s. Output Variables.
Returns Returns
------- -------
...@@ -797,9 +806,9 @@ def clone(i, o, copy_inputs=True): ...@@ -797,9 +806,9 @@ def clone(i, o, copy_inputs=True):
Parameters Parameters
---------- ----------
i : list i : list
Input L{Variable}s. Input Variables.
o : list o : list
Output L{Variable}s. Output Variables.
copy_inputs : bool copy_inputs : bool
If True, the inputs will be copied (defaults to True). If True, the inputs will be copied (defaults to True).
...@@ -959,7 +968,7 @@ def general_toposort(r_out, deps, debug_print=False, ...@@ -959,7 +968,7 @@ def general_toposort(r_out, deps, debug_print=False,
def io_toposort(inputs, outputs, orderings=None, clients=None): def io_toposort(inputs, outputs, orderings=None, clients=None):
""" """
WRITEME Perform topological sort from input and output nodes
Parameters Parameters
---------- ----------
...@@ -1218,8 +1227,8 @@ def op_as_string(i, op, ...@@ -1218,8 +1227,8 @@ def op_as_string(i, op,
leaf_formatter=default_leaf_formatter, leaf_formatter=default_leaf_formatter,
node_formatter=default_node_formatter): node_formatter=default_node_formatter):
""" """
WRITEME Op to return a string representation of the subgraph
between i and o
""" """
strs = as_string(i, op.inputs, leaf_formatter, node_formatter) strs = as_string(i, op.inputs, leaf_formatter, node_formatter)
return node_formatter(op, strs) return node_formatter(op, strs)
...@@ -1229,7 +1238,7 @@ def as_string(i, o, ...@@ -1229,7 +1238,7 @@ def as_string(i, o,
leaf_formatter=default_leaf_formatter, leaf_formatter=default_leaf_formatter,
node_formatter=default_node_formatter): node_formatter=default_node_formatter):
""" """
WRITEME Returns a string representation of the subgraph between i and o
Parameters Parameters
---------- ----------
......
...@@ -52,16 +52,24 @@ def log_thunk_trace(value, f=sys.stderr): ...@@ -52,16 +52,24 @@ def log_thunk_trace(value, f=sys.stderr):
def thunk_hook(type, value, trace): def thunk_hook(type, value, trace):
""" """
WRITEME
This function is meant to replace excepthook and do some This function is meant to replace excepthook and do some
special work if the exception value has a __thunk_trace__ special work if the exception value has a __thunk_trace__
field. In that case, it retrieves the field, which should field.
In that case, it retrieves the field, which should
contain a trace as returned by L{traceback.extract_stack}, contain a trace as returned by L{traceback.extract_stack},
and prints it out on L{stderr}. and prints it out on L{stderr}.
The normal excepthook is then called. The normal excepthook is then called.
Parameters:
----------
type
Exception class
value
Exception instance
trace
Traceback object
Notes Notes
----- -----
This hook replaced by nosetests, so it does not run in nose tests. This hook replaced by nosetests, so it does not run in nose tests.
...@@ -680,8 +688,6 @@ def streamline(fgraph, thunks, order, post_thunk_old_storage=None, ...@@ -680,8 +688,6 @@ def streamline(fgraph, thunks, order, post_thunk_old_storage=None,
class LocalLinker(Linker): class LocalLinker(Linker):
""" """
WRITEME
Useful base class for L{Linker}s which keep all nodes in the graph, and run Useful base class for L{Linker}s which keep all nodes in the graph, and run
a thunk associated with each node. a thunk associated with each node.
...@@ -707,7 +713,7 @@ class LocalLinker(Linker): ...@@ -707,7 +713,7 @@ class LocalLinker(Linker):
def gc_helper(node_list): def gc_helper(node_list):
""" """
Return the set of Variable instances which are computed by node_list.
Parameters Parameters
---------- ----------
node_list node_list
...@@ -743,8 +749,6 @@ def gc_helper(node_list): ...@@ -743,8 +749,6 @@ def gc_helper(node_list):
class PerformLinker(LocalLinker): class PerformLinker(LocalLinker):
""" """
WRITEME
Basic L{Linker} subclass that calls the perform method on each L{Op} in Basic L{Linker} subclass that calls the perform method on each L{Op} in
the L{FunctionGraph} in the order given by L{Linker.schedule}. the L{FunctionGraph} in the order given by L{Linker.schedule}.
...@@ -764,8 +768,7 @@ class PerformLinker(LocalLinker): ...@@ -764,8 +768,7 @@ class PerformLinker(LocalLinker):
Parameters Parameters
---------- ----------
fgraph fgraph
A PerformLinker can have accepted one FunctionGraph instance at a A PerformLinker can have accepted one FunctionGraph instance at a time.
time.
no_recycling no_recycling
WRITEME WRITEME
...@@ -786,13 +789,14 @@ class PerformLinker(LocalLinker): ...@@ -786,13 +789,14 @@ class PerformLinker(LocalLinker):
def make_all(self, input_storage=None, output_storage=None, storage_map=None): def make_all(self, input_storage=None, output_storage=None, storage_map=None):
""" """
Returns Function to run all nodes, list of input containers, list of outputs
Parameters Parameters
---------- ----------
input_storage input_storage
WRITEME list of storages corresponding to fgraph.inputs
output_storage output_storage
WRITEME list of storages corresponding to fgraph.outputs
Returns Returns
------- -------
...@@ -879,8 +883,6 @@ def add_clear_storage(f, computed, storage_map): ...@@ -879,8 +883,6 @@ def add_clear_storage(f, computed, storage_map):
class WrapLinker(Linker): class WrapLinker(Linker):
""" """
WRITEME
This class makes it easier to run several L{LocalLinker}s in parallel, and This class makes it easier to run several L{LocalLinker}s in parallel, and
offers some control over how each thunk is run. offers some control over how each thunk is run.
......
...@@ -791,6 +791,9 @@ class Op(utils.object2, PureOp, CLinkerOp): ...@@ -791,6 +791,9 @@ class Op(utils.object2, PureOp, CLinkerOp):
self._op_use_c_code = use_c_code self._op_use_c_code = use_c_code
def _props(self): def _props(self):
"""
Tuple of properties of all attributes
"""
return tuple(getattr(self, a) for a in self.__props__) return tuple(getattr(self, a) for a in self.__props__)
def _props_dict(self): def _props_dict(self):
...@@ -924,6 +927,9 @@ class Op(utils.object2, PureOp, CLinkerOp): ...@@ -924,6 +927,9 @@ class Op(utils.object2, PureOp, CLinkerOp):
def make_thunk(self, node, storage_map, compute_map, no_recycling): def make_thunk(self, node, storage_map, compute_map, no_recycling):
""" """
This function must return a thunk, that is a zero-arguments
function that encapsulates the computation to be performed
by this op on the arguments of the node.
Parameters Parameters
---------- ----------
...@@ -974,7 +980,9 @@ class Op(utils.object2, PureOp, CLinkerOp): ...@@ -974,7 +980,9 @@ class Op(utils.object2, PureOp, CLinkerOp):
return self.make_py_thunk(node, storage_map, compute_map, no_recycling) return self.make_py_thunk(node, storage_map, compute_map, no_recycling)
def make_node(self, *inputs): def make_node(self, *inputs):
"""
Create a "apply" nodes for the inputs in that order.
"""
if not hasattr(self, 'itypes'): if not hasattr(self, 'itypes'):
raise NotImplementedError("You can either define itypes and otypes,\ raise NotImplementedError("You can either define itypes and otypes,\
or implement make_node") or implement make_node")
...@@ -1058,6 +1066,10 @@ def debug_error_message(msg): ...@@ -1058,6 +1066,10 @@ def debug_error_message(msg):
def debug_assert(condition, msg=None): def debug_assert(condition, msg=None):
"""
Customized assert with options to ignore the assert
with just a warning
"""
if msg is None: if msg is None:
msg = 'debug_assert failed' msg = 'debug_assert failed'
if not condition: if not condition:
...@@ -1165,12 +1177,18 @@ class OpenMPOp(Op): ...@@ -1165,12 +1177,18 @@ class OpenMPOp(Op):
self.openmp = False self.openmp = False
def c_compile_args(self): def c_compile_args(self):
"""
Return the compilation arg "fopenmp" if openMP is supported
"""
self.update_self_openmp() self.update_self_openmp()
if self.openmp: if self.openmp:
return ['-fopenmp'] return ['-fopenmp']
return [] return []
def c_headers(self): def c_headers(self):
"""
Return the header file name "omp.h" if openMP is supported
"""
self.update_self_openmp() self.update_self_openmp()
if self.openmp: if self.openmp:
return ["omp.h"] return ["omp.h"]
...@@ -1178,6 +1196,9 @@ class OpenMPOp(Op): ...@@ -1178,6 +1196,9 @@ class OpenMPOp(Op):
@staticmethod @staticmethod
def test_gxx_support(): def test_gxx_support():
"""
Check if openMP is supported
"""
code = """ code = """
#include <omp.h> #include <omp.h>
int main( int argc, const char* argv[] ) int main( int argc, const char* argv[] )
...@@ -1313,6 +1334,9 @@ class COp(Op): ...@@ -1313,6 +1334,9 @@ class COp(Op):
'and specify the func_name') 'and specify the func_name')
def load_c_code(self): def load_c_code(self):
"""
Loads the c code to perform the Op
"""
self.func_codes = [] self.func_codes = []
for func_file in self.func_files: for func_file in self.func_files:
with open(func_file, 'r') as f: with open(func_file, 'r') as f:
...@@ -1391,6 +1415,9 @@ class COp(Op): ...@@ -1391,6 +1415,9 @@ class COp(Op):
return hash(tuple(self.func_codes)) return hash(tuple(self.func_codes))
def c_init_code(self): def c_init_code(self):
"""
Get the code section for init_code
"""
if 'init_code' in self.code_sections: if 'init_code' in self.code_sections:
return [self.code_sections['init_code']] return [self.code_sections['init_code']]
else: else:
...@@ -1500,6 +1527,10 @@ class COp(Op): ...@@ -1500,6 +1527,10 @@ class COp(Op):
undef_macros.append("#undef OUTPUT_%d", (i,)) undef_macros.append("#undef OUTPUT_%d", (i,))
def c_init_code_struct(self, node, name, sub): def c_init_code_struct(self, node, name, sub):
"""
Stitches all the macros and "init_code" together
"""
if 'init_code_struct' in self.code_sections: if 'init_code_struct' in self.code_sections:
op_code = self.code_sections['init_code_struct'] op_code = self.code_sections['init_code_struct']
...@@ -1554,6 +1585,9 @@ class COp(Op): ...@@ -1554,6 +1585,9 @@ class COp(Op):
'c_code', type(self), type(self).__name__) 'c_code', type(self), type(self).__name__)
def c_code_cleanup(self, node, name, inputs, outputs, sub): def c_code_cleanup(self, node, name, inputs, outputs, sub):
"""
Stitches all the macros and "code_cleanup" together
"""
if 'code_cleanup' in self.code_sections: if 'code_cleanup' in self.code_sections:
op_code = self.code_sections['code_cleanup'] op_code = self.code_sections['code_cleanup']
......
from __future__ import absolute_import, print_function, division
from six.moves import reduce
from six import string_types
if 0:
class _EquilibriumOptimizer(NavigatorOptimizer):
def __init__(self,
local_optimizers,
failure_callback=None,
max_depth=None,
max_use_ratio=None):
super(EquilibriumOptimizer, self).__init__(
None,
ignore_newtrees=False,
failure_callback=failure_callback)
self.local_optimizers = local_optimizers
self.max_depth = max_depth
self.max_use_ratio = max_use_ratio
self.tracks = defaultdict(list)
self.tracks0 = defaultdict(list)
max_depth = 0
for lopt in local_optimizers:
tracks = lopt.tracks()
for track in tracks:
max_depth = max(max_depth, len(track))
if self.max_depth is not None and max_depth > self.max_depth:
raise ValueError('One of the local optimizers exceeds the maximal depth.')
for i, op in enumerate(track):
if i == 0:
self.tracks0[op].append((track, i, lopt))
self.tracks[op].append((track, i, lopt))
def fetch_tracks(self, op):
return self.tracks[op] + self.tracks[None]
def fetch_tracks0(self, op):
return self.tracks0[op] + self.tracks0[None]
def backtrack(self, node, tasks):
candidates = self.fetch_tracks(node.op)
tracks = []
def filter(node, depth):
new_candidates = []
for candidate in candidates:
track, i, lopt = candidate
if i < depth:
pass
elif track[i-depth] in (None, node.op):
if i == depth:
tasks[node].append(lopt)
else:
tracks.append(candidate)
else:
new_candidates.append(candidate)
return new_candidates
depth = 0
nodes = [node]
while candidates:
for node in nodes:
candidates = list(filter(node, depth))
depth += 1
_nodes = nodes
nodes = reduce(list.__iadd__,
[reduce(list.__iadd__,
[[n for n, i in out.clients if not isinstance(n, string_types)] for out in node.outputs],
[]) for node in nodes],
[])
candidates = tracks
tracks = []
def apply(self, fgraph):
tasks = defaultdict(list)
if self.max_use_ratio is not None:
max_uses = self.max_use_ratio * len(fgraph.apply_nodes)
runs = defaultdict(int)
else:
runs = None
def importer(node):
# print 'IMPORTING', node
self.backtrack(node, tasks)
def pruner(node):
try:
del tasks[node]
except KeyError:
pass
def chin(node, i, r, new_r):
if new_r.owner and not r.clients:
self.backtrack(new_r.owner, tasks)
# # == NOT IDEAL == #
# for node in fgraph.apply_nodes:
# importer(node)
for node in fgraph.toposort():
tasks[node].extend(lopt for track, i, lopt in self.fetch_tracks0(node.op))
u = self.attach_updater(fgraph, importer, pruner, chin)
print('KEYS', [hash(t) for t in tasks.keys()])
while tasks:
for node in tasks:
todo = tasks.pop(node)
break
for lopt in todo:
if runs is not None and runs[lopt] >= max_uses:
print('Warning: optimization exceeded its maximal use ratio: %s, %s' % (lopt, max_uses), file=sys.stderr)
continue
success = self.process_node(fgraph, node, lopt)
if success:
if runs is not None: runs[lopt] += 1
break
self.detach_updater(fgraph, u)
# def match(self, node, candidates):
# candidates[:] = [candidate
# for candidate in candidates
# if candidate.current.op is None or candidate.current.op == node.op]
# for candidate in candidates:
# if candidate.current.inputs is not None:
# for in1, in2 in zip(candidate.current.inputs, node.inputs):
# if isinstance(in1, string_types):
# candidate.match[in1] = in2
# for client in node.clients:
# op = node.op
# patterns = self.pattern_base[(depth, op)].union(self.pattern_base[(depth, WILDCARD)])
# if not patterns:
# return patterns
# return self.match(node, depth + 1).intersection(patterns)
# def backtrack(self, node, q):
# for node2, i in node.clients:
# op2 = node2.op
...@@ -268,6 +268,9 @@ def sort_schedule_fn(*cmps): ...@@ -268,6 +268,9 @@ def sort_schedule_fn(*cmps):
def key_to_cmp(key): def key_to_cmp(key):
"""
comparator function based on "key" function
"""
def key_cmp(a, b): def key_cmp(a, b):
return cmp(key(a), key(b)) return cmp(key(a), key(b))
return key_cmp return key_cmp
...@@ -114,10 +114,20 @@ class Feature(object): ...@@ -114,10 +114,20 @@ class Feature(object):
class Bookkeeper(Feature): class Bookkeeper(Feature):
def on_attach(self, fgraph): def on_attach(self, fgraph):
"""
Called by FunctionGraph.attach_feature, the method that attaches
the feature to the FunctionGraph. Since this is called after the
FunctionGraph is initially populated, this is where you should
run checks on the initial contents of the FunctionGraph.
"""
for node in graph.io_toposort(fgraph.inputs, fgraph.outputs): for node in graph.io_toposort(fgraph.inputs, fgraph.outputs):
self.on_import(fgraph, node, "on_attach") self.on_import(fgraph, node, "on_attach")
def on_detach(self, fgraph): def on_detach(self, fgraph):
"""
Should remove any dynamically added functionality
that it installed into the function_graph
"""
for node in graph.io_toposort(fgraph.inputs, fgraph.outputs): for node in graph.io_toposort(fgraph.inputs, fgraph.outputs):
self.on_prune(fgraph, node, 'Bookkeeper.detach') self.on_prune(fgraph, node, 'Bookkeeper.detach')
...@@ -178,6 +188,10 @@ class History(Feature): ...@@ -178,6 +188,10 @@ class History(Feature):
fgraph.revert = partial(self.revert, fgraph) fgraph.revert = partial(self.revert, fgraph)
def on_detach(self, fgraph): def on_detach(self, fgraph):
"""
Should remove any dynamically added functionality
that it installed into the function_graph
"""
del fgraph.checkpoint del fgraph.checkpoint
del fgraph.revert del fgraph.revert
del self.history[fgraph] del self.history[fgraph]
...@@ -223,10 +237,19 @@ class Validator(Feature): ...@@ -223,10 +237,19 @@ class Validator(Feature):
fgraph.consistent = partial(self.consistent_, fgraph) fgraph.consistent = partial(self.consistent_, fgraph)
def on_detach(self, fgraph): def on_detach(self, fgraph):
"""
Should remove any dynamically added functionality
that it installed into the function_graph
"""
del fgraph.validate del fgraph.validate
del fgraph.consistent del fgraph.consistent
def validate_(self, fgraph): def validate_(self, fgraph):
"""
If the caller is replace_all_validate, just raise the
exception. replace_all_validate will print out the
verbose output. Or it has to be done here before raise.
"""
t0 = time.time() t0 = time.time()
try: try:
ret = fgraph.execute_callbacks('validate') ret = fgraph.execute_callbacks('validate')
...@@ -289,6 +312,10 @@ class ReplaceValidate(History, Validator): ...@@ -289,6 +312,10 @@ class ReplaceValidate(History, Validator):
self.replace_all_validate_remove, fgraph) self.replace_all_validate_remove, fgraph)
def on_detach(self, fgraph): def on_detach(self, fgraph):
"""
Should remove any dynamically added functionality
that it installed into the function_graph
"""
History.on_detach(self, fgraph) History.on_detach(self, fgraph)
Validator.on_detach(self, fgraph) Validator.on_detach(self, fgraph)
del self._nodes_removed del self._nodes_removed
...@@ -412,6 +439,10 @@ class NodeFinder(Bookkeeper): ...@@ -412,6 +439,10 @@ class NodeFinder(Bookkeeper):
Bookkeeper.on_attach(self, fgraph) Bookkeeper.on_attach(self, fgraph)
def on_detach(self, fgraph): def on_detach(self, fgraph):
"""
Should remove any dynamically added functionality
that it installed into the function_graph
"""
if self.fgraph is not fgraph: if self.fgraph is not fgraph:
raise Exception("This NodeFinder instance was not attached to the" raise Exception("This NodeFinder instance was not attached to the"
" provided fgraph.") " provided fgraph.")
...@@ -461,6 +492,10 @@ class PrintListener(Feature): ...@@ -461,6 +492,10 @@ class PrintListener(Feature):
print("-- attaching to: ", fgraph) print("-- attaching to: ", fgraph)
def on_detach(self, fgraph): def on_detach(self, fgraph):
"""
Should remove any dynamically added functionality
that it installed into the function_graph
"""
if self.active: if self.active:
print("-- detaching from: ", fgraph) print("-- detaching from: ", fgraph)
......
...@@ -4,7 +4,7 @@ can be "unified" if there exists an assignment to all unification variables ...@@ -4,7 +4,7 @@ can be "unified" if there exists an assignment to all unification variables
such that the two expressions are equal. such that the two expressions are equal.
For instance, [5, A, B] and [A, C, 9] can be unified if A=C=5 and B=9, For instance, [5, A, B] and [A, C, 9] can be unified if A=C=5 and B=9,
yielding [5, 5, 9]. yielding [5, 5, 9].
[5, [A, B]] and [A, [1, 2]] cannot be unified because there is no value for A [5, [A, B]] and [A, [1, 2]] cannot be unified because there is no value for A
that satisfies the constraints. That's useful for pattern matching. that satisfies the constraints. That's useful for pattern matching.
...@@ -15,7 +15,6 @@ from copy import copy ...@@ -15,7 +15,6 @@ from copy import copy
from functools import partial from functools import partial
from theano.gof.utils import ANY_TYPE, comm_guard, FALL_THROUGH, iteritems from theano.gof.utils import ANY_TYPE, comm_guard, FALL_THROUGH, iteritems
################################ ################################
...@@ -135,8 +134,6 @@ class Unification: ...@@ -135,8 +134,6 @@ class Unification:
""" """
This class represents a possible unification of a group of variables This class represents a possible unification of a group of variables
with each other or with tangible values. with each other or with tangible values.
Parameters Parameters
---------- ----------
inplace : bool inplace : bool
...@@ -229,7 +226,7 @@ def unify_walk(a, b, U): ...@@ -229,7 +226,7 @@ def unify_walk(a, b, U):
return False return False
@comm_guard(FreeVariable, ANY_TYPE) @comm_guard(FreeVariable, ANY_TYPE) # noqa
def unify_walk(fv, o, U): def unify_walk(fv, o, U):
""" """
FreeV is unified to BoundVariable(other_object). FreeV is unified to BoundVariable(other_object).
...@@ -239,7 +236,7 @@ def unify_walk(fv, o, U): ...@@ -239,7 +236,7 @@ def unify_walk(fv, o, U):
return U.merge(v, fv) return U.merge(v, fv)
@comm_guard(BoundVariable, ANY_TYPE) @comm_guard(BoundVariable, ANY_TYPE) # noqa
def unify_walk(bv, o, U): def unify_walk(bv, o, U):
""" """
The unification succeed iff BV.value == other_object. The unification succeed iff BV.value == other_object.
...@@ -251,7 +248,7 @@ def unify_walk(bv, o, U): ...@@ -251,7 +248,7 @@ def unify_walk(bv, o, U):
return False return False
@comm_guard(OrVariable, ANY_TYPE) @comm_guard(OrVariable, ANY_TYPE) # noqa
def unify_walk(ov, o, U): def unify_walk(ov, o, U):
""" """
The unification succeeds iff other_object in OrV.options. The unification succeeds iff other_object in OrV.options.
...@@ -264,7 +261,7 @@ def unify_walk(ov, o, U): ...@@ -264,7 +261,7 @@ def unify_walk(ov, o, U):
return False return False
@comm_guard(NotVariable, ANY_TYPE) @comm_guard(NotVariable, ANY_TYPE) # noqa
def unify_walk(nv, o, U): def unify_walk(nv, o, U):
""" """
The unification succeeds iff other_object not in NV.not_options. The unification succeeds iff other_object not in NV.not_options.
...@@ -277,7 +274,7 @@ def unify_walk(nv, o, U): ...@@ -277,7 +274,7 @@ def unify_walk(nv, o, U):
return U.merge(v, nv) return U.merge(v, nv)
@comm_guard(FreeVariable, Variable) @comm_guard(FreeVariable, Variable) # noqa
def unify_walk(fv, v, U): def unify_walk(fv, v, U):
""" """
Both variables are unified. Both variables are unified.
...@@ -287,7 +284,7 @@ def unify_walk(fv, v, U): ...@@ -287,7 +284,7 @@ def unify_walk(fv, v, U):
return U.merge(v, fv) return U.merge(v, fv)
@comm_guard(BoundVariable, Variable) @comm_guard(BoundVariable, Variable) # noqa
def unify_walk(bv, v, U): def unify_walk(bv, v, U):
""" """
V is unified to BV.value. V is unified to BV.value.
...@@ -296,13 +293,13 @@ def unify_walk(bv, v, U): ...@@ -296,13 +293,13 @@ def unify_walk(bv, v, U):
return unify_walk(v, bv.value, U) return unify_walk(v, bv.value, U)
@comm_guard(OrVariable, OrVariable) @comm_guard(OrVariable, OrVariable) # noqa
def unify_walk(a, b, U): def unify_walk(a, b, U):
""" """
OrV(list1) == OrV(list2) == OrV(intersection(list1, list2)) OrV(list1) == OrV(list2) == OrV(intersection(list1, list2))
""" """
opt = intersection(a.options, b.options) opt = a.options.intersection(b.options)
if not opt: if not opt:
return False return False
elif len(opt) == 1: elif len(opt) == 1:
...@@ -312,18 +309,18 @@ def unify_walk(a, b, U): ...@@ -312,18 +309,18 @@ def unify_walk(a, b, U):
return U.merge(v, a, b) return U.merge(v, a, b)
@comm_guard(NotVariable, NotVariable) @comm_guard(NotVariable, NotVariable) # noqa
def unify_walk(a, b, U): def unify_walk(a, b, U):
""" """
NV(list1) == NV(list2) == NV(union(list1, list2)) NV(list1) == NV(list2) == NV(union(list1, list2))
""" """
opt = union(a.not_options, b.not_options) opt = a.not_options.union(b.not_options)
v = NotVariable("?", opt) v = NotVariable("?", opt)
return U.merge(v, a, b) return U.merge(v, a, b)
@comm_guard(OrVariable, NotVariable) @comm_guard(OrVariable, NotVariable) # noqa
def unify_walk(o, n, U): def unify_walk(o, n, U):
""" """
OrV(list1) == NV(list2) == OrV(list1 \ list2) OrV(list1) == NV(list2) == OrV(list1 \ list2)
...@@ -339,7 +336,7 @@ def unify_walk(o, n, U): ...@@ -339,7 +336,7 @@ def unify_walk(o, n, U):
return U.merge(v, o, n) return U.merge(v, o, n)
@comm_guard(VariableInList, (list, tuple)) @comm_guard(VariableInList, (list, tuple)) # noqa
def unify_walk(vil, l, U): def unify_walk(vil, l, U):
""" """
Unifies VIL's inner Variable to OrV(list). Unifies VIL's inner Variable to OrV(list).
...@@ -350,7 +347,7 @@ def unify_walk(vil, l, U): ...@@ -350,7 +347,7 @@ def unify_walk(vil, l, U):
return unify_walk(v, ov, U) return unify_walk(v, ov, U)
@comm_guard((list, tuple), (list, tuple)) @comm_guard((list, tuple), (list, tuple)) # noqa
def unify_walk(l1, l2, U): def unify_walk(l1, l2, U):
""" """
Tries to unify each corresponding pair of elements from l1 and l2. Tries to unify each corresponding pair of elements from l1 and l2.
...@@ -365,7 +362,7 @@ def unify_walk(l1, l2, U): ...@@ -365,7 +362,7 @@ def unify_walk(l1, l2, U):
return U return U
@comm_guard(dict, dict) @comm_guard(dict, dict) # noqa
def unify_walk(d1, d2, U): def unify_walk(d1, d2, U):
""" """
Tries to unify values of corresponding keys. Tries to unify values of corresponding keys.
...@@ -379,7 +376,7 @@ def unify_walk(d1, d2, U): ...@@ -379,7 +376,7 @@ def unify_walk(d1, d2, U):
return U return U
@comm_guard(ANY_TYPE, ANY_TYPE) @comm_guard(ANY_TYPE, ANY_TYPE) # noqa
def unify_walk(a, b, U): def unify_walk(a, b, U):
""" """
Checks for the existence of the __unify_walk__ method for one of Checks for the existence of the __unify_walk__ method for one of
...@@ -394,7 +391,7 @@ def unify_walk(a, b, U): ...@@ -394,7 +391,7 @@ def unify_walk(a, b, U):
return FALL_THROUGH return FALL_THROUGH
@comm_guard(Variable, ANY_TYPE) @comm_guard(Variable, ANY_TYPE) # noqa
def unify_walk(v, o, U): def unify_walk(v, o, U):
""" """
This simply checks if the Var has an unification in U and uses it This simply checks if the Var has an unification in U and uses it
...@@ -429,27 +426,27 @@ def unify_merge(a, b, U): ...@@ -429,27 +426,27 @@ def unify_merge(a, b, U):
return a return a
@comm_guard(Variable, ANY_TYPE) @comm_guard(Variable, ANY_TYPE) # noqa
def unify_merge(v, o, U): def unify_merge(v, o, U):
return v return v
@comm_guard(BoundVariable, ANY_TYPE) @comm_guard(BoundVariable, ANY_TYPE) # noqa
def unify_merge(bv, o, U): def unify_merge(bv, o, U):
return bv.value return bv.value
@comm_guard(VariableInList, (list, tuple)) @comm_guard(VariableInList, (list, tuple)) # noqa
def unify_merge(vil, l, U): def unify_merge(vil, l, U):
return [unify_merge(x, x, U) for x in l] return [unify_merge(x, x, U) for x in l]
@comm_guard((list, tuple), (list, tuple)) @comm_guard((list, tuple), (list, tuple)) # noqa
def unify_merge(l1, l2, U): def unify_merge(l1, l2, U):
return [unify_merge(x1, x2, U) for x1, x2 in zip(l1, l2)] return [unify_merge(x1, x2, U) for x1, x2 in zip(l1, l2)]
@comm_guard(dict, dict) @comm_guard(dict, dict) # noqa
def unify_merge(d1, d2, U): def unify_merge(d1, d2, U):
d = d1.__class__() d = d1.__class__()
for k1, v1 in iteritems(d1): for k1, v1 in iteritems(d1):
...@@ -463,12 +460,12 @@ def unify_merge(d1, d2, U): ...@@ -463,12 +460,12 @@ def unify_merge(d1, d2, U):
return d return d
@comm_guard(FVar, ANY_TYPE) @comm_guard(FVar, ANY_TYPE) # noqa
def unify_merge(vs, o, U): def unify_merge(vs, o, U):
return vs(U) return vs(U)
@comm_guard(ANY_TYPE, ANY_TYPE) @comm_guard(ANY_TYPE, ANY_TYPE) # noqa
def unify_merge(a, b, U): def unify_merge(a, b, U):
if (not isinstance(a, Variable) and if (not isinstance(a, Variable) and
not isinstance(b, Variable) and not isinstance(b, Variable) and
...@@ -478,7 +475,7 @@ def unify_merge(a, b, U): ...@@ -478,7 +475,7 @@ def unify_merge(a, b, U):
return FALL_THROUGH return FALL_THROUGH
@comm_guard(Variable, ANY_TYPE) @comm_guard(Variable, ANY_TYPE) # noqa
def unify_merge(v, o, U): def unify_merge(v, o, U):
""" """
This simply checks if the Var has an unification in U and uses it This simply checks if the Var has an unification in U and uses it
......
...@@ -27,6 +27,9 @@ logger = logging.getLogger(__name__) ...@@ -27,6 +27,9 @@ logger = logging.getLogger(__name__)
def calculate_reallocate_info(order, fgraph, storage_map, compute_map_re, def calculate_reallocate_info(order, fgraph, storage_map, compute_map_re,
dependencies): dependencies):
"""
WRITEME : explain the parameters
"""
reallocated_info = {} reallocated_info = {}
viewed_by = {} viewed_by = {}
for var in fgraph.variables: for var in fgraph.variables:
...@@ -189,7 +192,9 @@ class VM(object): ...@@ -189,7 +192,9 @@ class VM(object):
raise NotImplementedError('override me') raise NotImplementedError('override me')
def update_profile(self, profile): def update_profile(self, profile):
# accumulate into the profile object """
Accumulate into the profile object
"""
for node, thunk, t, c in zip(self.nodes, self.thunks, for node, thunk, t, c in zip(self.nodes, self.thunks,
self.call_times, self.call_counts): self.call_times, self.call_counts):
profile.apply_time.setdefault(node, 0.0) profile.apply_time.setdefault(node, 0.0)
...@@ -723,6 +728,9 @@ class VM_Linker(link.LocalLinker): ...@@ -723,6 +728,9 @@ class VM_Linker(link.LocalLinker):
def accept(self, fgraph, no_recycling=None): def accept(self, fgraph, no_recycling=None):
""" """
Check if fgraph is the first FunctionGraph that has ever been
associated to self, else, create a new VM_Linker
associated to fgraph
Parameters Parameters
---------- ----------
......
...@@ -126,13 +126,10 @@ whitelist_flake8 = [ ...@@ -126,13 +126,10 @@ whitelist_flake8 = [
"sparse/sandbox/sp2.py", "sparse/sandbox/sp2.py",
"sparse/sandbox/truedot.py", "sparse/sandbox/truedot.py",
"sparse/sandbox/sp.py", "sparse/sandbox/sp.py",
"gof/unify.py",
"gof/__init__.py", "gof/__init__.py",
"gof/sandbox/equilibrium.py",
"d3viz/__init__.py", "d3viz/__init__.py",
"d3viz/tests/__init__.py", "d3viz/tests/__init__.py",
"gof/tests/__init__.py", "gof/tests/__init__.py",
] ]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论