提交 5e536853 authored 作者: abergeron's avatar abergeron

Merge pull request #3299 from harlouci/numpydoc_gof

Numpydoc gof
"""
gof.py
gof stands for Graph Optimization Framework
gof stands for Graph Optimization Framework.
The gof submodule of theano implements a framework
for manipulating programs described as graphs. The
......@@ -9,13 +9,13 @@ gof module defines basic theano graph concepts:
-Apply nodes, which represent the application
of an Op to Variables. Together these make up a
graph.
-The Type, needed for Variables to make sense
-The Type, needed for Variables to make sense.
-The FunctionGraph, which defines how a subgraph
should be interpreted to implement a function
should be interpreted to implement a function.
-The Thunk, a callable object that becames part
of the executable emitted by theano
of the executable emitted by theano.
-Linkers/VMs, the objects that call Thunks in
sequence in order to execute a theano program
sequence in order to execute a theano program.
Conceptually, gof is intended to be sufficiently abstract
that it could be used to implement a language other than
......@@ -32,9 +32,9 @@ functionality. Ideally this should be refactored into
a different submodule.
For more details and discussion, see the theano-dev
e-mail thread "What is gof?"
"""
e-mail thread "What is gof?".
"""
from theano.gof.cc import \
CLinker, OpWiseCLinker, DualLinker, HideC
......
"""
Defines Linkers that deal with C implementations.
"""
from __future__ import print_function
......@@ -45,8 +46,13 @@ run_cthunk = None # Will be imported only when needed.
def get_module_cache(init_args=None):
"""
:param init_args: If not None, the (k, v) pairs in this dictionary will
be forwarded to the ModuleCache constructor as keyword arguments.
Parameters
----------
init_args
If not None, the (k, v) pairs in this dictionary will be forwarded to
the ModuleCache constructor as keyword arguments.
"""
return cmodule.get_module_cache(config.compiledir, init_args=init_args)
......@@ -63,25 +69,31 @@ def get_persistent_module_cache():
class CodeBlock:
"""WRITEME
"""
WRITEME
Represents a computation unit composed of declare, behavior, and cleanup.
@ivar declare: C code that declares variables for use by the computation
@ivar behavior: C code that performs the computation
@ivar cleanup: C code that cleans up things allocated or incref-ed
in behavior
The constructor initializes a L{CodeBlock} with templatized declare,
behavior and cleanup. The sub parameter will be used in the other
arguments' templates. sub should contain a key called 'id' that maps to an
identifier for this block. The identifier will be used to determine the
failure code and a label to jump to. It should also contain a key called
'failure_var' that contains the name of the variable that contains the error
code.
Parameters
----------
declare
C code that declares variables for use by the computation.
behavior
C code that performs the computation.
cleanup
C code that cleans up things allocated or incref-ed in behavior.
"""
def __init__(self, declare, behavior, cleanup, sub):
"""
Initialize a L{CodeBlock} with templatized declare, behavior
and cleanup. The sub parameter will be used in the other
arguments' templates. sub should contain a key called 'id'
that maps to an identifier for this block.
The identifier will be used to determine the failure code and
a label to jump to. It should also contain a key called
'failure_var' that contains the name of the variable that
contains the error code.
"""
self.declare = declare
self.behavior = behavior
# the dummy is because gcc throws an error when a label's
......@@ -94,10 +106,12 @@ class CodeBlock:
def failure_code(sub):
"""Code contained in sub['fail'], usually substituted for %(fail)s.
"""
Code contained in sub['fail'], usually substituted for %(fail)s.
It sets information about current error, then goto the code
actually handling the failure, which is defined in struct_gen().
"""
return '''{
%(failure_var)s = %(id)s;
......@@ -110,7 +124,10 @@ def failure_code(sub):
def failure_code_init(sub):
"Code for failure in the struct init."
"""
Code for failure in the struct init.
"""
return '''{
if (!PyErr_Occurred()) {
PyErr_SetString(PyExc_RuntimeError,
......@@ -122,10 +139,13 @@ def failure_code_init(sub):
def code_gen(blocks):
"""WRITEME From a list of L{CodeBlock} instances, returns a string
"""
WRITEME
From a list of L{CodeBlock} instances, returns a string
that executes them all in sequence. eg for C{(decl1, task1,
cleanup1)} and C{(decl2, task2, cleanup2)} the returned string
will be of the form::
will be of the form:
decl1
decl2
......@@ -137,8 +157,8 @@ def code_gen(blocks):
}
cleanup1
}
"""
"""
decl = ""
head = ""
tail = ""
......@@ -150,27 +170,39 @@ def code_gen(blocks):
def struct_gen(args, struct_builders, blocks, sub):
"""WRITEME
Generates a struct conforming to the following specifications:
* args -> all of the PyObject* type, stored in the struct
they represent the storage and must be length 1 python lists.
* struct_builders -> list of L{CodeBlock} instances such that
* declarations are in the struct
* behavior is in the constructor
* cleanup is in the destructor
* blocks -> list of CodeBlock instances such that
* declarations, behavior and cleanup are in the run()
method of the struct
* sub -> dictionary used to template the struct.
* failure_var -> must contain a variable name to use for
the failure code.
In a nutshell, this returns code for a struct that represents
a function with state. The state's initialization and destruction
are handled by struct_builders and the actual behavior of the
function is handled by blocks.
"""
WRITEME
Generates a struct conforming to the following specifications:
Parameters
----------
args
All of the PyObject* type, stored in the struct
they represent the storage and must be length 1 python lists.
struct_builders
List of L{CodeBlock} instances such that
* declarations are in the struct
* behavior is in the constructor
* cleanup is in the destructor
blocks
List of CodeBlock instances such that
* declarations, behavior and cleanup are in the run()
method of the struct
sub
Dictionary used to template the struct.
* failure_var -> must contain a variable name to use for
the failure code.
Returns
-------
object
In a nutshell, this returns code for a struct that represents
a function with state. The state's initialization and destruction
are handled by struct_builders and the actual behavior of the
function is handled by blocks.
"""
struct_decl = ""
struct_init_head = ""
struct_init_tail = ""
......@@ -276,12 +308,18 @@ def struct_gen(args, struct_builders, blocks, sub):
# with handling of the py_<name> variable.
def get_nothing(r, name, sub):
"""WRITEME"""
"""
WRITEME
"""
return ""
def get_c_declare(r, name, sub):
"""Wrapper around c_declare that declares py_name"""
"""
Wrapper around c_declare that declares py_name.
"""
# The declaration will be used by the Apply node that
# is computing it (`r.owner`), and by each of the clients.
# If some of these have `check_input=True` in their `.op`,
......@@ -302,7 +340,10 @@ def get_c_declare(r, name, sub):
def get_c_init(r, name, sub):
"""Wrapper around c_init that initializes py_name to Py_None"""
"""
Wrapper around c_init that initializes py_name to Py_None.
"""
pre = "" """
py_%(name)s = Py_None;
{Py_XINCREF(py_%(name)s);}
......@@ -311,7 +352,10 @@ def get_c_init(r, name, sub):
def get_c_extract(r, name, sub):
"""Wrapper around c_extract that initializes py_name from storage."""
"""
Wrapper around c_extract that initializes py_name from storage.
"""
# `c_extract` is called when getting the value of an apply node's
# input from the compute map, before being used by its clients.
# If one of the clients has `check_input=True`, we need to perform
......@@ -346,7 +390,10 @@ def get_c_extract(r, name, sub):
def get_c_extract_out(r, name, sub):
"""Wrapper around c_extract_out that initializes py_name from storage."""
"""
Wrapper around c_extract_out that initializes py_name from storage.
"""
# `c_extract_out` is used to extract an output variable from
# the compute map, to be used as pre-allocated memory for `r`
# before its value gets computed.
......@@ -376,7 +423,10 @@ def get_c_extract_out(r, name, sub):
def get_c_cleanup(r, name, sub):
"""Wrapper around c_cleanup that decrefs py_name"""
"""
Wrapper around c_cleanup that decrefs py_name.
"""
post = """
{Py_XDECREF(py_%(name)s);}
""" % locals()
......@@ -384,7 +434,10 @@ def get_c_cleanup(r, name, sub):
def get_c_sync(r, name, sub):
"""Wrapper around c_sync that syncs py_name with storage."""
"""
Wrapper around c_sync that syncs py_name with storage.
"""
return """
if (!%(failure_var)s) {
%(sync)s
......@@ -397,11 +450,21 @@ def get_c_sync(r, name, sub):
def apply_policy(policy, r, name, sub):
"""WRITEME
@param policy: list of functions that map a L{Variable} to a string,
or a single such function
@type r: L{Variable}
@return: C{policy[0](r) + policy[1](r) + ...}
"""
WRITEME
Parameters
----------
policy
List of functions that map a L{Variable} to a string,
or a single such function.
r: L{Variable}
Returns
-------
object
C{policy[0](r) + policy[1](r) + ...}.
"""
if isinstance(policy, (list, tuple)):
ret = ""
......@@ -412,22 +475,27 @@ def apply_policy(policy, r, name, sub):
def struct_variable_codeblocks(variable, policies, id, symbol_table, sub):
"""WRITEME
variable -> a Variable
policies -> a pair of tuples ((declare_policy, behavior_policy,
cleanup_policy), -- at construction
(declare_policy, behavior_policy,
cleanup_policy)) -- at execution
the first list will produce an element of the
'struct_builders' argument in struct_gen the second
list will produce an element of the 'blocks' argument
in struct_gen
id -> the id assigned to this variable's task in the computation
symbol_table -> a dict that maps variables to variable names. It
is not read by this function but a variable name for the
variable is computed and added to the table.
sub -> dictionary for use by L{CodeBlock}.
"""
WRITEME
Parameters
----------
variable : a Variable
policies : a pair of tuples
(declare_policy, behavior_policy, cleanup_policy) -- at construction.
(declare_policy, behavior_policy, cleanup_policy)) -- at execution.
The first list will produce an element of the 'struct_builders' argument
in struct_gen. The second list will produce an element of the 'blocks'
argument in struct_gen.
id
The id assigned to this variable's task in the computation.
symbol_table
A dict that maps variables to variable names. It is not read by this
function but a variable name for the variable is computed and added to
the table.
sub
Dictionary for use by L{CodeBlock}.
"""
name = "V%i" % id
......@@ -453,7 +521,8 @@ def struct_variable_codeblocks(variable, policies, id, symbol_table, sub):
class CLinker(link.Linker):
"""WRITEME
"""
WRITEME
Creates C code for an fgraph, compiles it and returns callables
through make_thunk and make_function that make use of the compiled
......@@ -462,6 +531,7 @@ class CLinker(link.Linker):
no_recycling can contain a list of Variables that belong to the fgraph.
If a Variable is in no_recycling, CLinker will clear the output storage
associated to it during the computation (to avoid reusing it).
"""
def __init__(self, schedule=None):
......@@ -470,7 +540,10 @@ class CLinker(link.Linker):
self.schedule = schedule
def accept(self, fgraph, no_recycling=None):
"""WRITEME"""
"""
WRITEME
"""
if no_recycling is None:
no_recycling = []
if self.fgraph is not None and self.fgraph is not fgraph:
......@@ -483,9 +556,12 @@ class CLinker(link.Linker):
return self
def fetch_variables(self):
"""WRITEME
Fills the inputs, outputs, variables, orphans,
temps and node_order fields.
"""
WRITEME
Fills the inputs, outputs, variables, orphans, temps and node_order
fields.
"""
fgraph = self.fgraph
self.inputs = fgraph.inputs
......@@ -527,7 +603,9 @@ class CLinker(link.Linker):
self.consts = []
def code_gen(self):
"""WRITEME
"""
WRITEME
Generates code for a struct that does the computation of the fgraph and
stores it in the struct_code field of the instance.
......@@ -538,6 +616,7 @@ class CLinker(link.Linker):
is avoided.
This method caches its computations.
"""
if getattr(self, 'struct_code', False):
......@@ -804,12 +883,15 @@ class CLinker(link.Linker):
return self.struct_code
def support_code(self):
"""WRITEME
"""
WRITEME
Returns a list of support code strings that are needed by
one or more Variables or Ops. The support code from Variables is
added before the support code from Ops.
This might contain duplicates.
"""
ret = []
# generic support code
......@@ -822,11 +904,14 @@ class CLinker(link.Linker):
return ret
def compile_args(self):
"""WRITEME
"""
WRITEME
Returns a list of compile args that are needed by one
or more Variables or Ops.
This might contain duplicates.
"""
ret = ["-O3"]
# this is the param the -ffast-math activate. I put the explicitly as
......@@ -871,11 +956,14 @@ class CLinker(link.Linker):
return ret
def headers(self):
"""WRITEME
"""
WRITEME
Returns a list of headers that are needed by one
or more Types or Ops.
The return value will not contain duplicates.
"""
ret = []
for x in [y.type for y in self.variables] + [
......@@ -890,7 +978,9 @@ class CLinker(link.Linker):
"""
Return a list of code snippets that have to be inserted
in the module initialization code.
The return value will not contain duplicates.
"""
ret = []
for x in [y.type for y in self.variables] + [
......@@ -923,11 +1013,14 @@ class CLinker(link.Linker):
return c_compiler
def header_dirs(self):
"""WRITEME
"""
WRITEME
Returns a list of lib directories that are needed by one
or more Types or Ops.
The return value will not contain duplicates.
"""
ret = []
for x in [y.type for y in self.variables] + [
......@@ -939,11 +1032,14 @@ class CLinker(link.Linker):
return utils.uniq(ret)
def libraries(self):
"""WRITEME
"""
WRITEME
Returns a list of libraries that are needed by one
or more Types or Ops.
The return value will not contain duplicates.
"""
ret = []
for x in [y.type for y in self.variables] + [
......@@ -955,11 +1051,14 @@ class CLinker(link.Linker):
return utils.uniq(ret)
def lib_dirs(self):
"""WRITEME
"""
WRITEME
Returns a list of lib directories that are needed by one
or more Types or Ops.
The return value will not contain duplicates.
"""
ret = []
for x in [y.type for y in self.variables] + [
......@@ -972,18 +1071,26 @@ class CLinker(link.Linker):
def __compile__(self, input_storage=None,
output_storage=None, keep_lock=False):
"""WRITEME
"""
WRITEME
Compiles this linker's fgraph.
@type input_storage: list or None
@param input_storage: list of lists of length 1. In order to use
the thunk returned by __compile__, the inputs must be put in
that storage. If None, storage will be allocated.
@param output_storage: list of lists of length 1. The thunk returned
by __compile__ will put the variables of the computation in these
lists. If None, storage will be allocated.
Parameters
----------
input_storage: list or None
List of lists of length 1. In order to use the thunk returned
by __compile__, the inputs must be put in that storage.
If None, storage will be allocated.
output_storage: list of lists of length 1
The thunk returned by __compile__ will put the variables of the
computation in these lists. If None, storage will be allocated.
Returns
-------
object
Thunk, input_storage, output_storage, error_storage.
Returns: thunk, input_storage, output_storage, error_storage
"""
error_storage = [None, None, None]
if input_storage is None:
......@@ -1037,27 +1144,34 @@ class CLinker(link.Linker):
def make_thunk(self, input_storage=None, output_storage=None,
keep_lock=False):
"""WRITEME
"""
WRITEME
Compiles this linker's fgraph and returns a function to perform the
computations, as well as lists of storage cells for both the
inputs and outputs.
computations, as well as lists of storage cells for both the inputs
and outputs.
@type input_storage: list or None
@param input_storage: list of lists of length 1. In order to use
Parameters
----------
input_storage: list or None
List of lists of length 1. In order to use
the thunk returned by __compile__, the inputs must be put in
that storage. If None, storage will be allocated.
@param output_storage: list of lists of length 1. The thunk returned
by __compile__ will put the variables of the computation in these
lists. If None, storage will be allocated.
Returns: thunk, input_storage, output_storage
The return values can be used as follows:
f, istor, ostor = clinker.make_thunk()
istor[0].data = first_input
istor[1].data = second_input
f()
first_output = ostor[0].data
output_storage: list of lists of length 1
The thunk returned by __compile__ will put the variables of the
computation in these lists. If None, storage will be allocated.
Returns
-------
object
Thunk, input_storage, output_storage.
The return values can be used as follows:
f, istor, ostor = clinker.make_thunk()
istor[0].data = first_input
istor[1].data = second_input
f()
first_output = ostor[0].data
"""
init_tasks, tasks = self.get_init_tasks()
cthunk, in_storage, out_storage, error_storage = self.__compile__(
......@@ -1069,7 +1183,8 @@ class CLinker(link.Linker):
return res, in_storage, out_storage
def cmodule_key(self):
"""Return a complete hashable signature of the module we compiled.
"""
Return a complete hashable signature of the module we compiled.
This function must have the property that no two programs that
compute different things yield the same key.
......@@ -1090,8 +1205,8 @@ class CLinker(link.Linker):
The outer tuple has a brief header, containing the compilation options
passed to the compiler, the libraries to link against, an md5 hash
of theano.config (for all config options where "in_c_key" is True).
It is followed by elements for every node in the
topological ordering of `self.fgraph`.
It is followed by elements for every node in the topological ordering
of `self.fgraph`.
If the Op of any Apply in the FunctionGraph does not have
c_code_cache_ok()==True, then this function raises a KeyError
......@@ -1101,7 +1216,7 @@ class CLinker(link.Linker):
---------------
Each input signature is a tuple with an element for each input
to the corresponding Apply node. Each element identifies the
to the corresponding Apply node. Each element identifies the
type of the node input, and the nature of that input in the
graph.
......@@ -1116,7 +1231,6 @@ class CLinker(link.Linker):
If a variable is also a graph output, then its position in the
outputs list is also bundled with this tuple (after the b).
The nature of a Constant instance is defined as its signature,
together with two integers: the topological position of the
first Apply using that Constant instance, and the lowest index
......@@ -1141,6 +1255,7 @@ class CLinker(link.Linker):
booleans, indicating whether each output is in the
no_recycling set. Older versions of compiled modules only have the
no_recycle list.
"""
return self.cmodule_key_(self.fgraph, self.no_recycling,
compile_args=self.compile_args(),
......@@ -1154,7 +1269,8 @@ class CLinker(link.Linker):
c_compiler=None):
"""
Do the actual computation of cmodule_key in a static method
to allow it to be reused in scalar.Composite.__eq__
to allow it to be reused in scalar.Composite.__eq__.
"""
if compile_args is None:
compile_args = []
......@@ -1311,6 +1427,7 @@ class CLinker(link.Linker):
"""
This compiles the source code for this linker and returns a
loaded module.
"""
if location is None:
location = cmodule.dlimport_workdir(config.compiledir)
......@@ -1353,11 +1470,12 @@ class CLinker(link.Linker):
return module
def get_dynamic_module(self):
"""Return a cmodule.DynamicModule instance full of the code
for our fgraph.
"""
Return a cmodule.DynamicModule instance full of the code for our fgraph.
This method is cached on the first call so it can be called
multiple times without penalty.
"""
if not hasattr(self, '_mod'):
self.code_gen()
......@@ -1412,16 +1530,24 @@ class CLinker(link.Linker):
def cthunk_factory(self, error_storage, in_storage, out_storage,
keep_lock=False):
"""WRITEME
error_storage -> list of length 3
in_storage -> list of lists of length 1, one per input
out_storage -> list of lists of length 1, one per output
Returns a thunk that points to an instance of a C struct that
can carry on the computation of this linker's fgraph. That thunk,
when executed, will fetch its inputs from in_storage, put its
outputs in out_storage and if an error occurs will put the
type, value and traceback of the exception in error_storage.
"""
WRITEME
Parameters
----------
error_storage : list of length 3
in_storage : list of lists of length 1, one per input
out_storage : list of lists of length 1, one per output
Returns
-------
object
A thunk that points to an instance of a C struct that
can carry on the computation of this linker's fgraph. That thunk,
when executed, will fetch its inputs from in_storage, put its
outputs in out_storage and if an error occurs will put the
type, value and traceback of the exception in error_storage.
"""
try:
key = self.cmodule_key()
......@@ -1481,18 +1607,22 @@ class CLinker(link.Linker):
class _CThunk(object):
"""
A thunk with a C implementation
A thunk with a C implementation.
Parameters
----------
cthunk
The CObject pointer used by run_cthunk.
init_tasks
WRITEME
tasks
WRITEME
error_storage
WRITEME
"""
def __init__(self, cthunk, init_tasks, tasks, error_storage):
"""
Parameters
----------
cthunk: the CObject pointer used by run_cthunk
init_tasks: WRITEME
tasks: WRITEME
error_storage: WRITEME
"""
global run_cthunk
if run_cthunk is None:
# Lazy import to avoid compilation when importing theano.
......@@ -1505,6 +1635,7 @@ class _CThunk(object):
def find_task(self, failure_code):
"""
Maps a failure code to the task that is associated to it.
"""
failure_code -= 1
n = len(self.init_tasks)
......@@ -1540,7 +1671,9 @@ class _CThunk(object):
class OpWiseCLinker(link.LocalLinker):
"""WRITEME
"""
WRITEME
Uses CLinker on the individual Ops that comprise an fgraph and loops
over them in Python. The variable is slower than a compiled version of
the whole fgraph, but saves on compilation time because small changes
......@@ -1554,10 +1687,12 @@ class OpWiseCLinker(link.LocalLinker):
If a Variable is in no_recycling, CLinker will clear the output storage
associated to it prior to computation (to avoid reusing it).
:note: This is in a sense the 'default' linker for Theano. The
Notes
-----
This is in a sense the 'default' linker for Theano. The
overhead of using the OpWiseCLinker as compared with the CLinker
is only noticeable for graphs of very small tensors (such as 20
elements or less)
elements or less).
"""
......@@ -1676,9 +1811,12 @@ class OpWiseCLinker(link.LocalLinker):
def _default_checker(x, y):
"""WRITEME
"""
WRITEME
Default checker for DualLinker. This checks that the
variables contain the same data using ==.
"""
if x[0] != y[0]:
raise Exception("Output mismatch.",
......@@ -1686,7 +1824,9 @@ def _default_checker(x, y):
class DualLinker(link.Linker):
"""WRITEME
"""
WRITEME
Runs the fgraph in parallel using PerformLinker and CLinker.
The thunk/function produced by DualLinker uses PerformLinker as the
......@@ -1695,6 +1835,7 @@ class DualLinker(link.Linker):
the fgraph on which it runs OpWiseCLinker. At each step, the variables
of perform and of the C implementation are verified using a checker
function.
"""
def __init__(self, checker=_default_checker, schedule=None):
......@@ -1719,6 +1860,7 @@ class DualLinker(link.Linker):
no_recycling can contain a list of Variables that belong to the fgraph.
If a Variable is in no_recycling, CLinker will clear the output storage
associated to it during the computation (to avoid reusing it).
"""
self.fgraph = None
self.checker = checker
......
"""Generate and compile C modules for Python,
"""
Generate and compile C modules for Python.
"""
from __future__ import print_function
import atexit
......@@ -81,16 +82,23 @@ import_time = 0
class MissingGXX(Exception):
"""
This error is raised when we try to generate c code,
but g++ is not available
but g++ is not available.
"""
pass
def debug_counter(name, every=1):
"""Debug counter to know how often we go through some piece of code.
"""
Debug counter to know how often we go through some piece of code.
This is a utility function one may use when debugging.
Example
-------
debug_counter('I want to know how often I run this line')
This is a utility function one may use when debugging. Usage example:
debug_counter('I want to know how often I run this line')
"""
setattr(debug_counter, name, getattr(debug_counter, name, 0) + 1)
n = getattr(debug_counter, name)
......@@ -99,27 +107,36 @@ def debug_counter(name, every=1):
class ExtFunction(object):
"""A C function to put into a DynamicModule """
"""
A C function to put into a DynamicModule.
"""
name = ""
"""string - function's name"""
"""
str - function's name.
"""
code_block = ""
"""string - the entire code for the function.
"""
str - the entire code for the function.
Has the form ``static PyObject* <name>([...]){ ... }
See Python's C API Reference for how to write c functions for python
modules.
"""
method = ""
"""
str - calling method for this function (i.e. 'METH_VARARGS', 'METH_NOARGS')
method = ""
"""
str - calling method for this function (i.e. 'METH_VARARGS', 'METH_NOARGS').
"""
doc = ""
"""str - documentation string for this function"""
"""
str - documentation string for this function.
"""
def __init__(self, name, code_block, method, doc="undocumented"):
self.name = name
......@@ -132,6 +149,7 @@ class ExtFunction(object):
Returns the signature for this function.
It goes into the DynamicModule's method table.
"""
return '\t{"%s", %s, %s, "%s"}' % (
self.name, self.name, self.method, self.doc)
......@@ -244,7 +262,10 @@ static struct PyModuleDef moduledef = {{
return rval
def list_code(self, ofile=sys.stdout):
"""Print out the code with line numbers to `ofile` """
"""
Print out the code with line numbers to `ofile`.
"""
for i, line in enumerate(self.code().split('\n')):
print(('%4i' % (i + 1)), line, file=ofile)
ofile.flush()
......@@ -253,15 +274,21 @@ static struct PyModuleDef moduledef = {{
def dlimport(fullpath, suffix=None):
"""Dynamically load a .so, .pyd, .dll, or .py file
"""
Dynamically load a .so, .pyd, .dll, or .py file.
:type fullpath: string
:param fullpath: a fully-qualified path do a compiled python module
:param suffix: a suffix to strip from the end of fullpath to get the
import name
:type suffix: string
Parameters
----------
fullpath : str
A fully-qualified path do a compiled python module.
suffix : str
A suffix to strip from the end of fullpath to get the
import name.
:returns: the dynamically loaded module (from __import__)
Returns
-------
object
The dynamically loaded module (from __import__).
"""
if not os.path.isabs(fullpath):
......@@ -310,7 +337,8 @@ def dlimport_workdir(basedir):
"""
Return a directory where you should put your .so file for dlimport
to be able to load it, given a basedir which should normally be
config.compiledir
config.compiledir.
"""
return tempfile.mkdtemp(dir=basedir)
......@@ -319,6 +347,7 @@ def last_access_time(path):
"""
Return the number of seconds since the epoch of the last access of a
given file.
"""
return os.stat(path)[stat.ST_ATIME]
......@@ -327,6 +356,7 @@ def module_name_from_dir(dirname, err=True, files=None):
"""
Scan the contents of a cache directory and return full path of the
dynamic lib in it.
"""
if files is None:
files = os.listdir(dirname)
......@@ -349,6 +379,7 @@ def is_same_entry(entry_1, entry_2):
- They are equal.
- Their real paths are equal.
- They share the same temporary work directory and module file name.
"""
if entry_1 == entry_2:
return True
......@@ -372,6 +403,7 @@ def get_module_hash(src_code, key):
3. The compiler options defined in `key` (command line parameters and
libraries to link against).
4. The NumPy ABI version.
"""
# `to_hash` will contain any element such that we know for sure that if
# it changes, then the module hash should be different.
......@@ -425,6 +457,7 @@ def get_safe_part(key):
It is used to reduce the amount of key comparisons one has to go through
in order to find broken keys (i.e. keys with bad implementations of __eq__
or __hash__).
"""
version = key[0]
# This function should only be called on versioned keys.
......@@ -442,35 +475,43 @@ def get_safe_part(key):
class KeyData(object):
"""
Used to store the key information in the cache.
Parameters
----------
keys
Set of keys that are associated to the exact same module.
module_hash
Hash identifying the module (it should hash both the code and the
compilation options).
key_pkl
Path to the file in which this KeyData object should be
pickled.
"""Used to store the key information in the cache."""
"""
def __init__(self, keys, module_hash, key_pkl, entry):
"""
Constructor.
:param keys: Set of keys that are associated to the exact same module.
:param module_hash: Hash identifying the module (it should hash both
the code and the compilation options).
:param key_pkl: Path to the file in which this KeyData object should be
pickled.
"""
self.keys = keys
self.module_hash = module_hash
self.key_pkl = key_pkl
self.entry = entry
def add_key(self, key, save_pkl=True):
"""Add a key to self.keys, and update pickled file if asked to."""
"""
Add a key to self.keys, and update pickled file if asked to.
"""
assert key not in self.keys
self.keys.add(key)
if save_pkl:
self.save_pkl()
def remove_key(self, key, save_pkl=True):
"""Remove a key from self.keys, and update pickled file if asked to."""
"""
Remove a key from self.keys, and update pickled file if asked to.
"""
self.keys.remove(key)
if save_pkl:
self.save_pkl()
......@@ -481,6 +522,7 @@ class KeyData(object):
May raise a cPickle.PicklingError if such an exception is raised at
pickle time (in which case a warning is also displayed).
"""
# Note that writing in binary mode is important under Windows.
try:
......@@ -493,7 +535,10 @@ class KeyData(object):
raise
def get_entry(self):
"""Return path to the module file."""
"""
Return path to the module file.
"""
# TODO This method may be removed in the future (e.g. in 0.5) since
# its only purpose is to make sure that old KeyData objects created
# before the 'entry' field was added are properly handled.
......@@ -508,6 +553,7 @@ class KeyData(object):
Note that broken keys will not appear in the keys field, so we also
manually look for keys associated to the same entry, unless
do_manual_check is False.
"""
entry = self.get_entry()
for key in self.keys:
......@@ -522,7 +568,8 @@ class KeyData(object):
class ModuleCache(object):
"""Interface to the cache of dynamically compiled modules on disk
"""
Interface to the cache of dynamically compiled modules on disk.
Note that this interface does not assume exclusive use of the cache
directory. It is built to handle the case where multiple programs are also
......@@ -569,43 +616,58 @@ class ModuleCache(object):
- They share the same C code.
These three elements uniquely identify a module, and are summarized
in a single "module hash".
Parameters
----------
check_for_broken_eq
A bad __eq__ implementation can break this cache mechanism.
This option turns on a not-too-expensive sanity check every
time a new key is added to the cache.
do_refresh : bool
If True, then the ``refresh`` method will be called
in the constructor.
"""
dirname = ""
"""The working directory that is managed by this interface"""
"""
The working directory that is managed by this interface.
"""
module_from_name = {}
"""maps a module filename to the loaded module object"""
"""
Maps a module filename to the loaded module object.
"""
entry_from_key = {}
"""Maps keys to the filename of a .so/.pyd.
"""
Maps keys to the filename of a .so/.pyd.
"""
similar_keys = {}
"""Maps a part-of-key to all keys that share this same part."""
"""
Maps a part-of-key to all keys that share this same part.
"""
module_hash_to_key_data = {}
"""Maps a module hash to its corresponding KeyData object."""
"""
Maps a module hash to its corresponding KeyData object.
"""
stats = []
"""
A list with counters for the number of hits, loads, compiles issued by
module_from_key()
"""
module_from_key().
"""
loaded_key_pkl = set()
"""set of all key.pkl files that have been loaded.
"""
Set of all key.pkl files that have been loaded.
def __init__(self, dirname, check_for_broken_eq=True, do_refresh=True):
"""
:param check_for_broken_eq: A bad __eq__ implementation can break this
cache mechanism. This option turns on a not-too-expensive sanity check
every time a new key is added to the cache.
"""
:param do_refresh: If True, then the ``refresh`` method will be called
in the constructor.
"""
def __init__(self, dirname, check_for_broken_eq=True, do_refresh=True):
self.dirname = dirname
self.module_from_name = dict(self.module_from_name)
self.entry_from_key = dict(self.entry_from_key)
......@@ -624,11 +686,13 @@ class ModuleCache(object):
The default age threshold (in seconds) for cache files we want to use.
Older modules will be deleted in ``clear_old``.
"""
def _get_module(self, name):
"""
Fetch a compiled module from the loaded cache or the disk.
"""
if name not in self.module_from_name:
_logger.debug('loading name %s', name)
......@@ -641,25 +705,31 @@ class ModuleCache(object):
def refresh(self, age_thresh_use=None, delete_if_problem=False,
cleanup=True):
"""Update cache data by walking the cache directory structure.
"""
Update cache data by walking the cache directory structure.
Load key.pkl files that have not been loaded yet.
Remove entries which have been removed from the filesystem.
Also, remove malformed cache directories.
:param age_thresh_use: Do not use modules olther than this.
Defaults to self.age_thresh_use.
:param delete_if_problem: If True, cache entries that meet one
of those two conditions are deleted:
Parameters
----------
age_thresh_use
Do not use modules other than this. Defaults to self.age_thresh_use.
delete_if_problem : bool
If True, cache entries that meet one of those two conditions are
deleted:
- Those for which unpickling the KeyData file fails with
an unknown exception.
- Duplicated modules, regardless of their age.
cleanup : bool
Do a cleanup of the cache removing expired and broken modules.
:param cleanup: Do a cleanup of the cache removing expired and
broken modules.
Returns
-------
list
A list of modules of age higher than age_thresh_use.
:returns: a list of modules of age higher than age_thresh_use.
"""
if age_thresh_use is None:
age_thresh_use = self.age_thresh_use
......@@ -935,6 +1005,7 @@ class ModuleCache(object):
and None otherwise.
May raise ValueError if the key is malformed.
"""
name = None
if key is not None:
......@@ -993,6 +1064,7 @@ class ModuleCache(object):
def _add_to_cache(self, module, key, module_hash):
"""
This function expects the compile lock to be held.
"""
name = module.__file__
_logger.debug("Adding module to cache %s %s",
......@@ -1036,18 +1108,19 @@ class ModuleCache(object):
"""
Return a module from the cache, compiling it if necessary.
:param key: The key object associated with the module. If this
hits a match, we avoid compilation.
:param lnk: Usually a CLinker instance, but it can be any
object that defines the `get_src_code()` and
`compile_cmodule(location)` functions. The first
one returns the source code of the module to
load/compile and the second performs the actual
compilation.
Parameters
----------
key
The key object associated with the module. If this hits a match,
we avoid compilation.
lnk
Usually a CLinker instance, but it can be any object that defines
the `get_src_code()` and `compile_cmodule(location)` functions. The
first one returns the source code of the module to load/compile and
the second performs the actual compilation.
keep_lock : bool
If True, the compilation lock will not be released if taken.
:param keep_lock: If True, the compilation lock will not be
released if taken.
"""
# Is the module in the cache?
module = self._get_from_key(key)
......@@ -1123,8 +1196,13 @@ class ModuleCache(object):
"""
Perform checks to detect broken __eq__ / __hash__ implementations.
:param key: The key to be checked.
:param key_pkl: Its associated pickled file containing a KeyData.
Parameters
----------
key
The key to be checked.
key_pkl
Its associated pickled file containing a KeyData.
"""
start_time = time.time()
# Verify that when we reload the KeyData from the pickled file, the
......@@ -1177,18 +1255,24 @@ class ModuleCache(object):
age_thresh_del = 60 * 60 * 24 * 31 # 31 days
age_thresh_del_unversioned = 60 * 60 * 24 * 7 # 7 days
"""
The default age threshold for `clear_old` (in seconds).
"""The default age threshold for `clear_old` (in seconds)
"""
def clear_old(self, age_thresh_del=None, delete_if_problem=False):
"""
Delete entries from the filesystem for cache entries that are too old.
:param age_thresh_del: Dynamic modules whose last access time is more
than ``age_thresh_del`` seconds ago will be erased. Defaults to 31-day
age if not provided.
Parameters
----------
age_thresh_del
Dynamic modules whose last access time is more than
``age_thresh_del`` seconds ago will be erased.
Defaults to 31-day age if not provided.
delete_if_problem
See help of refresh() method.
:param delete_if_problem: See help of refresh() method.
"""
if age_thresh_del is None:
age_thresh_del = self.age_thresh_del
......@@ -1232,16 +1316,19 @@ class ModuleCache(object):
"""
Clear all elements in the cache.
:param unversioned_min_age: Forwarded to `clear_unversioned`. In
particular, you can set it to -1 in order to delete all unversioned
cached modules regardless of their age.
Parameters
----------
unversioned_min_age
Forwarded to `clear_unversioned`. In particular, you can set it
to -1 in order to delete all unversioned cached modules regardless
of their age.
clear_base_files : bool
If True, then delete base directories 'cuda_ndarray', 'cutils_ext',
'lazylinker_ext' and 'scan_perform' if they are present.
If False, those directories are left intact.
delete_if_problem
See help of refresh() method.
:param clear_base_files: If True, then delete base directories
'cuda_ndarray', 'cutils_ext', 'lazylinker_ext' and 'scan_perform'
if they are present.
If False, those directories are left intact.
:param delete_if_problem: See help of refresh() method.
"""
with compilelock.lock_ctx():
self.clear_old(
......@@ -1260,6 +1347,7 @@ class ModuleCache(object):
some systems due to these modules being currently in use. Instead we
rename them with the '.delete.me' extension, to mark them to be deleted
next time we clear the cache.
"""
with compilelock.lock_ctx():
for base_dir in ('cuda_ndarray', 'cutils_ext', 'lazylinker_ext',
......@@ -1287,8 +1375,12 @@ class ModuleCache(object):
They are deleted both from the internal dictionaries and from the
filesystem.
:param min_age: Minimum age to be deleted, in seconds. Defaults to
7-day age if not provided.
Parameters
----------
min_age
Minimum age to be deleted, in seconds. Defaults to
7-day age if not provided.
"""
if min_age is None:
min_age = self.age_thresh_del_unversioned
......@@ -1409,8 +1501,13 @@ _module_cache = None
def get_module_cache(dirname, init_args=None):
"""
:param init_args: If not None, the (k, v) pairs in this dictionary will
be forwarded to the ModuleCache constructor as keyword arguments.
Parameters
----------
init_args
If not None, the (k, v) pairs in this dictionary will be forwarded to
the ModuleCache constructor as keyword arguments.
"""
global _module_cache
if init_args is None:
......@@ -1429,7 +1526,10 @@ def get_module_cache(dirname, init_args=None):
def get_lib_extension():
"""Return the platform-dependent extension for compiled modules."""
"""
Return the platform-dependent extension for compiled modules.
"""
if sys.platform in ['win32', 'cygwin']:
return 'pyd'
else:
......@@ -1437,7 +1537,10 @@ def get_lib_extension():
def get_gcc_shared_library_arg():
"""Return the platform-dependent GCC argument for shared libraries."""
"""
Return the platform-dependent GCC argument for shared libraries.
"""
if sys.platform == 'darwin':
return '-dynamiclib'
else:
......@@ -1534,9 +1637,11 @@ def gcc_version():
def gcc_llvm():
""" Detect if the g++ version used is the llvm one or not.
"""
Detect if the g++ version used is the llvm one or not.
It don't support all g++ parameters even if it support many of them.
"""
if gcc_llvm.is_llvm is None:
try:
......@@ -1558,12 +1663,15 @@ gcc_llvm.is_llvm = None
class Compiler(object):
"""
Meta compiler that offer some generic function
Meta compiler that offer some generic function.
"""
@staticmethod
def _try_compile_tmp(src_code, tmp_prefix='', flags=(),
try_run=False, output=False, compiler=None):
"""Try to compile (and run) a test program.
"""
Try to compile (and run) a test program.
This is useful in various occasions, to check if libraries
or compilers are behaving as expected.
......@@ -1574,6 +1682,7 @@ class Compiler(object):
If try_run is False, returns the compilation status.
If try_run is True, returns a (compile_status, run_status) pair.
If output is there, we append the stdout and stderr to the output.
"""
if not compiler:
return False
......@@ -1631,12 +1740,13 @@ class Compiler(object):
@staticmethod
def _try_flags(flag_list, preambule="", body="",
try_run=False, output=False, compiler=None):
'''
"""
Try to compile a dummy file with these flags.
Returns True if compilation was successful, False if there
were errors.
'''
"""
if not compiler:
return False
......@@ -1933,33 +2043,38 @@ class GCC_compiler(Compiler):
include_dirs=None, lib_dirs=None, libs=None,
preargs=None, py_module=True, hide_symbols=True):
"""
:param module_name: string (this has been embedded in the src_code
:param src_code: a complete c or c++ source listing for the module
:param location: a pre-existing filesystem directory where the
cpp file and .so will be written
:param include_dirs: a list of include directory names (each
gets prefixed with -I)
:param lib_dirs: a list of library search path directory names
(each gets prefixed with -L)
:param libs: a list of libraries to link with (each gets
prefixed with -l)
:param preargs: a list of extra compiler arguments
:param py_module: if False, compile to a shared library, but do not
import it as a Python module.
:param hide_symbols: if True (the default) all symbols will be
hidden from the library symbol table (which means that other
objects can't use them.
Parameters
----------
module_name : str
This has been embedded in the src_code.
src_code
A complete c or c++ source listing for the module.
location
A pre-existing filesystem directory where the cpp file and .so will
be written.
include_dirs
A list of include directory names (each gets prefixed with -I).
lib_dirs
A list of library search path directory names (each gets prefixed
with -L).
libs
A list of libraries to link with (each gets prefixed with -l).
preargs
A list of extra compiler arguments.
py_module
If False, compile to a shared library, but do not import it as a
Python module.
hide_symbols
If True (the default) all symbols will be hidden from the library
symbol table (which means that other objects can't use them).
Returns
-------
object
Dynamically-imported python module of the compiled code (unless
py_module is False, in that case returns None).
:returns: dynamically-imported python module of the compiled code.
(unless py_module is False, in that case returns None.)
"""
# TODO: Do not do the dlimport in this function
......
......@@ -32,10 +32,11 @@ except OSError:
def local_bitwidth():
"""
Return 32 for 32bit arch, 64 for 64bit arch
Return 32 for 32bit arch, 64 for 64bit arch.
By "architecture", we mean the size of memory pointers (size_t in C),
*not* the size of long int, as it can be different.
"""
# Note that according to Python documentation, `platform.architecture()` is
# not reliable on OS X with universal binaries.
......@@ -49,6 +50,7 @@ def python_int_bitwidth():
Return the bit width of Python int (C long int).
Note that it can be different from the size of a memory pointer.
"""
# 'l' denotes a C long int, and the size is expressed in bytes.
return struct.calcsize('l') * 8
......@@ -67,7 +69,8 @@ compiledir_format_dict = {
def short_platform(r=None, p=None):
"""Return a safe shorter version of platform.platform().
"""
Return a safe shorter version of platform.platform().
The old default Theano compiledir used platform.platform in
it. This use the platform.version() as a substring. This is too
......@@ -103,13 +106,11 @@ def short_platform(r=None, p=None):
compiledir_Linux-2.6.32-220.7.1.el6.x86_64-x86_64-with-redhat-6.2-Santiago-x86_64-2.6.6
compiledir_Linux-2.6.32-220.4.1.el6.x86_64-x86_64-with-redhat-6.2-Santiago-x86_64-2.6.6
We suppose the version are ``X.Y[.*]-(digit)*(anything)*``. We
keep ``X.Y`` and don't keep less important digit in the part
before ``-`` and we remove the leading digit after the first
``-``.
We suppose the version are ``X.Y[.*]-(digit)*(anything)*``. We keep ``X.Y``
and don't keep less important digit in the part before ``-`` and we remove
the leading digit after the first ``-``.
If the information don't fit that pattern, we do not modify
platform.
If the information don't fit that pattern, we do not modify platform.
"""
if r is None:
......@@ -214,6 +215,7 @@ def filter_compiledir(path):
def get_home_dir():
"""
Return location of the user's home directory.
"""
home = os.getenv('HOME')
if home is None:
......@@ -269,6 +271,7 @@ def cleanup():
3) They do not have a compile version string
If there is no key left for a compiled module, we delete the module.
"""
compiledir = theano.config.compiledir
for directory in os.listdir(compiledir):
......
......@@ -47,6 +47,7 @@ hostname = socket.gethostname()
def force_unlock():
"""
Delete the compilation lock if someone else has it.
"""
get_lock(min_wait=0, max_wait=0.001, timeout=0)
release_lock()
......@@ -67,10 +68,16 @@ def _get_lock(lock_dir=None, **kw):
"""
Obtain lock on compilation directory.
:param kw: Additional arguments to be forwarded to the `lock` function when
acquiring the lock.
Parameters
----------
kw
Additional arguments to be forwarded to the `lock` function when
acquiring the lock.
Notes
-----
We can lock only on 1 directory at a time.
:note: We can lock only on 1 directory at a time.
"""
if lock_dir is None:
lock_dir = os.path.join(config.compiledir, 'lock_dir')
......@@ -125,6 +132,7 @@ get_lock = _get_lock
def release_lock():
"""
Release lock on compilation directory.
"""
get_lock.n_lock -= 1
assert get_lock.n_lock >= 0
......@@ -140,8 +148,11 @@ def set_lock_status(use_lock):
by default). Disabling may make compilation slightly faster (but is not
recommended for parallel execution).
:param use_lock: whether to use the compilation lock or not
:type use_lock: bool
Parameters
----------
use_lock : bool
Whether to use the compilation lock or not.
"""
get_lock.lock_is_enabled = use_lock
......@@ -169,22 +180,22 @@ def lock(tmp_dir, timeout=notset, min_wait=None, max_wait=None, verbosity=1):
displayed each time we re-check for the presence of the lock. Otherwise it
is displayed only when we notice the lock's owner has changed.
:param str tmp_dir: lock directory that will be created when
acquiring the lock
:param timeout: time (in seconds) to wait before replacing an
existing lock (default config 'compile.timeout')
:type timeout: int or None
:param int min_wait: minimum time (in seconds) to wait before
trying again to get the lock
(default config 'compile.wait')
Parameters
----------
tmp_dir : str
Lock directory that will be created when acquiring the lock.
timeout : int or None
Time (in seconds) to wait before replacing an existing lock (default
config 'compile.timeout').
min_wait: int
Minimum time (in seconds) to wait before trying again to get the lock
(default config 'compile.wait').
max_wait: int
Maximum time (in seconds) to wait before trying again to get the lock
(default 2 * min_wait).
verbosity : int
Amount of feedback displayed to screen (default 1).
:param int max_wait: maximum time (in seconds) to wait before
trying again to get the lock
(default 2 * min_wait)
:param int verbosity: amount of feedback displayed to screen (default 1)
"""
if min_wait is None:
min_wait = config.compile.wait
......@@ -321,6 +332,7 @@ def refresh_lock(lock_file):
"""
'Refresh' an existing lock by re-writing the file containing the owner's
unique id, using a new (randomly generated) id, which is also returned.
"""
unique_id = '%s_%s_%s' % (
os.getpid(),
......@@ -348,19 +360,22 @@ class Unlocker(object):
Class wrapper around release mechanism so that the lock is automatically
released when the program exits (even when crashing or being interrupted),
using the __del__ class method.
"""
def __init__(self, tmp_dir):
self.tmp_dir = tmp_dir
def unlock(self, force=False):
"""Remove current lock.
"""
Remove current lock.
This function does not crash if it is unable to properly
delete the lock file and directory. The reason is that it
should be allowed for multiple jobs running in parallel to
unlock the same directory at the same time (e.g. when reaching
their timeout limit).
"""
# If any error occurs, we assume this is because someone else tried to
# unlock this directory at the same time.
......
......@@ -194,7 +194,10 @@ fail:
def compile_cutils():
"""Do just the compilation of cutils_ext"""
"""
Do just the compilation of cutils_ext.
"""
code = ("""
#include <Python.h>
#include "numpy/arrayobject.h"
......
"""
Classes and functions for validating graphs that contain view
and inplace operations.
"""
from collections import deque
......@@ -17,35 +18,41 @@ from six.moves.queue import Queue
class ProtocolError(Exception):
"""Raised when FunctionGraph calls DestroyHandler callbacks in
"""
Raised when FunctionGraph calls DestroyHandler callbacks in
an invalid way, for example, pruning or changing a node that has
never been imported.
"""
pass
def _contains_cycle(fgraph, orderings):
"""
fgraph - the FunctionGraph to check for cycles
Parameters
----------
fgraph
The FunctionGraph to check for cycles.
orderings
Dictionary specifying extra dependencies besides those encoded in
Variable.owner / Apply.inputs.
orderings - dictionary specifying extra dependencies besides
those encoded in Variable.owner / Apply.inputs
If orderings[my_apply] == dependencies, then my_apply is an Apply
instance, dependencies is a set of Apply instances, and every member
of dependencies must be executed before my_apply.
If orderings[my_apply] == dependencies,
The dependencies are typically used to prevent
inplace apply nodes from destroying their input before
other apply nodes with the same input access it.
then my_apply is an Apply instance,
dependencies is a set of Apply instances,
and every member of dependencies must be executed
before my_apply.
Returns
-------
bool
True if the graph contains a cycle, False otherwise.
The dependencies are typically used to prevent
inplace apply nodes from destroying their input before
other apply nodes with the same input access it.
Returns True if the graph contains a cycle, False otherwise.
"""
# These are lists of Variable instances
outputs = fgraph.outputs
......@@ -227,10 +234,15 @@ def _build_droot_impact(destroy_handler):
def fast_inplace_check(inputs):
""" Return the variables in inputs that are posible candidate for as inputs of inplace operation
"""
Return the variables in inputs that are posible candidate for as inputs of
inplace operation.
Parameters
----------
inputs : list
Inputs Variable that you want to use as inplace destination.
:type inputs: list
:param inputs: inputs Variable that you want to use as inplace destination
"""
fgraph = inputs[0].fgraph
Supervisor = theano.compile.function_module.Supervisor
......@@ -249,38 +261,42 @@ if 0:
# old, non-incremental version of the DestroyHandler
class DestroyHandler(toolbox.Bookkeeper):
"""
The DestroyHandler class detects when a graph is impossible to evaluate because of
aliasing and destructive operations.
The DestroyHandler class detects when a graph is impossible to evaluate
because of aliasing and destructive operations.
Several data structures are used to do this.
When an Op uses its view_map property to declare that an output may be aliased
to an input, then if that output is destroyed, the input is also considering to be
destroyed. The view_maps of several Ops can feed into one another and form a directed graph.
The consequence of destroying any variable in such a graph is that all variables in the graph
must be considered to be destroyed, because they could all be refering to the same
underlying storage. In the current implementation, that graph is a tree, and the root of
that tree is called the foundation. The `droot` property of this class maps from every
graph variable to its foundation. The `impact` property maps backward from the foundation
to all of the variables that depend on it. When any variable is destroyed, this class marks
the foundation of that variable as being destroyed, with the `root_destroyer` property.
When an Op uses its view_map property to declare that an output may be
aliased to an input, then if that output is destroyed, the input is also
considering to be destroyed. The view_maps of several Ops can feed into
one another and form a directed graph. The consequence of destroying any
variable in such a graph is that all variables in the graph must be
considered to be destroyed, because they could all be refering to the
same underlying storage. In the current implementation, that graph is a
tree, and the root of that tree is called the foundation. The `droot`
property of this class maps from every graph variable to its foundation.
The `impact` property maps backward from the foundation to all of the
variables that depend on it. When any variable is destroyed, this class
marks the foundation of that variable as being destroyed, with the
`root_destroyer` property.
"""
droot = {}
"""
destroyed view + nonview variables -> foundation
"""
destroyed view + nonview variables -> foundation.
impact = {}
"""
destroyed nonview variable -> it + all views of it
impact = {}
"""
destroyed nonview variable -> it + all views of it.
root_destroyer = {}
"""
root -> destroyer apply
root_destroyer = {}
"""
root -> destroyer apply.
"""
def __init__(self, do_imports_on_attach=True):
self.fgraph = None
self.do_imports_on_attach = do_imports_on_attach
......@@ -295,8 +311,8 @@ if 0:
compilation to be slower.
TODO: WRITEME: what does this do besides the checks?
"""
"""
# Do the checking #
already_there = False
if self.fgraph not in [None, fgraph]:
......@@ -363,8 +379,10 @@ if 0:
self.fgraph = None
def on_import(self, fgraph, app, reason):
"""Add Apply instance to set which must be computed"""
"""
Add Apply instance to set which must be computed.
"""
# if app in self.debug_all_apps: raise ProtocolError("double import")
# self.debug_all_apps.add(app)
# print 'DH IMPORT', app, id(app), id(self), len(self.debug_all_apps)
......@@ -395,7 +413,10 @@ if 0:
self.stale_droot = True
def on_prune(self, fgraph, app, reason):
"""Remove Apply instance from set which must be computed"""
"""
Remove Apply instance from set which must be computed.
"""
# if app not in self.debug_all_apps: raise ProtocolError("prune without import")
# self.debug_all_apps.remove(app)
......@@ -427,7 +448,10 @@ if 0:
self.stale_droot = True
def on_change_input(self, fgraph, app, i, old_r, new_r, reason):
"""app.inputs[i] changed from old_r to new_r """
"""
app.inputs[i] changed from old_r to new_r.
"""
if app == 'output':
# app == 'output' is special key that means FunctionGraph is redefining which nodes are being
# considered 'outputs' of the graph.
......@@ -466,14 +490,14 @@ if 0:
self.stale_droot = True
def validate(self, fgraph):
"""Return None
"""
Return None.
Raise InconsistencyError when
a) orderings() raises an error
b) orderings cannot be topologically sorted.
"""
if self.destroyers:
ords = self.orderings(fgraph)
......@@ -487,7 +511,8 @@ if 0:
return True
def orderings(self, fgraph):
"""Return orderings induced by destructive operations.
"""
Return orderings induced by destructive operations.
Raise InconsistencyError when
a) attempting to destroy indestructable variable, or
......@@ -637,6 +662,7 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa
The following data structures remain to be converted:
<unknown>
"""
pickle_rm_attr = ["destroyers"]
......@@ -644,38 +670,48 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa
self.fgraph = None
self.do_imports_on_attach = do_imports_on_attach
"""maps every variable in the graph to its "foundation" (deepest
ancestor in view chain)
TODO: change name to var_to_vroot"""
"""
Maps every variable in the graph to its "foundation" (deepest
ancestor in view chain).
TODO: change name to var_to_vroot.
"""
self.droot = OrderedDict()
"""maps a variable to all variables that are indirect or direct views of it
(including itself)
essentially the inverse of droot
TODO: do all variables appear in this dict, or only those that are foundations?
TODO: do only destroyed variables go in here? one old docstring said so
TODO: rename to x_to_views after reverse engineering what x is"""
"""
Maps a variable to all variables that are indirect or direct views of it
(including itself) essentially the inverse of droot.
TODO: do all variables appear in this dict, or only those that are
foundations?
TODO: do only destroyed variables go in here? one old docstring said so.
TODO: rename to x_to_views after reverse engineering what x is
"""
self.impact = OrderedDict()
"""if a var is destroyed, then this dict will map
"""
If a var is destroyed, then this dict will map
droot[var] to the apply node that destroyed var
TODO: rename to vroot_to_destroyer"""
TODO: rename to vroot_to_destroyer
"""
self.root_destroyer = OrderedDict()
def on_attach(self, fgraph):
"""
When attaching to a new fgraph, check that
1) This DestroyHandler wasn't already attached to some fgraph
(its data structures are only set up to serve one)
(its data structures are only set up to serve one).
2) The FunctionGraph doesn't already have a DestroyHandler.
This would result in it validating everything twice, causing
compilation to be slower.
Give the FunctionGraph instance:
1) A new method "destroyers(var)"
TODO: what does this do exactly?
TODO: what does this do exactly?
2) A new attribute, "destroy_handler"
TODO: WRITEME: what does this do besides the checks?
"""
# Do the checking #
......@@ -723,9 +759,9 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa
def refresh_droot_impact(self):
"""
Makes sure self.droot, self.impact, and self.root_destroyer are
up to date, and returns them.
(see docstrings for these properties above)
Makes sure self.droot, self.impact, and self.root_destroyer are up to
date, and returns them (see docstrings for these properties above).
"""
if self.stale_droot:
self.droot, self.impact, self.root_destroyer =\
......@@ -747,7 +783,10 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa
self.fgraph = None
def on_import(self, fgraph, app, reason):
"""Add Apply instance to set which must be computed"""
"""
Add Apply instance to set which must be computed.
"""
if app in self.debug_all_apps:
raise ProtocolError("double import")
......@@ -780,7 +819,10 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa
self.stale_droot = True
def on_prune(self, fgraph, app, reason):
"""Remove Apply instance from set which must be computed"""
"""
Remove Apply instance from set which must be computed.
"""
if app not in self.debug_all_apps:
raise ProtocolError("prune without import")
self.debug_all_apps.remove(app)
......@@ -814,7 +856,10 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa
self.stale_droot = True
def on_change_input(self, fgraph, app, i, old_r, new_r, reason):
"""app.inputs[i] changed from old_r to new_r """
"""
app.inputs[i] changed from old_r to new_r.
"""
if app == 'output':
# app == 'output' is special key that means FunctionGraph is redefining which nodes are being
# considered 'outputs' of the graph.
......@@ -854,14 +899,14 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa
self.stale_droot = True
def validate(self, fgraph):
"""Return None
"""
Return None.
Raise InconsistencyError when
a) orderings() raises an error
b) orderings cannot be topologically sorted.
"""
if self.destroyers:
ords = self.orderings(fgraph)
......@@ -882,7 +927,8 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa
return True
def orderings(self, fgraph):
"""Return orderings induced by destructive operations.
"""
Return orderings induced by destructive operations.
Raise InconsistencyError when
a) attempting to destroy indestructable variable, or
......
"""
fg.py: fg stands for FunctionGraph
Contains the FunctionGraph class and exception
types that it can raise
types that it can raise.
"""
from __future__ import print_function
import sys
......@@ -23,10 +24,13 @@ NullType = None
class CachedConstantError(Exception):
"""An exception thrown when we put in a FunctionGraph a Constant
that is cached. This should not happen as the user can reuse this
"""
An exception thrown when we put in a FunctionGraph a Constant
that is cached. This should not happen as the user can reuse this
cached constant in other FunctionGraph.
"""
pass
......@@ -34,24 +38,28 @@ class InconsistencyError(Exception):
"""
This exception should be thrown by listeners to FunctionGraph when the
graph's state is invalid.
"""
pass
class MissingInputError(Exception):
"""
A symbolic input needed to compute the outputs is missing.
"""
pass
class FunctionGraph(utils.object2):
""" WRITEME
A FunctionGraph represents a subgraph bound by a set of input variables and a
set of output variables, ie a subgraph that specifies a theano function.
The inputs list should contain all the inputs
on which the outputs depend. Variables of type Constant are
not counted as inputs.
"""
WRITEME
A FunctionGraph represents a subgraph bound by a set of input variables and
a set of output variables, ie a subgraph that specifies a theano function.
The inputs list should contain all the inputs on which the outputs depend.
Variables of type Constant are not counted as inputs.
The FunctionGraph supports the replace operation which allows to replace a
variable in the subgraph by another, e.g. replace (x + x).out by (2
......@@ -74,28 +82,35 @@ class FunctionGraph(utils.object2):
Historically, the FunctionGraph was called an Env. Keep this in mind
while reading out-of-date documentation, e-mail support threads, etc.
"""
The constructor creates a FunctionGraph which operates on the subgraph
bound by the inputs and outputs sets.
def __init__(self, inputs, outputs, features=None, clone=True):
"""
Create an FunctionGraph which operates on the subgraph bound by the inputs and
outputs sets.
This class keeps a pointer to the inputs and outputs, and also modifies
them.
This class keeps a pointer to the inputs and outputs, and also modifies
them.
#TODO: document what variables are[not] set in the FunctionGraph when a
feature is added via the constructor. How constructed is the
FunctionGraph?
#TODO: document what variables are[not] set in the FunctionGraph when a feature
is added via the constructor. How constructed is the FunctionGraph?
Parameters
----------
inputs
Inputs nodes of the graph, usually declared by the user.
outputs
Outputs nodes of the graph.
clone
If true, we will clone the graph. This is useful to remove the constant
cache problem.
Note: the intermediate nodes between 'inputs' and 'outputs' are not explicitely
passed.
Notes
-----
The intermediate nodes between 'inputs' and 'outputs' are not explicitely
passed.
:param inputs: inputs nodes of the graph, usually declared by the user
:param outputs: outputs nodes of the graph.
:param clone: If true, we will clone the graph. This is
useful to remove the constant cache problem.
"""
def __init__(self, inputs, outputs, features=None, clone=True):
"""
if clone:
inputs, outputs = graph.clone(inputs, outputs)
......@@ -180,15 +195,17 @@ class FunctionGraph(utils.object2):
# self.execute_callbacks('on_setup_node', node)
def disown(self):
""" WRITEME
Cleans up all of this FunctionGraph's nodes and variables so they are not
associated with this FunctionGraph anymore.
"""
WRITEME
Cleans up all of this FunctionGraph's nodes and variables so they are
not associated with this FunctionGraph anymore.
The FunctionGraph should not be used anymore after disown is called.
This may not clean everything this FunctionGraph's features set in the
nodes and variables. If there are no features, this should set
them back to what they were originally.
"""
for apply_node in self.apply_nodes:
del apply_node.fgraph
......@@ -205,18 +222,25 @@ class FunctionGraph(utils.object2):
def clients(self, r):
"""
Set of all the (node, i) pairs such that node.inputs[i] is r.
Tell differently, a list of (node,i) such that each node have
Told differently, a list of (node,i) such that each node have
r as input at index i.
"""
return r.clients
def __add_clients__(self, r, new_clients):
""" WRITEME
r -> variable
new_clients -> list of (node, i) pairs such that node.inputs[i] is r.
"""
Updates the list of clients of r with new_clients.
WRITEME
Parameters
----------
r
Variable.
new_clients
List of (node, i) pairs such that node.inputs[i] is r.
"""
if set(r.clients).intersection(set(new_clients)):
print('ERROR: clients intersect!', file=sys.stderr)
......@@ -229,11 +253,18 @@ class FunctionGraph(utils.object2):
def __remove_clients__(self, r, clients_to_remove,
prune=True, reason=None):
""" WRITEME
r -> variable
clients_to_remove -> list of (op, i) pairs such that node.inputs[i] is not r anymore.
"""
Removes all from the clients list of r.
WRITEME
Parameters
----------
r
Variable.
clients_to_remove
List of (op, i) pairs such that node.inputs[i] is not r anymore.
"""
for entry in clients_to_remove:
r.clients.remove(entry)
......@@ -286,11 +317,14 @@ class FunctionGraph(utils.object2):
if config.exception_verbosity == 'high':
def find_path_to(output_var, input_var):
""" Returns a list of each variable on a (not necessarily unique)
path from input_var to output_var, where each variable in the
list has the preceding variable as one of its inputs.
Returns None if no path exists"""
"""
Returns a list of each variable on a (not
necessarily unique) path from input_var to
output_var, where each variable in the list has
the preceding variable as one of its inputs.
Returns None if no path exists.
"""
# If output and input are the same we have a singleton path
if output_var is input_var:
return [output_var]
......@@ -376,12 +410,13 @@ class FunctionGraph(utils.object2):
# prune #
def __prune_r__(self, variable, reason=None):
"""Should be called for variable that aren't used anymore:
len(var.clients) == 0
"""
Should be called for variable that aren't used anymore:
len(var.clients) == 0.
This do not mean we will remove it from fgraph.variables. If
the owner stay in the fgraph as other outputs are still used,
the variable will be stay in fgraph.variables.
the variable will stay in fgraph.variables.
"""
# Prunes the owners of the variables.
......@@ -409,7 +444,8 @@ class FunctionGraph(utils.object2):
del variable.fgraph
def __prune__(self, apply_node, reason=None):
"""Always called on owner of pruned variable from the graph.
"""
Always called on owner of pruned variable from the graph.
This do not mean we will remove it from the graph. If other
outputs are still used, we will keep the node in the graph.
......@@ -433,14 +469,17 @@ class FunctionGraph(utils.object2):
# change input #
def change_input(self, node, i, new_r, reason=None):
"""WRITEME
"""
Changes node.inputs[i] to new_r.
WRITEME
new_r.type == old_r.type must be True, where old_r is the
current value of node.inputs[i] which we want to replace.
For each feature that has a 'on_change_input' method, calls:
feature.on_change_input(function_graph, node, i, old_r, new_r, reason)
feature.on_change_input(function_graph, node, i, old_r, new_r, reason)
"""
# TODO: ERROR HANDLING FOR LISTENERS (should it complete the change or revert it?)
if node == 'output':
......@@ -478,9 +517,12 @@ class FunctionGraph(utils.object2):
# replace #
def replace(self, r, new_r, reason=None, verbose=None):
""" WRITEME
"""
WRITEME
This is the main interface to manipulate the subgraph in FunctionGraph.
For every node that uses r as input, makes it use new_r instead.
"""
if verbose is None:
verbose = config.optimizer_verbose
......@@ -532,16 +574,19 @@ class FunctionGraph(utils.object2):
# print >> sys.stderr, "WARNING: CLIENTS LEFT AFTER REPLACE", r, r.clients
def replace_all(self, pairs, reason=None):
"""WRITEME"""
"""
WRITEME
"""
for r, new_r in pairs:
self.replace(r, new_r, reason=reason)
def attach_feature(self, feature):
"""
Adds a gof.toolbox.Feature to this function_graph
and triggers its on_attach callback
"""
Adds a gof.toolbox.Feature to this function_graph and triggers its
on_attach callback.
"""
# Filter out literally identical features
if feature in self._features:
return # the feature is already present
......@@ -567,7 +612,9 @@ class FunctionGraph(utils.object2):
self._features.append(feature)
def remove_feature(self, feature):
"""WRITEME
"""
WRITEME
Removes the feature from the graph.
Calls feature.on_detach(function_graph) if an on_detach method
......@@ -585,10 +632,13 @@ class FunctionGraph(utils.object2):
# callback utils #
def execute_callbacks(self, name, *args, **kwargs):
"""WRITEME
"""
WRITEME
Calls
getattr(feature, name)(*args)
for each feature which has a method called after name.
"""
t0 = time.time()
for feature in self._features:
......@@ -605,10 +655,13 @@ class FunctionGraph(utils.object2):
self.execute_callbacks_time += time.time() - t0
def collect_callbacks(self, name, *args):
"""WRITEME
"""
WRITEME
Returns a dictionary d such that:
d[feature] == getattr(feature, name)(*args)
For each feature which has a method called after name.
"""
d = {}
for feature in self._features:
......@@ -621,16 +674,19 @@ class FunctionGraph(utils.object2):
# misc #
def toposort(self):
"""WRITEME
Returns an ordering of the graph's Apply nodes such that:
- All the nodes of the inputs of a node are before that node.
- Satisfies the orderings provided by each feature that has
an 'orderings' method.
"""
WRITEME
Return an ordering of the graph's Apply nodes such that:
- All the nodes of the inputs of a node are before that node.
- Satisfies the orderings provided by each feature that has
an 'orderings' method.
If a feature has an 'orderings' method, it will be called with
this FunctionGraph as sole argument. It should return a dictionary of
{node: predecessors} where predecessors is a list of nodes
that should be computed before the key node.
"""
if len(self.apply_nodes) < 2:
# optimization
......@@ -652,11 +708,12 @@ class FunctionGraph(utils.object2):
before node itself can be evaluated.
This is used primarily by the destroy_handler feature to ensure that
all clients of any destroyed inputs have already computed their
outputs.
all clients of any destroyed inputs have already computed their outputs.
:note: This only calls the orderings() fct on all features. It does not
take care of computing dependencies by itself.
Notes
-----
This only calls the orderings() fct on all features. It does not
take care of computing dependencies by itself.
"""
ords = OrderedDict()
......@@ -682,8 +739,11 @@ class FunctionGraph(utils.object2):
return ords
def check_integrity(self):
"""WRITEME
"""
WRITEME
Call this for a diagnosis if things go awry.
"""
nodes = graph.ops(self.inputs, self.outputs)
if self.apply_nodes != nodes:
......@@ -740,11 +800,17 @@ class FunctionGraph(utils.object2):
# clone #
def clone(self, check_integrity=True):
"""WRITEME"""
"""
WRITEME
"""
return self.clone_get_equiv(check_integrity)[0]
def clone_get_equiv(self, check_integrity=True):
"""WRITEME"""
"""
WRITEME
"""
equiv = graph.clone_get_equiv(self.inputs, self.outputs)
if check_integrity:
self.check_integrity()
......@@ -757,8 +823,10 @@ class FunctionGraph(utils.object2):
return e, equiv
def __getstate__(self):
"""This is needed as some feature introduce instancemethod and
this is not picklable.
"""
This is needed as some features introduce instance methods.
This is not picklable.
"""
d = self.__dict__.copy()
for feature in self._features:
......
......@@ -26,68 +26,74 @@ NoContext = object()
class Node(utils.object2):
"""A Node in a theano graph.
Graphs contain two kinds of Nodes--
Variable and Apply.
"""
A Node in a theano graph.
Graphs contain two kinds of Nodes -- Variable and Apply.
Edges in the graph are not explicitly represented.
Instead each Node keeps track of its parents via
Variable.owner / Apply.inputs and its children
via Variable.clients / Apply.outputs.
"""
def get_parents(self):
""" Return a list of the parents of this node.
"""
Return a list of the parents of this node.
Should return a copy--i.e., modifying the return
value should not modify the graph structure."""
value should not modify the graph structure.
"""
raise NotImplementedError()
class Apply(Node):
"""
An :term:`Apply` instance is a node in an expression graph which represents the application
of an `Op` to some input `Variable` nodes, producing some output `Variable` nodes.
An :term:`Apply` instance is a node in an expression graph which represents
the application of an `Op` to some input `Variable` nodes, producing some
output `Variable` nodes.
This class is typically instantiated by an Op's make_node() function, which is typically
called by that Op's __call__() function.
This class is typically instantiated by an Op's make_node() function, which
is typically called by that Op's __call__() function.
An Apply instance serves as a simple structure with three important attributes:
An Apply instance serves as a simple structure with three important
attributes:
- :literal:`inputs` : a list of `Variable` nodes that represent the arguments of the expression,
- :literal:`inputs` : a list of `Variable` nodes that represent the
arguments of the expression,
- :literal:`outputs` : a list of `Variable` nodes that represent the variable of the expression, and
- :literal:`outputs` : a list of `Variable` nodes that represent the
variable of the expression, and
- :literal:`op` : an `Op` instance that determines the nature of the expression being applied.
- :literal:`op` : an `Op` instance that determines the nature of the
expression being applied.
The driver `compile.function` uses Apply's inputs attribute together with Variable's owner
attribute to search the expression graph and determine which inputs are necessary to
compute the function's outputs.
The driver `compile.function` uses Apply's inputs attribute together with
Variable's owner attribute to search the expression graph and determine
which inputs are necessary to compute the function's outputs.
A `Linker` uses the Apply instance's `op` field to compute the variables.
Comparing with the Python language, an `Apply` instance is theano's version of a function
call (or expression instance) whereas `Op` is theano's version of a function definition.
"""
Comparing with the Python language, an `Apply` instance is theano's version
of a function call (or expression instance) whereas `Op` is theano's version
of a function definition.
def __init__(self, op, inputs, outputs):
"""Initialize attributes
Parameters
----------
op : `Op` instance
inputs : list of Variable instances
outputs : list of Variable instances
:Parameters:
`op` : `Op` instance
initialize self.op
`inputs` : list of Variable instances
initialize self.inputs
`outputs` : list of Variable instances
initialize self.outputs
Notes
-----
The owner field of each output in the outputs list will be set to self.
:note:
The owner field of each output in the outputs list will be set to self.
If an output element has an owner that is neither None nor self, then a
ValueError exception will be raised.
:note:
If an output element has an owner that is neither None nor self, then a ValueError
exception will be raised.
"""
"""
def __init__(self, op, inputs, outputs):
self.op = op
self.inputs = []
self.tag = utils.scratchpad()
......@@ -118,27 +124,29 @@ class Apply(Node):
raise TypeError("The 'outputs' argument to Apply must contain Variable instances with no owner, not %s" % output)
def run_context(self):
"""Returns the context for the node, or NoContext if no context is set.
"""
Returns the context for the node, or NoContext if no context is set.
"""
if hasattr(self.op, 'get_context'):
return self.op.get_context(self)
return NoContext
def default_output(self):
"""Returns the default output for this node.
:rtype:
Variable instance
"""
Returns the default output for this node.
:return:
an element of self.outputs, typically self.outputs[0].
Returns
-------
Variable instance
An element of self.outputs, typically self.outputs[0].
:note:
may raise AttributeError self.op.default_output is out of range, or if there are
multiple outputs and self.op.default_output does not exist.
Notes
-----
May raise AttributeError self.op.default_output is out of range, or if
there are multiple outputs and self.op.default_output does not exist.
"""
do = getattr(self.op, 'default_output', None)
if do is None:
if len(self.outputs) == 1:
......@@ -156,7 +164,10 @@ class Apply(Node):
out = property(default_output,
doc="alias for self.default_output()")
"""Alias for self.default_output()"""
"""
Alias for self.default_output().
"""
def __str__(self):
return op_as_string(self.inputs, self)
......@@ -168,13 +179,18 @@ class Apply(Node):
return self
def clone(self):
"""Duplicate this Apply instance with inputs = self.inputs.
"""
Duplicate this Apply instance with inputs = self.inputs.
:return:
a new Apply instance (or subclass instance) with new outputs.
Returns
-------
object
A new Apply instance (or subclass instance) with new outputs.
Notes
-----
Tags are copied from self to the returned instance.
:note:
tags are copied from self to the returned instance.
"""
cp = self.__class__(self.op, self.inputs,
[output.clone() for output in self.outputs])
......@@ -182,13 +198,14 @@ class Apply(Node):
return cp
def clone_with_new_inputs(self, inputs, strict=True):
"""Duplicate this Apply instance in a new graph.
:param inputs: list of Variable instances to use as inputs.
:type strict: Bool
"""
Duplicate this Apply instance in a new graph.
:param strict:
Parameters
----------
inputs
List of Variable instances to use as inputs.
strict : bool
If True, the type fields of all the inputs must be equal
to the current ones (or compatible, for instance Tensor /
CudaNdarray of the same dtype and broadcastable patterns,
......@@ -198,7 +215,10 @@ class Apply(Node):
clone's outputs will have the same types as self.outputs,
and cloning may not even be possible (it depends on the Op).
:returns: an Apply instance with the same op but different outputs.
Returns
-------
object
An Apply instance with the same op but different outputs.
"""
assert isinstance(inputs, (list, tuple))
......@@ -224,62 +244,90 @@ class Apply(Node):
# convenience properties
nin = property(lambda self: len(self.inputs), doc='same as len(self.inputs)')
"""property: Number of inputs"""
"""
Property: Number of inputs.
"""
nout = property(lambda self: len(self.outputs), doc='same as len(self.outputs)')
"""property: Number of outputs"""
"""
Property: Number of outputs.
"""
context_type = property(lambda self: self.op.context_type, doc='type to use for the context')
class Variable(Node):
"""
A :term:`Variable` is a node in an expression graph that represents a variable.
The inputs and outputs of every `Apply` (theano.gof.Apply) are `Variable` instances.
The input and output arguments to create a `function` are also `Variable` instances.
A `Variable` is like a strongly-typed variable in some other languages; each `Variable` contains a
reference to a `Type` instance that defines the kind of value the `Variable` can take in a
A :term:`Variable` is a node in an expression graph that represents a
variable.
The inputs and outputs of every `Apply` (theano.gof.Apply) are `Variable`
instances. The input and output arguments to create a `function` are also
`Variable` instances. A `Variable` is like a strongly-typed variable in
some other languages; each `Variable` contains a reference to a `Type`
instance that defines the kind of value the `Variable` can take in a
computation.
A `Variable` is a container for four important attributes:
- :literal:`type` a `Type` instance defining the kind of value this `Variable` can have,
- :literal:`type` a `Type` instance defining the kind of value this
`Variable` can have,
- :literal:`owner` either None (for graph roots) or the `Apply` instance of which `self` is an output,
- :literal:`owner` either None (for graph roots) or the `Apply` instance
of which `self` is an output,
- :literal:`index` the integer such that :literal:`owner.outputs[index] is this_variable` (ignored if `owner` is None)
- :literal:`index` the integer such that :literal:`owner.outputs[index] is
this_variable` (ignored if `owner` is None),
- :literal:`name` a string to use in pretty-printing and debugging.
There are a few kinds of Variables to be aware of: A Variable which is the output of a symbolic
computation has a reference to the Apply instance to which it belongs (property: owner) and
the position of itself in the owner's output list (property: index).
There are a few kinds of Variables to be aware of: A Variable which is the
output of a symbolic computation has a reference to the Apply instance to
which it belongs (property: owner) and the position of itself in the owner's
output list (property: index).
- `Variable` (this base type) is typically the output of a symbolic computation,
- `Variable` (this base type) is typically the output of a symbolic
computation.
- `Constant` (a subclass) which adds a default and un-replaceable :literal:`value`, and
requires that owner is None
- `Constant` (a subclass) which adds a default and un-replaceable
:literal:`value`, and requires that owner is None.
- `TensorVariable` subclass of Variable that represents a numpy.ndarray object
- `TensorVariable` subclass of Variable that represents a numpy.ndarray
object.
- `TensorSharedVariable` Shared version of TensorVariable
- `TensorSharedVariable` Shared version of TensorVariable.
- `SparseVariable` subclass of Variable that represents a scipy.sparse.{csc,csr}_matrix object
- `SparseVariable` subclass of Variable that represents
a scipy.sparse.{csc,csr}_matrix object.
- `CudaNdarrayVariable` subclass of Variable that represents our object on the GPU that is a subset of numpy.ndarray
- `CudaNdarrayVariable` subclass of Variable that represents our object on
the GPU that is a subset of numpy.ndarray.
- `RandomVariable`
- `RandomVariable`.
A Variable which is the output of a symbolic computation will have an owner
not equal to None.
Using the Variables' owner field and the Apply nodes' inputs fields, one can navigate a graph
from an output all the way to the inputs. The opposite direction is not possible until an
FunctionGraph has annotated the Variables with the clients field, ie, before the compilation process
has begun a Variable does not know which Apply nodes take it as input.
Using the Variables' owner field and the Apply nodes' inputs fields, one can
navigate a graph from an output all the way to the inputs. The opposite
direction is not possible until a FunctionGraph has annotated the Variables
with the clients field, ie, before the compilation process has begun a
Variable does not know which Apply nodes take it as input.
**Code Example**
Parameters
----------
type : a Type instance
The type governs the kind of data that can be associated with this
variable.
owner : None or Apply instance
The Apply instance which computes the value for this variable.
index : None or int
The position of this Variable in owner.outputs.
name : None or str
A string for pretty-printing and debugging.
Examples
--------
.. code-block:: python
......@@ -303,32 +351,20 @@ class Variable(Node):
e = d + b
theano.function([d,b], [e]) # this works. d's default value of 1.5 is ignored.
The python variables :literal:`a,b,c` all refer to instances of type `Variable`.
The `Variable` refered to by `a` is also an instance of `Constant`.
The python variables :literal:`a,b,c` all refer to instances of type
`Variable`. The `Variable` refered to by `a` is also an instance of
`Constant`.
`compile.function` uses each `Apply` instance's `inputs` attribute together
with each Variable's `owner` field to determine which inputs are necessary
to compute the function's outputs.
`compile.function` uses each `Apply` instance's `inputs` attribute
together with each Variable's `owner` field to determine which inputs are necessary to compute the function's outputs.
"""
# __slots__ = ['type', 'owner', 'index', 'name']
__count__ = count(0)
def __init__(self, type, owner=None, index=None, name=None):
"""Initialize type, owner, index, name.
:type type: a Type instance
:param type:
the type governs the kind of data that can be associated with this variable
:type owner: None or Apply instance
:param owner: the Apply instance which computes the value for this variable
:type index: None or int
:param index: the position of this Variable in owner.outputs
:type name: None or str
:param name: a string for pretty-printing and debugging
"""
super(Variable, self).__init__()
self.tag = utils.scratchpad()
......@@ -345,7 +381,10 @@ class Variable(Node):
self.auto_name = 'auto_' + str(next(self.__count__))
def __str__(self):
"""WRITEME"""
"""
WRITEME
"""
if self.name is not None:
return self.name
if self.owner is not None:
......@@ -361,13 +400,21 @@ class Variable(Node):
return str(self)
def clone(self):
"""Return a new Variable like self.
"""
Return a new Variable like self.
Returns
-------
Variable instance
A new Variable instance (or subclass instance) with no owner or
index.
Notes
-----
Tags are copied to the returned instance.
:rtype: Variable instance
:return: a new Variable instance (or subclass instance) with no owner or index.
Name is copied to the returned instance.
:note: tags are copied to the returned instance.
:note: name is copied to the returned instance.
"""
# return copy(self)
cp = self.__class__(self.type, None, None, self.name)
......@@ -396,9 +443,14 @@ class Variable(Node):
return []
def eval(self, inputs_to_values=None):
""" Evaluates this variable.
"""
Evaluates this variable.
Parameters
----------
inputs_to_values
A dictionary mapping theano Variables to values.
inputs_to_values: a dictionary mapping theano Variables to values.
"""
if inputs_to_values is None:
......@@ -424,21 +476,23 @@ class Variable(Node):
class Constant(Variable):
"""
A :term:`Constant` is a `Variable` with a `value` field that cannot be changed at runtime.
A :term:`Constant` is a `Variable` with a `value` field that cannot be
changed at runtime.
Constant nodes make eligible numerous optimizations: constant inlining in C code, constant folding, etc.
"""
# __slots__ = ['data']
def __init__(self, type, data, name=None):
"""Initialize self.
Constant nodes make eligible numerous optimizations: constant inlining in
C code, constant folding, etc.
:note:
The data field is filtered by what is provided in the constructor for the Constant's
type field.
Notes
-----
The data field is filtered by what is provided in the constructor for the
Constant's type field.
WRITEME
WRITEME
"""
"""
# __slots__ = ['data']
def __init__(self, type, data, name=None):
Variable.__init__(self, type, None, None, name)
self.data = type.filter(data)
......@@ -463,18 +517,23 @@ class Constant(Variable):
def clone(self):
"""
We clone this object, but we don't clone the data to lower memory requirement
We suppose that the data will never change.
We clone this object, but we don't clone the data to lower memory
requirement. We suppose that the data will never change.
"""
cp = self.__class__(self.type, self.data, self.name)
cp.tag = copy(self.tag)
return cp
def __set_owner(self, value):
"""WRITEME
"""
WRITEME
Raises
------
ValueError
If `value` is not `None`.
:Exceptions:
- `ValueError`: if `value` is not `None`
"""
if value is not None:
raise ValueError("Constant instances cannot have an owner.")
......@@ -486,20 +545,26 @@ class Constant(Variable):
def stack_search(start, expand, mode='bfs', build_inv=False):
"""Search through a graph, either breadth- or depth-first
:type start: deque
:param start: search from these nodes
:type expand: callable
:param expand:
when we get to a node, add expand(node) to the list of nodes to visit.
This function should return a list, or None
:rtype: list of `Variable` or `Apply` instances (depends on `expend`)
:return: the list of nodes in order of traversal.
"""
Search through a graph, either breadth- or depth-first.
:note:
a node will appear at most once in the return value, even if it
appears multiple times in the start parameter.
Parameters
----------
start : deque
Search from these nodes.
expand : callable
When we get to a node, add expand(node) to the list of nodes to visit.
This function should return a list, or None.
Returns
-------
list of `Variable` or `Apply` instances (depends on `expend`)
The list of nodes in order of traversal.
Notes
-----
A node will appear at most once in the return value, even if it
appears multiple times in the start parameter.
:postcondition: every element of start is transferred to the returned list.
:postcondition: start is empty.
......@@ -533,15 +598,20 @@ def stack_search(start, expand, mode='bfs', build_inv=False):
def ancestors(variable_list, blockers=None):
"""Return the variables that contribute to those in variable_list (inclusive).
"""
Return the variables that contribute to those in variable_list (inclusive).
:type variable_list: list of `Variable` instances
:param variable_list:
output `Variable` instances from which to search backward through owners
:rtype: list of `Variable` instances
:returns:
all input nodes, in the order found by a left-recursive depth-first search
started at the nodes in `variable_list`.
Parameters
----------
variable_list : list of `Variable` instances
Output `Variable` instances from which to search backward through
owners.
Returns
-------
list of `Variable` instances
All input nodes, in the order found by a left-recursive depth-first
search started at the nodes in `variable_list`.
"""
def expand(r):
......@@ -552,15 +622,20 @@ def ancestors(variable_list, blockers=None):
def inputs(variable_list, blockers=None):
"""Return the inputs required to compute the given Variables.
"""
Return the inputs required to compute the given Variables.
:type variable_list: list of `Variable` instances
:param variable_list:
output `Variable` instances from which to search backward through owners
:rtype: list of `Variable` instances
:returns:
input nodes with no owner, in the order found by a left-recursive depth-first search
started at the nodes in `variable_list`.
Parameters
----------
variable_list : list of `Variable` instances
Output `Variable` instances from which to search backward through
owners.
Returns
-------
list of `Variable` instances
Input nodes with no owner, in the order found by a left-recursive
depth-first search started at the nodes in `variable_list`.
"""
vlist = ancestors(variable_list, blockers)
......@@ -569,7 +644,9 @@ def inputs(variable_list, blockers=None):
def variables_and_orphans(i, o):
"""WRITEME
"""
WRITEME
"""
def expand(r):
if r.owner and r not in i:
......@@ -582,17 +659,24 @@ def variables_and_orphans(i, o):
def ops(i, o):
""" WRITEME
"""
WRITEME
:type i: list
:param i: input L{Variable}s
:type o: list
:param o: output L{Variable}s
Parameters
----------
i : list
Input L{Variable}s.
o : list
Output L{Variable}s.
Returns
-------
object
The set of ops that are contained within the subgraph that lies
between i and o, including the owners of the L{Variable}s in o and
intermediary ops between i and o, but not the owners of the L{Variable}s
in i.
:returns:
the set of ops that are contained within the subgraph that lies between i and o,
including the owners of the L{Variable}s in o and intermediary ops between i and o, but
not the owners of the L{Variable}s in i.
"""
ops = set()
variables, orphans = variables_and_orphans(i, o)
......@@ -604,33 +688,48 @@ def ops(i, o):
def variables(i, o):
""" WRITEME
"""
WRITEME
:type i: list
:param i: input L{Variable}s
:type o: list
:param o: output L{Variable}s
Parameters
----------
i : list
Input L{Variable}s.
o : list
Output L{Variable}s.
Returns
-------
object
The set of Variables that are involved in the subgraph that lies
between i and o. This includes i, o, orphans(i, o) and all values of
all intermediary steps from i to o.
:returns:
the set of Variables that are involved in the subgraph that lies between i and o. This
includes i, o, orphans(i, o) and all values of all intermediary steps from i to o.
"""
return variables_and_orphans(i, o)[0]
def orphans(i, o):
""" WRITEME
"""
WRITEME
:type i: list
:param i: input L{Variable}s
:type o: list
:param o: output L{Variable}s
Parameters
----------
i : list
Input L{Variable}s.
o : list
Output L{Variable}s.
:returns:
the set of Variables which one or more Variables in o depend on but are neither in i nor in
the subgraph that lies between i and o.
Returns
-------
object
The set of Variables which one or more Variables in o depend on but are
neither in i nor in the subgraph that lies between i and o.
Examples
--------
orphans([x], [(x+y).out]) => [y]
e.g. orphans([x], [(x+y).out]) => [y]
"""
return variables_and_orphans(i, o)[1]
......@@ -639,14 +738,20 @@ def clone(i, o, copy_inputs=True):
"""
Copies the subgraph contained between i and o.
:type i: list
:param i: input L{Variable}s
:type o: list
:param o: output L{Variable}s
:type copy_inputs: bool
:param copy_inputs: if True, the inputs will be copied (defaults to True)
Parameters
----------
i : list
Input L{Variable}s.
o : list
Output L{Variable}s.
copy_inputs : bool
If True, the inputs will be copied (defaults to True).
Returns
-------
object
The inputs and outputs of that copy.
Returns the inputs and outputs of that copy.
"""
equiv = clone_get_equiv(i, o, copy_inputs)
return [equiv[input] for input in i], [equiv[output] for output in o]
......@@ -662,20 +767,19 @@ def clone_get_equiv(inputs, outputs, copy_inputs_and_orphans=True, memo=None):
Parameters
----------
inputs: a list of Variables
outputs: a list of Variables
copy_inputs_and_orphans: bool
inputs : a list of Variables
outputs : a list of Variables
copy_inputs_and_orphans : bool
True means to create the cloned graph from new input and constant
nodes (the bottom of a feed-upward graph),
nodes (the bottom of a feed-upward graph).
False means to clone a graph that is rooted at the original input
nodes.
memo: None or dict
nodes.
memo : None or dict
Optionally start with a partly-filled dictionary for the return value.
If a dictionary is passed, this function will work in-place on that
dictionary and return it.
"""
if memo is None:
memo = {}
......@@ -714,29 +818,33 @@ def clone_get_equiv(inputs, outputs, copy_inputs_and_orphans=True, memo=None):
def general_toposort(r_out, deps, debug_print=False,
compute_deps_cache=None, deps_cache=None):
"""WRITEME
"""
WRITEME
:note:
deps(i) should behave like a pure function (no funny business with internal state)
Parameters
----------
deps
A python function that takes a node as input and returns its dependence.
compute_deps_cache : optional
If provided deps_cache should also be provided. This is a function like
deps, but that also cache its results in a dict passed as deps_cache.
deps_cache : dict
Must be used with compute_deps_cache.
:note:
deps(i) will be cached by this function (to be fast)
Notes
-----
deps(i) should behave like a pure function (no funny business with
internal state).
:note:
The order of the return value list is determined by the order of nodes returned by the deps() function.
deps(i) will be cached by this function (to be fast).
:param deps: a python function that take a node as input and
return its dependence.
:param compute_deps_cache: Optional,
if provided deps_cache should also be provided. This is a
function like deps, but that also cache its results in a dict
passed as deps_cache.
:param deps_cache: a dict. Must be used with compute_deps_cache.
The order of the return value list is determined by the order of nodes
returned by the deps() function.
:note: deps should be provided or can be None and the caller
provide compute_deps_cache and deps_cache. The second option
remove a Python function call, and allow for more specialized
code, so it can be faster.
deps should be provided or can be None and the caller provides
compute_deps_cache and deps_cache. The second option removes a Python
function call, and allows for more specialized code, so it can be
faster.
"""
if compute_deps_cache is None:
......@@ -788,18 +896,17 @@ def general_toposort(r_out, deps, debug_print=False,
def io_toposort(inputs, outputs, orderings=None):
"""WRITEME
inputs: a list or tuple of Variable instances
outputs: a list or tuple of Apply instances
orderings: a dictionary
key: Apply instance
value: list of Apply instance
"""
WRITEME
it is important that the value be
a container with a deterministic iteration
order. no sets allowed!
Parameters
----------
inputs : list or tuple of Variable instances
outputs : list or tuple of Apply instances
orderings: dict
Key: Apply instance. Value: list of Apply instance.
It is important that the value be a container with a deterministic
iteration order. No sets allowed!
"""
# the inputs are used only here in the function that decides what 'predecessors' to explore
......@@ -864,9 +971,9 @@ def default_node_formatter(op, argstrings):
def io_connection_pattern(inputs, outputs):
"""
Returns the connection pattern of a subgraph defined by given
inputs and outputs
"""
inputs and outputs.
"""
inner_nodes = io_toposort(inputs, outputs)
# Initialize 'connect_pattern_by_var' by establishing each input as
......@@ -941,22 +1048,26 @@ def is_same_graph(var1, var2, givens=None, debug=False):
return the same output. The goal is to verify this assumption, to
eventually get rid of one of them in the future.
:param var1: The first Variable to compare.
:param var2: The second Variable to compare.
:param givens: Similar to the `givens` argument of `theano.function`, it
can be used to perform substitutions in the computational graph of `var1`
and `var2`. This argument is associated to neither `var1` nor `var2`:
substitutions may affect both graphs if the substituted variable is present
in both.
:param debug: If True, then an exception is raised when we are in a
situation where the `equal_computations` implementation cannot be called.
This parameter is intended to be used in tests only, to make sure we
properly test both implementations.
Examples:
Parameters
----------
var1
The first Variable to compare.
var2
The second Variable to compare.
givens
Similar to the `givens` argument of `theano.function`, it can be used
to perform substitutions in the computational graph of `var1` and
`var2`. This argument is associated to neither `var1` nor `var2`:
substitutions may affect both graphs if the substituted variable
is present in both.
debug : bool
If True, then an exception is raised when we are in a situation where
the `equal_computations` implementation cannot be called.
This parameter is intended to be used in tests only, to make sure we
properly test both implementations.
Examples
--------
====== ====== ====== ======
var1 var2 givens output
......@@ -965,6 +1076,7 @@ def is_same_graph(var1, var2, givens=None, debug=False):
x + 1 y + 1 {} False
x + 1 y + 1 {x: y} True
====== ====== ====== ======
"""
# Lazy import.
if givens is None:
......@@ -1040,7 +1152,10 @@ def is_same_graph(var1, var2, givens=None, debug=False):
def op_as_string(i, op,
leaf_formatter=default_leaf_formatter,
node_formatter=default_node_formatter):
"""WRITEME"""
"""
WRITEME
"""
strs = as_string(i, op.inputs, leaf_formatter, node_formatter)
return node_formatter(op, strs)
......@@ -1048,28 +1163,32 @@ def op_as_string(i, op,
def as_string(i, o,
leaf_formatter=default_leaf_formatter,
node_formatter=default_node_formatter):
"""WRITEME
:type i: list
:param i: input `Variable` s
:type o: list
:param o: output `Variable` s
:type leaf_formatter: function
:param leaf_formatter: takes a `Variable` and returns a string to describe it
:type node_formatter: function
:param node_formatter:
takes an `Op` and the list of strings corresponding to its arguments and returns a
string to describe it
:rtype: str
:returns:
Returns a string representation of the subgraph between i and o. If the same op is used
by several other ops, the first occurrence will be marked as :literal:`*n ->
description` and all subsequent occurrences will be marked as :literal:`*n`, where n is
an id number (ids are attributed in an unspecified order and only exist for viewing
convenience).
"""
WRITEME
Parameters
----------
i : list
Input `Variable` s.
o : list
Output `Variable` s.
leaf_formatter : function
Takes a `Variable` and returns a string to describe it.
node_formatter : function
Takes an `Op` and the list of strings corresponding to its arguments
and returns a string to describe it.
Returns
-------
str
Returns a string representation of the subgraph between i and o. If the
same op is used by several other ops, the first occurrence will be
marked as :literal:`*n -> description` and all subsequent occurrences
will be marked as :literal:`*n`, where n is an id number (ids are
attributed in an unspecified order and only exist for viewing
convenience).
"""
i = set(i)
orph = orphans(i, o)
......@@ -1126,6 +1245,7 @@ def view_roots(r):
consecutive view_map()s.
WRITEME
"""
owner = r.owner
if owner is not None:
......@@ -1147,7 +1267,10 @@ def view_roots(r):
def list_of_nodes(inputs, outputs):
""" Return the apply nodes of the graph between inputs and outputs """
"""
Return the apply nodes of the graph between inputs and outputs.
"""
return stack_search(
deque([o.owner for o in outputs]),
lambda o: [inp.owner for inp in o.inputs
......
"""WRITEME"""
"""
WRITEME
"""
from __future__ import print_function
from copy import copy, deepcopy
from sys import getsizeof
......@@ -20,8 +23,10 @@ __excepthook = sys.excepthook
def log_thunk_trace(value, f=sys.stderr):
"""Log Theano's diagnostic stack trace for an exception
"""
Log Theano's diagnostic stack trace for an exception
raised by raise_with_op.
"""
# in future, consider accepting `write` as arg rather than file
# to support writing to a logger
......@@ -46,7 +51,9 @@ def log_thunk_trace(value, f=sys.stderr):
def thunk_hook(type, value, trace):
"""WRITEME
"""
WRITEME
This function is meant to replace excepthook and do some
special work if the exception value has a __thunk_trace__
field. In that case, it retrieves the field, which should
......@@ -55,7 +62,10 @@ def thunk_hook(type, value, trace):
The normal excepthook is then called.
:note: This hook replaced by nosetests, so it does not run in nose tests.
Notes
-----
This hook replaced by nosetests, so it does not run in nose tests.
"""
log_thunk_trace(value)
__excepthook(type, value, trace)
......@@ -82,7 +92,6 @@ def raise_with_op(node, thunk=None, exc_info=None, storage_map=None):
Notes
-----
This re-raises the exception described by `exc_info` (or the last
one raised, if `exc_info` is omitted) and annotates the exception
object with several new members which may be helpful for debugging
......@@ -96,6 +105,7 @@ def raise_with_op(node, thunk=None, exc_info=None, storage_map=None):
to this op in `op.fgraph.toposort()`.
The exception is not annotated if it is of type `KeyboardInterrupt`.
"""
if exc_info is None:
exc_info = sys.exc_info()
......@@ -298,7 +308,10 @@ def raise_with_op(node, thunk=None, exc_info=None, storage_map=None):
class Linker(object):
"""WRITEME"""
"""
WRITEME
"""
def clone(self, allow_gc=undef):
new = copy(self)
......@@ -308,22 +321,25 @@ class Linker(object):
def make_thunk(self):
"""
This function must return a triplet (function, input_variables, output_variables)
where function is a thunk that operates on the returned variables. If inplace
is True, the input_variables and output_variables lists will be the same as the
inputs and outputs of the graph provided to the L{Linker}. Else, independent
This function must return a triplet (function, input_variables,
output_variables) where function is a thunk that operates on the
returned variables. If inplace is True, the input_variables and
output_variables lists will be the same as the inputs and outputs
of the graph provided to the L{Linker}. Else, independent
variables will be returned.
Example::
x, y = Variable(Double), Variable(Double)
e = x + y
fgraph = FunctionGraph([x, y], [e])
fn, (new_x, new_y), (new_e, ) = MyLinker(fgraph).make_thunk(inplace)
new_x.data = 1.0
new_y.data = 2.0
fn()
print new_e.data # 3.0
print e.data # 3.0 iff inplace == True (else unknown)
Examples
--------
x, y = Variable(Double), Variable(Double)
e = x + y
fgraph = FunctionGraph([x, y], [e])
fn, (new_x, new_y), (new_e, ) = MyLinker(fgraph).make_thunk(inplace)
new_x.data = 1.0
new_y.data = 2.0
fn()
print new_e.data # 3.0
print e.data # 3.0 iff inplace == True (else unknown)
"""
raise utils.MethodNotDefined("make_thunk", type(self),
self.__class__.__name__)
......@@ -332,21 +348,23 @@ class Linker(object):
def make_function(self, unpack_single=True, **kwargs):
"""
Returns a function that takes values corresponding to the inputs of the
fgraph used by this L{Linker} and returns values corresponding the the outputs
of that fgraph. If inplace is True, the calculations will operate in the
same storage the fgraph uses, else independent storage will be allocated
for the function.
Example::
e = x + y
fgraph = FunctionGraph([x, y], [e])
fn = MyLinker(fgraph).make_function(inplace)
print fn(1.0, 2.0) # 3.0
print e.data # 3.0 iff inplace == True (else unknown)
fgraph used by this L{Linker} and returns values corresponding the the
outputs of that fgraph. If inplace is True, the calculations will
operate in the same storage the fgraph uses, else independent storage
will be allocated for the function.
Example
-------
e = x + y
fgraph = FunctionGraph([x, y], [e])
fn = MyLinker(fgraph).make_function(inplace)
print fn(1.0, 2.0) # 3.0
print e.data # 3.0 iff inplace == True (else unknown)
If unpack_single is True (default) and that the function has only one
output, then that output will be returned. Else, a list or tuple of
length 1 will be returned.
"""
thunk, inputs, outputs = self.make_thunk(**kwargs)
......@@ -376,24 +394,31 @@ class Linker(object):
# TODO: Move this class to the compile module, where it is used (and for which it exists).
class Container(object):
"""This class joins a variable with its computed value.
"""
This class joins a variable with its computed value.
It is used in linkers, especially for the inputs and outputs of a Function.
Parameters
----------
r : a Variable or a Type
storage
A list of length 1, whose element is the value for `r`.
readonly : bool
True indicates that this should not be setable by Function[r] = val.
strict : bool
If True, we don't allow type casting.
allow_downcast
If True (and `strict` is False), allow upcasting of type, but not
downcasting. If False, prevent it. If None (default), allows only
downcasting of float to floatX scalar.
name : str
A string (for pretty-printing?)
"""
def __init__(self, r, storage, readonly=False, strict=False,
allow_downcast=None, name=None):
"""WRITEME
:Parameters:
`r`: a Variable or a Type
`storage`: a list of length 1, whose element is the value for `r`
`readonly`: True indicates that this should not be setable by Function[r] = val
`strict`: if True, we don't allow type casting.
`allow_downcast`: if True (and `strict` is False), allow upcasting
of type, but not downcasting. If False, prevent it. If None
(default), allows only downcasting of float to floatX scalar.
`name`: A string (for pretty-printing?)
"""
if not isinstance(storage, list) or not len(storage) >= 1:
raise TypeError("storage must be a list of length at least one")
# self.r = r
......@@ -472,23 +497,38 @@ class Container(object):
def map_storage(fgraph, order, input_storage, output_storage):
"""Ensure there is storage (a length-1 list) for inputs, outputs, and interior nodes.
:param fgraph: The current fgraph. This function uses the inputs and outputs attributes.
:param order: an iterable over Apply instances (in program running order)
:param input_storage: None or existing input storage (see below)
:param output_storage: None or existing output storage (see below)
:rtype: 3-tuple
:returns: (list of storage for inputs, list of storage for outputs, and the `storage_map`)
"""
Ensure there is storage (a length-1 list) for inputs, outputs, and
interior nodes.
Parameters
----------
fgraph
The current fgraph. This function uses the inputs and outputs
attributes.
order
An iterable over Apply instances (in program running order).
input_storage
None or existing input storage (see below).
output_storage
None or existing output storage (see below).
Returns
-------
3-tuple
List of storage for inputs, list of storage for outputs, and
the `storage_map`.
Extended summary
----------------
This function iterates over the nodes in `order` and ensures that for every
input and output `Variable`, there is a unique storage container. This is
returned as a dictionary Variable->storage called the `storage_map`.
input and output `Variable`, there is a unique storage container. This is
returned as a dictionary Variable -> storage called the `storage_map`.
This function also returns `input_storage` which is a list of storages corresponding to fgraph.inputs.
This function also returns `output_storage` which is a list of storages corresponding to fgraph.outputs.
This function also returns `input_storage`, which is a list of storages
corresponding to fgraph.inputs.
This function also returns `output_storage`, which is a list of storages
corresponding to fgraph.outputs.
"""
# each Apply argument's data is stored in a list of length 1 (these lists act like pointers)
......@@ -531,23 +571,28 @@ def map_storage(fgraph, order, input_storage, output_storage):
def streamline(fgraph, thunks, order, post_thunk_old_storage=None,
no_recycling=None, nice_errors=True):
"""WRITEME
:param fgraph:
:param thunks: the list of program instructions
:param order: the list of apply instances that gave rise to the thunks (same order as thunks)
:param post_thunk_old_storage: a list (corresponding to thunks, order) whose elements are
lists of storage cells, that should be cleared after running the corresponding thunk. A
value of None disables this functionality
"""
WRITEME
:param no_recycling: storage elements that cannot be 'recycled' by repeatedly executing the
program. These storage elements are cleared before re-running.
Parameters
----------
fgraph
thunks
The list of program instructions.
order
The list of apply instances that gave rise to the thunks
(same order as thunks).
post_thunk_old_storage
A list (corresponding to thunks, order) whose elements are lists of
storage cells, that should be cleared after running thecorresponding
thunk. A value of None disables this functionality.
no_recycling
Storage elements that cannot be 'recycled' by repeatedly executing the
program. These storage elements are cleared before re-running.
nice_errors
Run in such a way that the double-traceback is printed. This costs a
bit of performance in the inner python loop.
:param nice_errors: run in such a way that the double-traceback is printed. This costs a
bit of performance in the inner python loop.
"""
if no_recycling is None:
no_recycling = []
......@@ -597,9 +642,12 @@ def streamline(fgraph, thunks, order, post_thunk_old_storage=None,
class LocalLinker(Linker):
"""WRITEME
Useful base class for L{Linker}s which keep all nodes in the graph, and run a
thunk associated with each node.
"""
WRITEME
Useful base class for L{Linker}s which keep all nodes in the graph, and run
a thunk associated with each node.
"""
def make_thunk(self, input_storage=None, output_storage=None):
......@@ -621,16 +669,25 @@ class LocalLinker(Linker):
def gc_helper(node_list):
"""
:param node_list: list of Apply instances in program execution order
:rtype: a 2-tuple
:returns: FIRST, the set of Variable instances which are computed by node_list, and SECOND a
dictionary that maps each Variable instance to a the last node to use Variable as an input.
Parameters
----------
node_list
List of Apply instances in program execution order.
Returns
-------
2-tuple
FIRST, the set of Variable instances which are computed by node_list,
and SECOND a dictionary that maps each Variable instance to a the last
node to use Variable as an input.
Extended Summary
----------------
This is used to allow garbage collection within graphs.
It ignore view_map and destroy_map. This isn't needed as python
have referecence count. In Theano gc, we should not take into
It ignores view_map and destroy_map. This isn't needed as python
have reference count. In Theano gc, we should not take into
account view_map and destroy_map as if the thunk decided to create
a new output, we would delay uselessly its gc by Python.
......@@ -647,10 +704,12 @@ def gc_helper(node_list):
class PerformLinker(LocalLinker):
"""WRITEME
"""
WRITEME
Basic L{Linker} subclass that calls the perform method on each L{Op} in
the L{FunctionGraph} in the order given by L{Linker.schedule}.
"""
def __init__(self, allow_gc=None, schedule=None):
......@@ -663,11 +722,20 @@ class PerformLinker(LocalLinker):
def accept(self, fgraph, no_recycling=None):
"""
:param fgraph: a PerformLinker can have accepted one FunctionGraph instance at a time.
:param no_recycling: WRITEME
Parameters
----------
fgraph
A PerformLinker can have accepted one FunctionGraph instance at a
time.
no_recycling
WRITEME
Returns
-------
object
self (TODO: WHY? Who calls this function?)
:returns: self (TODO: WHY? Who calls this function?)
"""
if no_recycling is None:
no_recycling = []
......@@ -680,10 +748,20 @@ class PerformLinker(LocalLinker):
def make_all(self, input_storage=None, output_storage=None):
"""
:param input_storage: WRITEME
:param output_storage: WRITEME
:returns: function to run all nodes, list of input containers, list of output containers, list of thunks (for all of program), list of nodes (for all of program)
Parameters
----------
input_storage
WRITEME
output_storage
WRITEME
Returns
-------
object
Function to run all nodes, list of input containers, list of output
containers, list of thunks (for all programs), list of nodes
(for all programs).
"""
fgraph = self.fgraph
......@@ -764,41 +842,39 @@ def add_clear_storage(f, computed, storage_map):
class WrapLinker(Linker):
"""
WRITEME
This class makes it easier to run several L{LocalLinker}s in parallel, and
offers some control over how each thunk is run.
A wrapper function must be provided, and it can be used to execute the
thunks, inspect the nodes, print stuff out, etc.
@note:
The outputs of the first linker will be returned.
@note:
This linker ensures that each linker has its own storage for
inputs and outputs and intermediate variables. There is no interference
between linkers.
"""
The constructor initializes a WrapLinker.
def __init__(self, linkers, wrapper):
"""
Initialize a WrapLinker.
@type linkers: list of L{LocalLinker} subclasses, whose make_all()
method returns thunks in the same order.
@param linkers: for each node in the graph, each linker will provide a
Parameters
----------
linkers : list of L{LocalLinker} subclasses, whose make_all() method returns
thunks in the same order.
For each node in the graph, each linker will provide a
thunk. This class makes it possible to iterate over each linker's
program in parallel.
wrapper : lambda (i, i_node, i_thunk1, i_thunk2, ...) : None
Does some user-defined action for the i'th element of the program.
i_thunk<n> is the thunk returned by the n'th linker. (If you want
to run the program, make sure to call the necessary thunks in this
function.)
Notes
-----
The outputs of the first linker will be returned.
@type wrapper: lambda (i, i_node, i_thunk1, i_thunk2, ...) : None
This linker ensures that each linker has its own storage for inputs and
outputs and intermediate variables. There is no interference between
linkers.
@param wrapper: do some user-defined action for the i'th element of the
program. i_thunk<n> is the thunk returned by the n'th linker. (If you
want to run the program, make sure to call the necessary thunks in this
function.)
"""
"""
def __init__(self, linkers, wrapper):
self.fgraph = None
self.linkers = linkers
self.wrapper = wrapper
......@@ -807,12 +883,16 @@ class WrapLinker(Linker):
"""
Shallow copy of a WrapLinker.
@returns: A copy of self, where each of the linkers in self.linkers
Returns
-------
object
A copy of self, where each of the linkers in self.linkers
have been shallow-copied.
It is useful because in FunctionMaker, copy.copy is called on the
Mode's linker, so that it is not modified inplace when linker.accept()
is called. In this case, we want the wrapped linkers to be copied too.
"""
other = self.__class__(
linkers=[copy(l) for l in self.linkers],
......@@ -826,14 +906,15 @@ class WrapLinker(Linker):
def accept(self, fgraph, no_recycling=None):
"""
@type fgraph: gof.FunctionGraph
@param fgraph: the fgraph which we will link
@type no_recycling: a list of Variables that belong to fgraph.
@param no_recycling: If a Variable is in no_recycling, L{WrapLinker} will clear
the output storage associated to it (for each linker in linkers) during
the computation to avoid reusing it.
Parameters
----------
fgraph : gof.FunctionGraph
The fgraph which we will link.
no_recycling : a list of Variables that belong to fgraph.
If a Variable is in no_recycling, L{WrapLinker} will clear
the output storage associated to it (for each linker in linkers)
during the computation to avoid reusing it.
"""
if no_recycling is None:
......@@ -905,6 +986,7 @@ def WrapLinkerMany(linkers, wrappers):
"""
Variant on WrapLinker that runs a series of wrapper functions instead of
just one.
"""
def wrapper(*args):
for f in wrappers:
......
......@@ -3,19 +3,21 @@ from theano.gof.type import Type
class NullType(Type):
"""
A type that allows no values.
A type that allows no values. Used to represent expressions
Used to represent expressions
that are undefined, either because they do not exist mathematically
or because the code to generate the expression has not been
implemented yet.
Parameters
----------
why_null : str
A string explaining why this variable can't take on any values.
"""
def __init__(self, why_null='(no explanation given)'):
"""
why_null: A string explaining why this variable
can't take on any values
"""
self.why_null = why_null
def filter(self, data, strict=False, allow_downcast=None):
......
"""Defines base classes `Op`, `PureOp`, and `CLinkerOp`
"""
Defines base classes `Op`, `PureOp`, and `CLinkerOp`.
The `Op` class is the base interface for all operations
compatible with `gof`'s :doc:`graph` routines.
"""
"""
import inspect
import logging
import numpy
......@@ -33,122 +34,159 @@ __docformat__ = "restructuredtext en"
class CLinkerObject(object):
"""Standard elements of an Op or Type used with the CLinker
"""
Standard elements of an Op or Type used with the CLinker.
"""
def c_headers(self):
"""Optional: Return a list of header files required by code returned by
"""
Optional: Return a list of header files required by code returned by
this class.
For example: return ['<iostream>', '<math.h>', '/full/path/to/header.h']
Examples
--------
return ['<iostream>', '<math.h>', '/full/path/to/header.h']
These strings will be prefixed with "#include " and inserted at the beginning of the c
source code.
These strings will be prefixed with "#include " and inserted at the
beginning of the c source code.
Strings in this list that start neither with '<' nor '"' will be enclosed in
double-quotes.
Strings in this list that start neither with '<' nor '"' will be
enclosed in double-quotes.
:Exceptions:
- `MethodNotDefined`: Subclass does not implement this method
Raises
------
MethodNotDefined
Subclass does not implement this method.
"""
raise utils.MethodNotDefined("c_headers", type(self), self.__class__.__name__)
def c_header_dirs(self):
"""Optional: Return a list of header search paths required by code returned by
this class.
"""
Optional: Return a list of header search paths required by code
returned by this class.
For example: return ['/usr/local/include', '/opt/weirdpath/src/include'].
Examples
--------
return ['/usr/local/include', '/opt/weirdpath/src/include']
Provide search paths for headers, in addition to those in any relevant environment
variables.
Provides search paths for headers, in addition to those in any relevant
environment variables.
Hint: for unix compilers, these are the things that get '-I' prefixed in the compiler
cmdline.
Hint: for unix compilers, these are the things that get '-I' prefixed
in the compiler cmdline.
:Exceptions:
- `MethodNotDefined`: Subclass does not implement this method
Raises
------
MethodNotDefined
Subclass does not implement this method.
"""
raise utils.MethodNotDefined("c_header_dirs", type(self), self.__class__.__name__)
def c_libraries(self):
"""Optional: Return a list of libraries required by code returned by
"""
Optional: Return a list of libraries required by code returned by
this class.
For example: return ['gsl', 'gslcblas', 'm', 'fftw3', 'g2c'].
Examples
--------
return ['gsl', 'gslcblas', 'm', 'fftw3', 'g2c'].
The compiler will search the directories specified by the environment
variable LD_LIBRARY_PATH in addition to any returned by `c_lib_dirs`.
Hint: for unix compilers, these are the things that get '-l' prefixed in the compiler
cmdline.
Hint: for unix compilers, these are the things that get '-l' prefixed
in the compiler cmdline.
:Exceptions:
- `MethodNotDefined`: Subclass does not implement this method
Raises
------
MethodNotDefined
Subclass does not implement this method.
"""
raise utils.MethodNotDefined("c_libraries", type(self), self.__class__.__name__)
def c_lib_dirs(self):
"""Optional: Return a list of library search paths required by code returned by
this class.
"""
Optional: Return a list of library search paths required by code
returned by this class.
For example: return ['/usr/local/lib', '/opt/weirdpath/build/libs'].
Examples
--------
return ['/usr/local/lib', '/opt/weirdpath/build/libs'].
Provide search paths for libraries, in addition to those in any relevant environment
variables (e.g. LD_LIBRARY_PATH).
Provides search paths for libraries, in addition to those in any
relevant environment variables (e.g. LD_LIBRARY_PATH).
Hint: for unix compilers, these are the things that get '-L' prefixed in the compiler
cmdline.
Hint: for unix compilers, these are the things that get '-L' prefixed
in the compiler cmdline.
:Exceptions:
- `MethodNotDefined`: Subclass does not implement this method
Raises
------
MethodNotDefined
Subclass does not implement this method.
"""
raise utils.MethodNotDefined("c_lib_dirs", type(self), self.__class__.__name__)
def c_support_code(self):
"""Optional: Return utility code for use by a `Variable` or `Op` to be
"""
Optional: Return utility code for use by a `Variable` or `Op` to be
included at global scope prior to the rest of the code for this class.
QUESTION: How many times will this support code be emitted for a graph
with many instances of the same type?
:Exceptions:
- `MethodNotDefined`: Subclass does not implement this method
Raises
------
MethodNotDefined
Subclass does not implement this method.
"""
raise utils.MethodNotDefined("c_support_code", type(self), self.__class__.__name__)
def c_code_cache_version(self):
"""Return a tuple of integers indicating the version of this Op.
"""
Return a tuple of integers indicating the version of this Op.
An empty tuple indicates an 'unversioned' Op that will not be cached between processes.
An empty tuple indicates an 'unversioned' Op that will not be cached
between processes.
The cache mechanism may erase cached modules that have been superceded by newer
versions. See `ModuleCache` for details.
The cache mechanism may erase cached modules that have been superceded
by newer versions. See `ModuleCache` for details.
See Also
--------
c_code_cache_version_apply()
:note: See also `c_code_cache_version_apply()`
"""
return ()
def c_compile_args(self):
"""Optional: Return a list of compile args recommended to compile the
"""
Optional: Return a list of compile args recommended to compile the
code returned by other methods in this class.
Example: return ['-ffast-math']
Example
-------
return ['-ffast-math']
Compiler arguments related to headers, libraries and search paths should be provided
via the functions `c_headers`, `c_libraries`, `c_header_dirs`, and `c_lib_dirs`.
Compiler arguments related to headers, libraries and search paths should
be provided via the functions `c_headers`, `c_libraries`,
`c_header_dirs`, and `c_lib_dirs`.
:Exceptions:
- `MethodNotDefined`: Subclass does not implement this method
Raises
------
MethodNotDefined
Subclass does not implement this method.
"""
raise utils.MethodNotDefined("c_compile_args", type(self), self.__class__.__name__)
def c_no_compile_args(self):
"""Optional: Return a list of incompatible gcc compiler arguments.
"""
Optional: return a list of incompatible gcc compiler arguments.
We will remove those arguments from the command line of gcc. So if
another Op adds a compile arg in the graph that is incompatible
......@@ -159,8 +197,10 @@ class CLinkerObject(object):
WRITEME
:Exceptions:
- `MethodNotDefined`: the subclass does not override this method
Raises
------
MethodNotDefined
The subclass does not override this method.
"""
raise utils.MethodNotDefined("c_no_compile_args", type(self), self.__class__.__name__)
......@@ -170,8 +210,11 @@ class CLinkerObject(object):
Optional: return a list of code snippets to be inserted in module
initialization.
:Exceptions:
- `MethodNotDefined`: the subclass does not override this method
Raises
------
MethodNotDefined
The subclass does not override this method.
"""
raise utils.MethodNotDefined("c_init_code", type(self),
self.__class__.__name__)
......@@ -183,50 +226,56 @@ class CLinkerOp(CLinkerObject):
A subclass should implement WRITEME.
WRITEME: structure of automatically generated C code. Put this in doc/code_structure.txt
WRITEME: structure of automatically generated C code.
Put this in doc/code_structure.txt
"""
def c_code(self, node, name, inputs, outputs, sub):
"""Required: Return the C implementation of an Op.
"""
Required: return the C implementation of an Op.
Returns C code that does the computation associated to this `Op`,
given names for the inputs and outputs.
:Parameters:
`node` : Apply instance
The node for which we are compiling the current c_code.
Parameters
----------
node : Apply instance
The node for which we are compiling the current c_code.
The same Op may be used in more than one node.
`name` : A string
A name that is automatically assigned and guaranteed to be
unique.
`inputs` : list of strings
There is a string for each input of the function, and the
string is the name of a C variable pointing to that input.
The type of the variable depends on the declared type of
the input. There is a corresponding python variable that
can be accessed by prepending "py_" to the name in the
list.
`outputs` : list of strings
Each string is the name of a C variable where the Op should
store its output. The type depends on the declared type of
the output. There is a corresponding python variable that
can be accessed by prepending "py_" to the name in the
list. In some cases the outputs will be preallocated and
the value of the variable may be pre-filled. The value for
an unallocated output is type-dependent.
`sub` : dict of strings
extra symbols defined in `CLinker` sub symbols (such as 'fail').
WRITEME
:Exceptions:
- `MethodNotDefined`: the subclass does not override this method
name : str
A name that is automatically assigned and guaranteed to be
unique.
inputs : list of strings
There is a string for each input of the function, and the
string is the name of a C variable pointing to that input.
The type of the variable depends on the declared type of
the input. There is a corresponding python variable that
can be accessed by prepending "py_" to the name in the
list.
outputs : list of strings
Each string is the name of a C variable where the Op should
store its output. The type depends on the declared type of
the output. There is a corresponding python variable that
can be accessed by prepending "py_" to the name in the
list. In some cases the outputs will be preallocated and
the value of the variable may be pre-filled. The value for
an unallocated output is type-dependent.
sub : dict of strings
Extra symbols defined in `CLinker` sub symbols (such as 'fail').
WRITEME
Raises
------
MethodNotDefined
The subclass does not override this method.
"""
raise utils.MethodNotDefined('%s.c_code' % self.__class__.__name__)
def c_code_cache_version_apply(self, node):
"""Return a tuple of integers indicating the version of this Op.
"""
Return a tuple of integers indicating the version of this Op.
An empty tuple indicates an 'unversioned' Op that will not be
cached between processes.
......@@ -234,69 +283,82 @@ class CLinkerOp(CLinkerObject):
The cache mechanism may erase cached modules that have been
superceded by newer versions. See `ModuleCache` for details.
:note: See also `c_code_cache_version()`
See Also
--------
c_code_cache_version()
Notes
-----
This function overrides `c_code_cache_version` unless it explicitly
calls `c_code_cache_version`. The default implementation simply
calls `c_code_cache_version` and ignores the `node` argument.
:note: This function overrides `c_code_cache_version` unless
it explicitly calls `c_code_cache_version`. The
default implementation simply calls
`c_code_cache_version` and ignores the `node` argument.
"""
return self.c_code_cache_version()
def c_code_cleanup(self, node, name, inputs, outputs, sub):
"""
Optional: Return C code to run after c_code, whether it failed
or not.
Optional: return C code to run after c_code, whether it failed or not.
This is a convenient place to clean up things allocated by c_code().
:Parameters:
`node` : Apply instance
WRITEME
`name` : A string
A name that is automatically assigned and guaranteed to be
unique.
`inputs` : list of strings
There is a string for each input of the function, and the
string is the name of a C variable pointing to that input.
The type of the variable depends on the declared type of
the input. There is a corresponding python variable that
can be accessed by prepending "py_" to the name in the
list.
`outputs` : list of strings
Each string is the name of a C variable correspoinding to
one of the outputs of the Op. The type depends on the
declared type of the output. There is a corresponding
python variable that can be accessed by prepending "py_" to
the name in the list.
`sub` : dict of strings
extra symbols defined in `CLinker` sub symbols (such as 'fail').
WRITEME
:Exceptions:
- `MethodNotDefined`: the subclass does not override this method
Parameters
----------
node : Apply instance
WRITEME
name : str
A name that is automatically assigned and guaranteed to be
unique.
inputs : list of strings
There is a string for each input of the function, and the
string is the name of a C variable pointing to that input.
The type of the variable depends on the declared type of
the input. There is a corresponding python variable that
can be accessed by prepending "py_" to the name in the
list.
outputs : list of strings
Each string is the name of a C variable correspoinding to
one of the outputs of the Op. The type depends on the
declared type of the output. There is a corresponding
python variable that can be accessed by prepending "py_" to
the name in the list.
sub : dict of strings
extra symbols defined in `CLinker` sub symbols (such as 'fail').
WRITEME
Raises
------
MethodNotDefined
The subclass does not override this method.
"""
raise utils.MethodNotDefined('%s.c_code_cleanup' %
self.__class__.__name__)
def c_support_code_apply(self, node, name):
"""Optional: Return utility code for use by an `Op` that will be
"""
Optional: return utility code for use by an `Op` that will be
inserted at global scope, that can be specialized for the
support of a particular `Apply` node.
:param node: an Apply instance in the graph being compiled
:param name: a string or number that serves to uniquely
identify this node. Symbol names defined by this
support code should include the name, so that
they can be called from the c_code, and so that
they do not cause name collisions.
:note: This function is called in addition to c_support_code
and will supplement whatever is returned from there.
:Exceptions:
- `MethodNotDefined`: Subclass does not implement this method
Parameters
----------
node: an Apply instance in the graph being compiled
name: str
A string or number that serves to uniquely identify this node.
Symbol names defined by this support code should include the name,
so that they can be called from the c_code, and so that they do not
cause name collisions.
Notes
-----
This function is called in addition to c_support_code and will
supplement whatever is returned from there.
Raises
------
MethodNotDefined
Subclass does not implement this method.
"""
raise utils.MethodNotDefined("c_support_code_apply",
......@@ -307,19 +369,25 @@ class CLinkerOp(CLinkerObject):
Optional: return a code string specific to the apply
to be inserted in the module initialization code.
:param node: an Apply instance in the graph being compiled
Parameters
----------
node : an Apply instance in the graph being compiled
name : str
A string or number that serves to uniquely identify this node.
Symbol names defined by this support code should include the name,
so that they can be called from the c_code, and so that they do not
cause name collisions.
Notes
-----
This function is called in addition to c_init_code and will supplement
whatever is returned from there.
Raises
------
MethodNotDefined
The subclass does not override this method.
:param name: a string or number that serves to uniquely
identify this node. Symbol names defined by this
support code should include the name, so that
they can be called from the c_code, and so that
they do not cause name collisions.
:note: This function is called in addition to c_init_code
and will supplement whatever is returned from there.
:Exceptions:
- `MethodNotDefined`: the subclass does not override this method
"""
raise utils.MethodNotDefined("c_init_code_apply", type(self),
self.__class__.__name__)
......@@ -329,34 +397,42 @@ class CLinkerOp(CLinkerObject):
Optional: return a code string specific to the apply
to be inserted in the struct initialization code.
:param node: an Apply instance in the graph being compiled
:param name: a unique name to distinguish you variables from
those of other nodes.
:param sub: a dictionary of values to substitute in the code.
Most notably it contains a 'fail' entry that you
should place in your code after setting a python
exception to indicate an error.
Parameters
----------
node : an Apply instance in the graph being compiled
name : str
A unique name to distinguish variables from those of other nodes.
sub
A dictionary of values to substitute in the code.
Most notably it contains a 'fail' entry that you should place in
your code after setting a python exception to indicate an error.
Raises
------
MethodNotDefined
The subclass does not override this method.
:Exceptions:
- `MethodNotDefined`: the subclass does not override this method
"""
raise utils.MethodNotDefined("c_init_code_apply", type(self),
self.__class__.__name__)
def c_support_code_struct(self, node, name):
"""Optional: Return utility code for use by an `Op` that will be
"""
Optional: return utility code for use by an `Op` that will be
inserted at struct scope, that can be specialized for the
support of a particular `Apply` node.
:param node: an Apply instance in the graph being compiled
:param name: a unique name to distinguish you variables from
those of other nodes.
Parameters
----------
node : an Apply instance in the graph being compiled
name : str
A unique name to distinguish you variables from those of other
nodes.
:Exceptions:
- `MethodNotDefined`: Subclass does not implement this method
Raises
------
MethodNotDefined
Subclass does not implement this method.
"""
raise utils.MethodNotDefined("c_support_code_struct",
......@@ -367,13 +443,17 @@ class CLinkerOp(CLinkerObject):
Optional: return a code string specific to the apply to be
inserted in the struct cleanup code.
:param node: an Apply instance in the graph being compiled
Parameters
----------
node : an Apply instance in the graph being compiled
name : str
A unique name to distinguish variables from those of other nodes.
:param name: a unique name to distinguish you variables from
those of other nodes.
Raises
------
MethodNotDefined
The subclass does not override this method.
:Exceptions:
- `MethodNotDefined`: the subclass does not override this method
"""
raise utils.MethodNotDefined("c_cleanup_code_struct", type(self),
self.__class__.__name__)
......@@ -383,31 +463,33 @@ class PureOp(object):
"""
An :term:`Op` is a type of operation.
`Op` is an abstract class that documents the interface for theano's data transformations.
It has many subclasses, such as
`Op` is an abstract class that documents the interface for theano's data
transformations. It has many subclasses, such as
`sparse dot <http://pylearn.org/epydoc/theano.sparse.Dot-class.html>`__,
and `Shape <http://pylearn.org/epydoc/theano.tensor.Shape-class.html>`__.
These subclasses are meant to be instantiated.
An instance has several responsabilities:
- making `Apply` instances, which mean "apply this type of operation to some particular inputs" (via `make_node`),
- making `Apply` instances, which mean "apply this type of operation to some
particular inputs" (via `make_node`),
- performing the calculation of outputs from given inputs (via the `perform`),
- performing the calculation of outputs from given inputs
(via the `perform`),
- [optionally] building gradient-calculating graphs (via `grad`).
To see how `Op`, `Type`, `Variable`, and `Apply` fit together see the page
on :doc:`graph`.
To see how `Op`, `Type`, `Variable`, and `Apply` fit together see the page on :doc:`graph`.
For more specifications on how these methods should behave: see the `Op Contract` in the
sphinx docs (advanced tutorial on Op-making).
For more specifications on how these methods should behave: see the
`Op Contract` in the sphinx docs (advanced tutorial on Op-making).
"""
default_output = None
"""
configuration variable for `__call__`
Configuration variable for `__call__`.
A subclass should not change this class variable, but instead over-ride it with a subclass
variable or an instance variable.
......@@ -425,8 +507,9 @@ class PureOp(object):
All subclasses should over-ride this function.
:Exceptions:
- `MethodNotDefined`: the subclass does not override this method
Raises
------
MethodNotDefined : the subclass does not override this method.
"""
raise utils.MethodNotDefined("make_node", type(self), self.__class__.__name__)
......@@ -434,11 +517,13 @@ class PureOp(object):
@classmethod
def _get_test_value(cls, v):
"""
Extract test value from variable v. Raises AttributeError if there is none.
Extract test value from variable v.
Raises AttributeError if there is none.
For a Constant, the test value is v.value.
For a Shared variable, it is the internal value.
For another Variable, it is the content of v.tag.test_value.
"""
# avoid circular import
from theano.compile.sharedvalue import SharedVariable
......@@ -481,7 +566,8 @@ class PureOp(object):
raise AttributeError('%s has no test value' % v)
def __call__(self, *inputs, **kwargs):
"""Optional: Return some or all output[s] of `make_node`.
"""
Optional: return some or all output[s] of `make_node`.
It is called by code such as:
......@@ -489,21 +575,26 @@ class PureOp(object):
x = tensor.matrix()
# tensor.exp is an Op instance, calls Op.__call__(self=<instance of exp>, inputs=(x,))
# tensor.exp is an Op instance, calls
# Op.__call__(self=<instance of exp>, inputs=(x,))
y = tensor.exp(x)
This class implements a convenience function (for graph-building) which uses
`default_output`, but subclasses are free to override this function and ignore
`default_output`.
:param inputs: The Op's inputs, forwarded to the call to `make_node()`.
This class implements a convenience function (for graph-building) which
uses `default_output`, but subclasses are free to override this function
and ignore `default_output`.
:param kwargs: Additional keyword arguments to be forwarded to
Parameters
----------
inputs
The Op's inputs, forwarded to the call to `make_node()`.
kwargs
Additional keyword arguments to be forwarded to
`make_node()` *except* for optional argument `return_list` (which
defaults to False). If `return_list` is True, then the returned
value is always a list. Otherwise it is either a single Variable
when the output of `make_node()` contains a single element, or this
output (unchanged) when it contains multiple elements.
"""
return_list = kwargs.pop('return_list', False)
node = self.make_node(*inputs, **kwargs)
......@@ -594,25 +685,26 @@ class PureOp(object):
def R_op(self, inputs, eval_points):
"""
This method is primarily used by tensor.Rop
Suppose the op outputs
[ f_1(inputs), ..., f_n(inputs) ]
inputs: a Variable or list of Variables
eval_points: a Variable or list of Variables with
the same length as inputs. Each element
of eval_points specifies the value of
the corresponding input at the point
where the R op is to be evaluated.
returns: a list of n elements
rval[i] should be Rop(f=f_i(inputs),
wrt=inputs,
eval_points=eval_points)
Parameters
----------
inputs : a Variable or list of Variables
eval_points
A Variable or list of Variables with the same length as inputs.
Each element of eval_points specifies the value of the corresponding
input at the point where the R op is to be evaluated.
Returns
-------
list of n elements
rval[i] should be Rop(f=f_i(inputs),
wrt=inputs,
eval_points=eval_points)
"""
raise NotImplementedError(
......@@ -624,17 +716,21 @@ class PureOp(object):
def perform(self, node, inputs, output_storage):
"""
Required: Calculate the function on the inputs and put the variables in the
output storage. Return None.
:Parameters:
`node` : Apply instance
contains the symbolic inputs and outputs
`inputs` : list
sequence of inputs (immutable)
`output_storage` : list
list of mutable 1-element lists (do not change the length of these lists)
Required: Calculate the function on the inputs and put the variables in
the output storage. Return None.
Parameters
----------
node : Apply instance
Contains the symbolic inputs and outputs.
inputs : list
Sequence of inputs (immutable).
output_storage : list
List of mutable 1-element lists (do not change the length of
these lists)
Notes
-----
The `output_storage` list might contain data. If an element of
output_storage is not None, it has to be of the right type,
for instance, for a TensorVariable, it has to be a Numpy ndarray,
......@@ -644,8 +740,10 @@ class PureOp(object):
could be allocated by another Op impl is free to reuse it as it
sees fit, or to discard it and allocate new memory.
:Exceptions:
- `MethodNotDefined`: the subclass does not override this method
Raises
------
MethodNotDefined
The subclass does not override this method.
"""
raise utils.MethodNotDefined("perform", type(self), self.__class__.__name__)
......@@ -657,12 +755,16 @@ class PureOp(object):
choose where it puts its memory/speed trade-off. Also, it
could make things faster as constants can't be used for inplace
operations (see *IncSubtensor).
"""
return True
class Op(utils.object2, PureOp, CLinkerOp):
"""Convenience class to bundle `PureOp` and `CLinkerOp`"""
"""
Convenience class to bundle `PureOp` and `CLinkerOp`.
"""
def __new__(cls, *args, **kwargs):
# this function exists to silently and transparently ensure that all
# existing Ops get a _op_use_c_code attribute
......@@ -704,6 +806,7 @@ class Op(utils.object2, PureOp, CLinkerOp):
def make_c_thunk(self, node, storage_map, compute_map, no_recycling):
"""
Like make_thunk, but will only try to make a C thunk.
"""
logger = logging.getLogger('theano.gof.op.Op')
......@@ -747,6 +850,7 @@ class Op(utils.object2, PureOp, CLinkerOp):
def make_py_thunk(self, node, storage_map, compute_map, no_recycling):
"""
Like make_thunk() but only makes python thunks.
"""
node_input_storage = [storage_map[r] for r in node.inputs]
node_output_storage = [storage_map[r] for r in node.outputs]
......@@ -780,24 +884,31 @@ class Op(utils.object2, PureOp, CLinkerOp):
def make_thunk(self, node, storage_map, compute_map, no_recycling):
"""
:param node: something previously returned by self.make_node
:param storage_map: dict variable -> one-element-list where a computed
value for this variable may be found.
:param compute_map: dict variable -> one-element-list where a boolean
value will be found. The boolean indicates whether the
variable's storage_map container contains a valid value (True)
or if it has not been computed yet (False).
Parameters
----------
node
Something previously returned by self.make_node.
storage_map
dict variable -> one-element-list where a computed
value for this variable may be found.
compute_map
dict variable -> one-element-list where a boolean
value will be found. The boolean indicates whether the
variable's storage_map container contains a valid value (True)
or if it has not been computed yet (False).
no_recycling
List of variables for which it is forbidden to reuse memory
allocated by a previous call.
Notes
-----
If the thunk consults the storage_map on every call, it is safe
for it to ignore the no_recycling argument, because elements of the
no_recycling list will have a value of None in the storage map. If
the thunk can potentially cache return values (like CLinker does),
then it must not do so for variables in the no_recycling list.
:param no_recycling: list of variables for which it is forbidden to
reuse memory allocated by a previous call.
:note: If the thunk consults the storage_map on every call, it is safe
for it to ignore the no_recycling argument, because elements of the
no_recycling list will have a value of None in the storage map. If
the thunk can potentially cache return values (like CLinker does),
then it must not do so for variables in the no_recycling list.
"""
logger = logging.getLogger('theano.gof.op.Op')
......@@ -823,6 +934,7 @@ def get_test_value(v):
For a Constant, the test value is v.value.
For a Shared variable, it is the internal value.
For another Variable, it is the content of v.tag.test_value.
"""
if not isinstance(v, graph.Variable):
v_var = theano.tensor.as_tensor_variable(v)
......@@ -832,14 +944,20 @@ def get_test_value(v):
def missing_test_message(msg):
""" Displays msg, a message saying that some test_value is missing,
"""
Displays msg, a message saying that some test_value is missing,
in the appropriate form based on config.compute_test_value:
off: the interactive debugger is off, so we do nothing
ignore: the interactive debugger is set to ignore missing inputs,
so do nothing
warn: display msg as a warning
raise: raise an AttributeError with msg as the exception text
off: The interactive debugger is off, so we do nothing.
ignore: The interactive debugger is set to ignore missing inputs,
so do nothing.
warn: Display msg as a warning.
Raises
------
AttributeError
With msg as the exception text.
"""
action = config.compute_test_value
if action == 'raise':
......@@ -851,10 +969,12 @@ def missing_test_message(msg):
def debug_error_message(msg):
""" Displays a message saying that an error was found in some
"""
Displays a message saying that an error was found in some
test_values. Becomes a warning or a ValueError depending on
config.compute_test_value"""
config.compute_test_value.
"""
action = config.compute_test_value
# this message should never be called when the debugger is off
......@@ -900,7 +1020,7 @@ def get_debug_values(*args):
3. If the interactive debugger is on, and some variable does
not have a debug value, issue a missing_test_message about
the variable, and, if still in control of execution, return
an empty list
an empty list.
"""
......@@ -938,11 +1058,13 @@ self.fn, the value will be 'fn'.
We need that to be able not to run debug checks a number of times that is
exponential in the nesting level of those ops.
For instance, Scan will be registered here.
"""
class OpenMPOp(Op):
"""All op using OpenMP code should inherit from this Op.
"""
All op using OpenMP code should inherit from this Op.
This op will check that the compiler support correctly OpenMP code.
If not, it will print a warning and disable openmp for this Op.
......@@ -954,9 +1076,11 @@ class OpenMPOp(Op):
We also add the correct compiler flags in c_compile_args.
"""
gxx_support_openmp = None
"""
True/False after we tested this.
"""
def __init__(self, openmp=None):
......@@ -1004,7 +1128,8 @@ int main( int argc, const char* argv[] )
def update_self_openmp(self):
"""
Make sure self.openmp is not True if there is no support in gxx
Make sure self.openmp is not True if there is no support in gxx.
"""
if self.openmp:
if OpenMPOp.gxx_support_openmp is None:
......@@ -1056,13 +1181,16 @@ def apply_meth(tag):
class COp(Op):
""" Class to allow an op to have an external C implementation.
"""
Class to allow an op to have an external C implementation.
An op can use this class by inheriting from it and calling its
__init__() method, providing it with a path to an external file containing
the C implementation and the name of the function, in that file, to call
to perform the computations for the op.
"""
section_re = re.compile(r'^#section ([a-zA-Z0-9_]+)$', re.MULTILINE)
backward_re = re.compile(r'^THEANO_(APPLY|SUPPORT)_CODE_SECTION$', re.MULTILINE)
# This is the set of allowed markers
......@@ -1078,6 +1206,7 @@ class COp(Op):
Convert a path relative to the location of the class file into
an aboslute path. Paths that are already absolute are passed
through unchanged.
"""
if not os.path.isabs(f):
class_file = inspect.getfile(cls)
......@@ -1089,6 +1218,7 @@ class COp(Op):
"""
Sections are loaded from files in order with sections in later
files overriding sections in previous files.
"""
if not isinstance(func_files, list):
func_files = [func_files]
......@@ -1181,6 +1311,7 @@ class COp(Op):
The names must be strings that are not a C keyword and the
values must be strings of literal C representations.
"""
return []
......
"""
Defines the base class for optimizations as well as a certain
amount of useful generic optimization tools.
"""
from __future__ import print_function
......@@ -35,10 +36,13 @@ def _list_of_nodes(fgraph):
class Optimizer(object):
"""WRITEME
"""
WRITEME
An L{Optimizer} can be applied to an L{FunctionGraph} to transform it.
It can represent an optimization or in general any kind
of transformation you could apply to an L{FunctionGraph}.
"""
def __hash__(self):
......@@ -58,19 +62,25 @@ class Optimizer(object):
return id(self) != id(other)
def apply(self, fgraph):
"""WRITEME
"""
WRITEME
Applies the optimization to the provided L{FunctionGraph}. It may
use all the methods defined by the L{FunctionGraph}. If the
L{Optimizer} needs to use a certain tool, such as an
L{InstanceFinder}, it can do so in its L{add_requirements} method.
"""
pass
def optimize(self, fgraph, *args, **kwargs):
"""WRITEME
This is meant as a shortcut to::
"""
WRITEME
This is meant as a shortcut to:
opt.add_requirements(fgraph)
opt.apply(fgraph)
"""
self.add_requirements(fgraph)
try:
......@@ -82,18 +92,24 @@ class Optimizer(object):
return ret
def __call__(self, fgraph):
"""WRITEME
Same as self.optimize(fgraph)
"""
WRITEME
Same as self.optimize(fgraph).
"""
return self.optimize(fgraph)
def add_requirements(self, fgraph):
"""WRITEME
"""
WRITEME
Add features to the fgraph that are required to apply the optimization.
For example:
fgraph.attach_feature(History())
fgraph.attach_feature(MyFeature())
etc.
"""
pass
......@@ -111,7 +127,10 @@ class Optimizer(object):
class FromFunctionOptimizer(Optimizer):
"""WRITEME"""
"""
WRITEME
"""
def __init__(self, fn, requirements=()):
self.apply = fn
self.requirements = requirements
......@@ -134,14 +153,20 @@ class FromFunctionOptimizer(Optimizer):
def optimizer(f):
"""decorator for FromFunctionOptimizer"""
"""
Decorator for FromFunctionOptimizer.
"""
rval = FromFunctionOptimizer(f)
rval.__name__ = f.__name__
return rval
def inplace_optimizer(f):
"""decorator for FromFunctionOptimizer"""
"""
Decorator for FromFunctionOptimizer.
"""
dh_handler = dh.DestroyHandler
requirements = (lambda fgraph:
fgraph.attach_feature(dh_handler()),)
......@@ -152,13 +177,18 @@ def inplace_optimizer(f):
class SeqOptimizer(Optimizer, list):
# inherit from Optimizer first to get Optimizer.__hash__
"""WRITEME
"""
WRITEME
Takes a list of L{Optimizer} instances and applies them
sequentially.
"""
@staticmethod
def warn(exc, self, optimizer):
"""Default failure_callback for SeqOptimizer
"""
Default failure_callback for SeqOptimizer.
"""
_logger.error("SeqOptimizer apply %s" % str(optimizer))
_logger.error("Traceback:")
......@@ -169,15 +199,21 @@ class SeqOptimizer(Optimizer, list):
pdb.post_mortem(sys.exc_info()[2])
def __init__(self, *opts, **kw):
"""WRITEME"""
"""
WRITEME
"""
if len(opts) == 1 and isinstance(opts[0], (list, tuple)):
opts = opts[0]
self[:] = opts
self.failure_callback = kw.pop('failure_callback', None)
def apply(self, fgraph):
"""WRITEME
"""
WRITEME
Applies each L{Optimizer} in self in turn.
"""
l = []
if fgraph.profile:
......@@ -286,6 +322,7 @@ class SeqOptimizer(Optimizer, list):
def merge_profile(prof1, prof2):
"""
Merge 2 profiles returned by this cass apply() fct.
"""
new_t = []
new_l = []
......@@ -354,7 +391,11 @@ class SeqOptimizer(Optimizer, list):
class _metadict:
"""WRITEME"""
"""
WRITEME
"""
# dict that accepts unhashable keys
# uses an associative list
# for internal use only
......@@ -430,6 +471,7 @@ class MergeFeature(object):
That way, the MergeOptimizer can remember the result of the last merge
pass on the fgraph.
"""
def on_attach(self, fgraph):
assert not hasattr(fgraph, 'merge_feature')
......@@ -493,7 +535,10 @@ class MergeFeature(object):
self.seen_constants.discard(id(c))
def process_constant(self, fgraph, c):
"""Check if a constant can be merged, and queue that replacement"""
"""
Check if a constant can be merged, and queue that replacement.
"""
if id(c) in self.seen_constants:
return
sig = c.merge_signature()
......@@ -511,7 +556,10 @@ class MergeFeature(object):
self.seen_constants.add(id(c))
def process_node(self, fgraph, node):
"""Check if a node can be merged, and queue that replacement."""
"""
Check if a node can be merged, and queue that replacement.
"""
if node in self.nodes_seen:
return
......@@ -667,6 +715,7 @@ class MergeOptimizer(Optimizer):
The first step of merging is constant-merging, so that all clients of an
int(1) for example, are transferred to a particular instance of int(1).
"""
def add_requirements(self, fgraph):
......@@ -807,6 +856,7 @@ def is_same_graph_with_merge(var1, var2, givens=None):
Merge-based implementation of `theano.gof.graph.is_same_graph`.
See help on `theano.gof.graph.is_same_graph` for additional documentation.
"""
if givens is None:
givens = {}
......@@ -847,13 +897,15 @@ def pre_constant_merge(vars):
`vars` is a list of nodes, and we want to merge together nodes
that are constant inputs used to compute nodes in that list.
:note: This function will ignore nodes that are in an fgraph.
It is used to pre-merge nodes generated inside an optimization,
before it is inserted in the fgraph.
It is useful if there are many such replacements to make,
so that DebugMode will not check each of them.
"""
Notes
-----
This function will ignore nodes that are in an fgraph.
It is used to pre-merge nodes generated inside an optimization,
before it is inserted in the fgraph.
It is useful if there are many such replacements to make,
so that DebugMode will not check each of them.
"""
seen_var = set()
# signature -> variable (for constants)
const_sig_inv = {}
......@@ -896,10 +948,12 @@ def pre_constant_merge(vars):
########################
class LocalOptimizer(object):
"""A class for node-based optimizations.
"""
A class for node-based optimizations.
Instances should implement the transform function,
and be passed to configure a fgraph-based Optimizer instance.
"""
def __hash__(self):
......@@ -913,11 +967,13 @@ class LocalOptimizer(object):
Return the list of op classes that this opt applies to.
Return None to apply to all nodes.
"""
return None
def transform(self, node):
"""Transform a subgraph whose output is `node`.
"""
Transform a subgraph whose output is `node`.
Subclasses should implement this function so that it returns one of two
kinds of things:
......@@ -929,7 +985,9 @@ class LocalOptimizer(object):
- dict(old variables -> new variables). A dictionary that map
from old variables to new variables to replace.
:type node: an Apply instance
Parameters
----------
node : an Apply instance
"""
......@@ -939,8 +997,8 @@ class LocalOptimizer(object):
def add_requirements(self, fgraph):
"""
If this local optimization wants to add some requirements to the
fgraph,
This is the place to do it.
fgraph, this is the place to do it.
"""
# Added by default
# fgraph.attach_feature(toolbox.ReplaceValidate())
......@@ -959,8 +1017,11 @@ theano.configparser.AddConfigVar(
class LocalMetaOptimizer(LocalOptimizer):
"""Base class for meta-optimizers that try a set of LocalOptimizers
to replace a node and choose the one that executes the fastest"""
"""
Base class for meta-optimizers that try a set of LocalOptimizers
to replace a node and choose the one that executes the fastest.
"""
def __init__(self, tracks=None, optimizers=()):
self._tracks = tracks
......@@ -1036,9 +1097,12 @@ class LocalMetaOptimizer(LocalOptimizer):
return
def provide_inputs(self, node, inputs):
"""If implemented, returns a dictionary mapping all symbolic variables
in ``inputs`` to SharedVariable instances of suitable dummy values. The
``node`` can be inspected to infer required input shapes."""
"""
If implemented, returns a dictionary mapping all symbolic variables
in ``inputs`` to SharedVariable instances of suitable dummy values.
The ``node`` can be inspected to infer required input shapes.
"""
raise NotImplementedError()
def time_call(self, fn):
......@@ -1048,7 +1112,10 @@ class LocalMetaOptimizer(LocalOptimizer):
class FromFunctionLocalOptimizer(LocalOptimizer):
"""WRITEME"""
"""
WRITEME
"""
def __init__(self, fn, tracks=None, requirements=()):
self.transform = fn
self._tracks = tracks
......@@ -1074,7 +1141,10 @@ class FromFunctionLocalOptimizer(LocalOptimizer):
def local_optimizer(tracks, inplace=False):
def decorator(f):
"""WRITEME"""
"""
WRITEME
"""
if tracks is not None:
if len(tracks) is 0:
raise ValueError("Use None instead of an empty list to apply to all nodes.", f.__module__, f.__name__)
......@@ -1093,7 +1163,10 @@ def local_optimizer(tracks, inplace=False):
class LocalOptGroup(LocalOptimizer):
"""WRITEME"""
"""
WRITEME
"""
def __init__(self, *optimizers):
if len(optimizers) == 1 and isinstance(optimizers[0], list):
......@@ -1138,12 +1211,23 @@ class LocalOptGroup(LocalOptimizer):
class OpSub(LocalOptimizer):
"""WRITEME
"""
WRITEME
Replaces the application of a certain op by the application of
another op that take the same inputs as what they are replacing.
another op that takes the same inputs as what they are replacing.
e.g. OpSub(add, sub) ==>
Parameters
----------
op1, op2
op1.make_node and op2.make_node must take the same number of
inputs and have the same number of outputs.
Examples
--------
OpSub(add, sub) ==>
add(div(x, y), add(y, x)) -> sub(div(x, y), sub(y, x))
"""
# an OpSub does not apply to the nodes it produces
......@@ -1152,10 +1236,6 @@ class OpSub(LocalOptimizer):
retains_inputs = True
def __init__(self, op1, op2, transfer_tags=True):
"""
op1.make_node and op2.make_node must take the same number of
inputs and have the same number of outputs.
"""
self.op1 = op1
self.op2 = op2
self.transfer_tags = transfer_tags
......@@ -1181,9 +1261,12 @@ class OpSub(LocalOptimizer):
class OpRemove(LocalOptimizer):
"""WRITEME
"""
WRITEME
Removes all applications of an op by transferring each of its
outputs to the corresponding input.
"""
reentrant = False # no nodes are added at all
......@@ -1214,25 +1297,27 @@ class OpRemove(LocalOptimizer):
class PatternSub(LocalOptimizer):
"""WRITEME
"""
WRITEME
@todo update
Replaces all occurrences of the input pattern by the output pattern:
input_pattern ::= (op, <sub_pattern1>, <sub_pattern2>, ...)
input_pattern ::= dict(pattern = <input_pattern>,
input_pattern ::= (op, <sub_pattern1>, <sub_pattern2>, ...)
input_pattern ::= dict(pattern = <input_pattern>,
constraint = <constraint>)
sub_pattern ::= input_pattern
sub_pattern ::= string
sub_pattern ::= a Constant instance
sub_pattern ::= int
sub_pattern ::= float
constraint ::= lambda fgraph, expr: additional matching condition
output_pattern ::= (op, <output_pattern1>, <output_pattern2>, ...)
output_pattern ::= string
output_pattern ::= int
output_pattern ::= float
sub_pattern ::= input_pattern
sub_pattern ::= string
sub_pattern ::= a Constant instance
sub_pattern ::= int
sub_pattern ::= float
constraint ::= lambda fgraph, expr: additional matching condition
output_pattern ::= (op, <output_pattern1>, <output_pattern2>, ...)
output_pattern ::= string
output_pattern ::= int
output_pattern ::= float
Each string in the input pattern is a variable that will be set to
whatever expression is found in its place. If the same string is
......@@ -1252,45 +1337,51 @@ class PatternSub(LocalOptimizer):
trying to match and returns True or False according to an
arbitrary criterion.
Examples:
PatternSub((add, 'x', 'y'), (add, 'y', 'x'))
PatternSub((multiply, 'x', 'x'), (square, 'x'))
PatternSub((subtract, (add, 'x', 'y'), 'y'), 'x')
PatternSub((power, 'x', Constant(double, 2.0)), (square, 'x'))
PatternSub((boggle, {'pattern': 'x',
'constraint': lambda expr: expr.type == scrabble}),
(scrabble, 'x'))
The constructor creates a PatternSub that replaces occurrences of
in_pattern by occurrences of out_pattern.
Parameters
----------
in_pattern
The input pattern that we want to replace.
out_pattern
The replacement pattern.
allow_multiple_clients : bool
If False, the pattern matching will fail if one of the subpatterns has
more than one client.
skip_identities_fn : TODO
name
Allows to override this optimizer name.
pdb : bool
If True, we invoke pdb when the first node in the pattern matches.
tracks : optional
The values that self.tracks() will return. Useful to speed up
optimization sometimes.
get_nodes : optional
If you provide `tracks`, you must provide this parameter. It must be a
function that takes the tracked node and returns a list of nodes on
which we will try this optimizer.
Notes
-----
`tracks` and `get_nodes` can be used to make this optimizer track a less
frequent Op, so this will make this optimizer tried less frequently.
Examples
--------
PatternSub((add, 'x', 'y'), (add, 'y', 'x'))
PatternSub((multiply, 'x', 'x'), (square, 'x'))
PatternSub((subtract, (add, 'x', 'y'), 'y'), 'x')
PatternSub((power, 'x', Constant(double, 2.0)), (square, 'x'))
PatternSub((boggle, {'pattern': 'x',
'constraint': lambda expr: expr.type == scrabble}),
(scrabble, 'x'))
"""
def __init__(self, in_pattern, out_pattern,
allow_multiple_clients=False,
skip_identities_fn=None, name=None, pdb=False,
tracks=(), get_nodes=None):
"""
Creates a PatternSub that replaces occurrences of
in_pattern by occurrences of out_pattern.
:param in_pattern: the input pattern that we want to replace
:param out_pattern: the replacement pattern
:param allow_multiple_clients: if False, the pattern matching will fail
if one of the subpatterns has more than
one client.
:param skip_identities_fn: TODO
:param name: Allow to override this optimizer name
:param pdb: if True, we invoke pdb when the first node in the
pattern match.
:param tracks: Optional. The values that self.tracks() will
return. Useful to speed up optimization some times.
:param get_nodes: Optional. If you provide `tracks`, you must
provide this parameter. It must be a function that take the
tracked node and return a list of node on which we will try
this optimizer.
`tracks` and `get_nodes` can be used to make this optimizer
track a less frequent Op, so this will make this optimizer
tried less frequently,
"""
self.in_pattern = in_pattern
self.out_pattern = out_pattern
if isinstance(in_pattern, (list, tuple)):
......@@ -1325,6 +1416,7 @@ class PatternSub(LocalOptimizer):
"""
Checks if the graph from node corresponds to in_pattern. If it does,
constructs out_pattern and performs the replacement.
"""
if get_nodes and self.get_nodes is not None:
for real_node in self.get_nodes(node):
......@@ -1486,12 +1578,40 @@ class Updater:
class NavigatorOptimizer(Optimizer):
"""Abstract class
"""
Abstract class.
Parameters
----------
local_opt
A LocalOptimizer to apply over a FunctionGraph (or None is Ok too).
ignore_newtrees
- True: new subgraphs returned by an optimization is not a
candidate for optimization.
- False: new subgraphs returned by an optimization is a candidate
for optimization.
- 'auto': let the local_opt set this parameter via its 'reentrant'
attribute.
failure_callback
A function that takes (exception, navigator, [(old, new),
(old,new),...]) and we call it if there's an exception.
If the trouble is from local_opt.transform(), the new variables
will be 'None'.
If the trouble is from validation (the new types don't match for
example) then the new variables will be the ones created by
transform().
If this parameter is None, then exceptions are not caught here
(raised normally).
"""
@staticmethod
def warn(exc, nav, repl_pairs, local_opt):
"""failure_callback for NavigatorOptimizer: print traceback
"""
Failure_callback for NavigatorOptimizer: print traceback.
"""
if config.on_opt_error != 'ignore':
_logger.error("Optimization failure due to: %s" % str(local_opt))
......@@ -1506,9 +1626,11 @@ class NavigatorOptimizer(Optimizer):
@staticmethod
def warn_inplace(exc, nav, repl_pairs, local_opt):
"""failure_callback for NavigatorOptimizer
"""
Failure_callback for NavigatorOptimizer.
Ignore InconsistencyErrors, print traceback.
ignore InconsistencyErrors, print traceback
"""
if isinstance(exc, InconsistencyError):
return
......@@ -1516,36 +1638,14 @@ class NavigatorOptimizer(Optimizer):
@staticmethod
def warn_ignore(exc, nav, repl_pairs, local_opt):
"""failure_callback for NavigatorOptimizer: ignore all errors
"""
Failure_callback for NavigatorOptimizer: ignore all errors.
"""
pass
def __init__(self, local_opt, ignore_newtrees='auto',
failure_callback=None):
"""
:param local_opt: a LocalOptimizer to apply over a FunctionGraph
(or None is Ok too).
:param ignore_newtrees:
- True: new subgraphs returned by an optimization is not a
candidate for optimization
- False: new subgraphs returned by an optimization is a candidate
for optimization
- 'auto': let the local_opt set this parameter via its 'reentrant'
attribute.
:param failure_callback:
a function that takes (exception, navigator, [(old, new),
(old,new),...]) and we call it if there's an exception.
If the trouble is from local_opt.transform(), the new variables
will be 'None'.
If the trouble is from validation (the new types don't match for
example) then the new variables will be the ones created by
transform().
If this parameter is None, then exceptions are not caught here
(raised normally).
"""
self.local_opt = local_opt
if ignore_newtrees == 'auto':
self.ignore_newtrees = not getattr(local_opt, 'reentrant', True)
......@@ -1558,14 +1658,23 @@ class NavigatorOptimizer(Optimizer):
Install some FunctionGraph listeners to help the navigator deal with
the ignore_trees-related functionality.
:param importer: function that will be called whenever when
optimizations add stuff to the graph.
:param pruner: function to be called when optimizations remove stuff
from graph.
:param chin: "on change input" called whenever an node's inputs change.
:returns: The FunctionGraph plugin that handles the three tasks.
Parameters
----------
importer
Function that will be called whenever optimizations add stuff
to the graph.
pruner
Function to be called when optimizations remove stuff
from the graph.
chin
"on change input" called whenever a node's inputs change.
Returns
-------
object
The FunctionGraph plugin that handles the three tasks.
Keep this around so that you can detach later!
"""
if self.ignore_newtrees:
importer = None
......@@ -1578,18 +1687,25 @@ class NavigatorOptimizer(Optimizer):
return u
def detach_updater(self, fgraph, u):
"""Undo the work of attach_updater.
"""
Undo the work of attach_updater.
Parameters
----------
u
A return-value of attach_updater.
:param u: a return-value of attach_updater
Returns
-------
None
:returns: None.
"""
if u is not None:
fgraph.remove_feature(u)
def process_node(self, fgraph, node, lopt=None):
"""
This function will use `lopt` to `transform` the `node`. The
This function will use `lopt` to `transform` the `node`. The
`transform` method will return either False or a list of Variables
that are intended to replace `node.outputs`.
......@@ -1599,12 +1715,20 @@ class NavigatorOptimizer(Optimizer):
If there are no replacement candidates or the fgraph rejects the
replacements, this function returns False.
:param fgraph: a FunctionGraph
:param node: an Apply instance in `fgraph`
:param lopt: a LocalOptimizer instance that may have a better idea for
Parameters
----------
fgraph
A FunctionGraph.
node
An Apply instance in `fgraph`
lopt
A LocalOptimizer instance that may have a better idea for
how to compute node's outputs.
:rtype: Bool
:returns: True iff the `node`'s outputs were replaced in the `fgraph`.
Returns
-------
bool
True iff the `node`'s outputs were replaced in the `fgraph`.
"""
lopt = lopt or self.local_opt
......@@ -1673,7 +1797,10 @@ class NavigatorOptimizer(Optimizer):
class TopoOptimizer(NavigatorOptimizer):
"""WRITEME"""
"""
WRITEME
"""
def __init__(self, local_opt, order='in_to_out', ignore_newtrees=False,
failure_callback=None):
......@@ -1746,7 +1873,10 @@ class TopoOptimizer(NavigatorOptimizer):
class OpKeyOptimizer(NavigatorOptimizer):
"""WRITEME"""
"""
WRITEME
"""
def __init__(self, local_opt, ignore_newtrees=False,
failure_callback=None):
......@@ -1790,6 +1920,7 @@ class OpKeyOptimizer(NavigatorOptimizer):
Requires the following features:
- NodeFinder
- ReplaceValidate(Added by default)
"""
super(OpKeyOptimizer, self).add_requirements(fgraph)
fgraph.attach_feature(toolbox.NodeFinder())
......@@ -1815,24 +1946,27 @@ class ChangeTracker:
class EquilibriumOptimizer(NavigatorOptimizer):
"""
Apply optimizations until equilibrium point.
Parameters
----------
optimizers
List or set of local or global optimizations to apply until equilibrium.
max_use_ratio
Each optimizer can be applied at most (size of graph * this number)
times.
ignore_newtrees
See EquilibriumDB ignore_newtrees parameter definition.
"""
def __init__(self,
optimizers,
failure_callback=None,
ignore_newtrees=True,
max_use_ratio=None,
final_optimizers=None):
""" Apply optimizations until equilibrium point.
:param optimizers: list or set of local or global optimizations to
apply until equilibrium.
:param max_use_ratio: each optimizer can be applied at most
(size of graph * this number) times
:param ignore_newtrees: See EquilibriumDB ignore_newtrees
parameter definition
"""
super(EquilibriumOptimizer, self).__init__(
None,
ignore_newtrees=ignore_newtrees,
......@@ -2212,8 +2346,10 @@ class EquilibriumOptimizer(NavigatorOptimizer):
def _check_chain(r, chain):
"""WRITEME"""
"""
WRITEME
"""
chain = list(reversed(chain))
while chain:
elem = chain.pop()
......@@ -2244,17 +2380,20 @@ def _check_chain(r, chain):
def check_chain(r, *chain):
"""WRITEME"""
"""
WRITEME
"""
if isinstance(r, graph.Apply):
r = r.outputs[0]
return _check_chain(r, reduce(list.__iadd__, ([x, 0] for x in chain)))
def pre_greedy_local_optimizer(list_optimizations, out):
'''
"""
This function traverses the computation graph described by all
``node`` in the graph before the variable out but that are not in the
fgraph. it applies each of the local_optimizations on the traversed graph.
fgraph. It applies each of the local_optimizations on the traversed graph.
Its main use is to apply locally constant folding when generating
the graph of the indices of a subtensor.
......@@ -2262,11 +2401,14 @@ def pre_greedy_local_optimizer(list_optimizations, out):
We should not apply optimizations on node that are in fgraph.
So we don't optimize node that have an attribute fgraph.
:note: This don't do an equilibrium... So if there is optimization
like local_upcast_elemwise_constant_inputs in the list, that
add additional node to the inputs of the node, it can
be needed to call this function multiple time.
'''
Notes
-----
This doesn't do an equilibrium... So if there is optimization
like local_upcast_elemwise_constant_inputs in the list, that
adds additional node to the inputs of the node, it can
be needed to call this function multiple times.
"""
def local_recursive_function(list_opt, out, optimized_vars, depth):
if not getattr(out, 'owner', None):
return [out], optimized_vars
......
......@@ -36,11 +36,17 @@ class DB(object):
def register(self, name, obj, *tags, **kwargs):
"""
:param name: name of the optimizer.
:param obj: the optimizer to register.
:param tags: tag name that allow to select the optimizer.
:param kwargs: If non empty, should contain
only use_db_name_as_tag=False.
Parameters
----------
name : str
Name of the optimizer.
obj
The optimizer to register.
tags
Tag name that allow to select the optimizer.
kwargs
If non empty, should contain only use_db_name_as_tag=False.
By default, all optimizations registered in EquilibriumDB
are selected when the EquilibriumDB name is used as a
tag. We do not want this behavior for some optimizer like
......@@ -156,14 +162,18 @@ multiple time in a DB. Tryed to register "%s" again under the new name "%s".
class Query(object):
"""
Parameters
----------
position_cutoff : float
Used by SequenceDB to keep only optimizer that are positioned before
the cut_off point.
"""
def __init__(self, include, require=None, exclude=None,
subquery=None, position_cutoff=None):
"""
:type position_cutoff: float
:param position_cutoff: Used by SequenceDB to keep only optimizer that
are positioned before the cut_off point.
"""
self.include = OrderedSet(include)
self.require = require or OrderedSet()
self.exclude = exclude or OrderedSet()
......@@ -206,22 +216,26 @@ class Query(object):
class EquilibriumDB(DB):
"""A set of potential optimizations which should be applied in an
arbitrary order until equilibrium is reached.
"""
A set of potential optimizations which should be applied in an arbitrary
order until equilibrium is reached.
Canonicalize, Stabilize, and Specialize are all equilibrium optimizations.
:param ignore_newtrees: If False, we will apply local opt on new
node introduced during local optimization application. This
could result in less fgraph iterations, but this don't mean it
will be faster globally.
Parameters
----------
ignore_newtrees
If False, we will apply local opt on new node introduced during local
optimization application. This could result in less fgraph iterations,
but this doesn't mean it will be faster globally.
.. note::
We can put LocalOptimizer and Optimizer as EquilibriumOptimizer
suppor both.
Notes
-----
We can put LocalOptimizer and Optimizer as EquilibriumOptimizer
suppor both.
"""
def __init__(self, ignore_newtrees=True):
super(EquilibriumDB, self).__init__()
self.ignore_newtrees = ignore_newtrees
......@@ -253,7 +267,8 @@ class EquilibriumDB(DB):
class SequenceDB(DB):
"""A sequence of potential optimizations.
"""
A sequence of potential optimizations.
Retrieve a sequence of optimizations (a SeqOptimizer) by calling query().
......@@ -265,6 +280,7 @@ class SequenceDB(DB):
other tags) fast_run and fast_compile optimizers are drawn is a SequenceDB.
"""
seq_opt = opt.SeqOptimizer
def __init__(self, failure_callback=opt.SeqOptimizer.warn):
......@@ -278,9 +294,12 @@ class SequenceDB(DB):
def query(self, *tags, **kwtags):
"""
:type position_cutoff: float or int
:param position_cutoff: only optimizations with position less than
the cutoff are returned.
Parameters
----------
position_cutoff : float or int
Only optimizations with position less than the cutoff are returned.
"""
opts = super(SequenceDB, self).query(*tags, **kwtags)
......@@ -326,11 +345,14 @@ class SequenceDB(DB):
class LocalGroupDB(SequenceDB):
"""This generate a local optimizer of type LocalOptGroup instead
"""
Generate a local optimizer of type LocalOptGroup instead
of a global optimizer.
It support the tracks, to only get applied to some Op.
It supports the tracks, to only get applied to some Op.
"""
seq_opt = opt.LocalOptGroup
def __init__(self, failure_callback=opt.SeqOptimizer.warn):
......@@ -342,9 +364,11 @@ class ProxyDB(DB):
"""
Wrap an existing proxy.
This is needed as we can't register the same DB mutiple time in
different position in a SequentialDB
This is needed as we can't register the same DB mutiple times in
different positions in a SequentialDB.
"""
def __init__(self, db):
assert isinstance(db, DB), ""
self.db = db
......
......@@ -26,7 +26,10 @@ from theano.compat import cmp
def memodict(f):
""" Memoization decorator for a function taking a single argument """
"""
Memoization decorator for a function taking a single argument.
"""
class memodict(defaultdict):
def __missing__(self, key):
ret = self[key] = f(key)
......@@ -39,7 +42,10 @@ def memodict(f):
def make_depends():
@memodict
def depends(pair):
""" Returns True if a depends on b """
"""
Returns True if a depends on b.
"""
a, b = pair
return (any(bout in a.inputs for bout in b.outputs) or
any(depends((ainp.owner, b)) for ainp in a.inputs
......@@ -48,16 +54,22 @@ def make_depends():
def make_dependence_cmp():
""" Create a comparator to represent the dependence of nodes in a graph """
"""
Create a comparator to represent the dependence of nodes in a graph.
"""
depends = make_depends()
def dependence(a, b):
""" A cmp function for nodes in a graph - does a depend on b?
"""
A cmp function for nodes in a graph - does a depend on b?
Returns
-------
int
Positive number if a depends on b, negative number
if b depends on a, 0 otherwise.
Returns positive number if a depends on b
Returns negative number if b depends on a
Returns 0 otherwise
"""
if depends((a, b)):
return 1
......@@ -69,17 +81,22 @@ def make_dependence_cmp():
def reverse_dict(d):
"""Reverses direction of dependence dict
"""
Reverses direction of dependence dict.
Notes
-----
dict order is not deterministic. As we iterate on the
input dict, it makes the output of this function depend on the
dict order. So this function output order should be considered
as undeterministic.
Examples
--------
>>> d = {'a': (1, 2), 'b': (2, 3), 'c':()}
>>> reverse_dict(d)
{1: ('a',), 2: ('a', 'b'), 3: ('b',)}
:note: dict order are not deterministic. As we iterate on the
input dict, it make the output of this function depend on the
dict order. So this function output order should be considered
as undeterministic.
"""
result = {}
for key in d:
......@@ -89,21 +106,32 @@ def reverse_dict(d):
def _toposort(edges):
""" Topological sort algorithm by Kahn [1] - O(nodes + vertices)
"""
Topological sort algorithm by Kahn [1] - O(nodes + vertices).
inputs:
edges - a dict of the form {a: {b, c}} where b and c depend on a
outputs:
L - an ordered list of nodes that satisfy the dependencies of edges
Parameters
----------
edges
A dict of the form {a: {b, c}} where b and c depend on a.
>>> _toposort({1: {2, 3}, 2: (3, )})
[1, 2, 3]
Returns
-------
L : list
An ordered list of nodes that satisfy the dependencies of edges.
Closely follows the wikipedia page [2]
References
----------
[1] Kahn, Arthur B. (1962), "Topological sorting of large networks",
Communications of the ACM
[2] http://en.wikipedia.org/wiki/Toposort#Algorithms
Examples
--------
>>> _toposort({1: {2, 3}, 2: (3, )})
[1, 2, 3]
"""
incoming_edges = reverse_dict(edges)
incoming_edges = dict((k, set(val))
......@@ -125,25 +153,38 @@ def _toposort(edges):
def posort(l, *cmps):
""" Partially ordered sort with multiple comparators
Given a list of comparators order the elements in l so that the comparators
are satisfied as much as possible giving precedence to earlier comparators.
inputs:
l - an iterable of nodes in a graph
cmps - a sequence of comparator functions that describe which nodes
should come before which others
outputs:
a list of nodes which satisfy the comparators as much as possible.
"""
Partially ordered sort with multiple comparators.
Given a list of comparators, orders the elements in l so that the
comparators are satisfied as much as possible giving precedence to
earlier comparators.
Parameters
----------
l
An iterable of nodes in a graph.
cmps
A sequence of comparator functions that describe which nodes should
come before which others.
Returns
-------
list
A list of nodes which satisfy the comparators as much as possible.
Notes
-----
Implemented with _toposort.
Examples
--------
>>> lower_tens = lambda a, b: a/10 - b/10 # prefer lower numbers div 10
>>> prefer evens = lambda a, b: a%2 - b%2 # prefer even numbers
>>> posort(list(range(20)), lower_tens, prefer_evens)
[0, 8, 2, 4, 6, 1, 3, 5, 7, 9, 16, 18, 10, 12, 14, 17, 19, 11, 13, 15]
implemented with _toposort """
"""
comes_before = dict((a, set()) for a in l)
comes_after = dict((a, set()) for a in l)
......@@ -158,7 +199,10 @@ def posort(l, *cmps):
comes_before[c].update(comes_before[b])
def check():
""" Tests for cycles in manufactured edges """
"""
Tests for cycles in manufactured edges.
"""
for a in l:
for b in l:
assert not(b in comes_after[a] and a in comes_after[b])
......@@ -176,12 +220,15 @@ def posort(l, *cmps):
def sort_apply_nodes(inputs, outputs, cmps):
""" Order a graph of apply nodes according to a list of comparators
"""
Order a graph of apply nodes according to a list of comparators.
The following example sorts first by dependence of nodes (this is a
topological sort) and then by lexicographical ordering (nodes that start
with 'E' come before nodes that start with 'I' if there is no dependence.
Examples
--------
>>> from theano.gof.graph import sort_apply_nodes, dependence
>>> from theano.tensor import matrix, dot
>>> x = matrix('x')
......@@ -193,22 +240,28 @@ def sort_apply_nodes(inputs, outputs, cmps):
Elemwise{mul,no_inplace}(x, InplaceDimShuffle{x,x}.0),
InplaceDimShuffle{x,x}(TensorConstant{1}),
dot(Elemwise{mul,no_inplace}.0, Elemwise{add,no_inplace}.0)]
"""
"""
return posort(list_of_nodes(inputs, outputs), *cmps)
def sort_schedule_fn(*cmps):
""" Make a schedule function from comparators
"""
Make a schedule function from comparators.
See Also
--------
sort_apply_nodes
See also:
sort_apply_nodes
"""
dependence = make_dependence_cmp()
cmps = (dependence,) + cmps
def schedule(fgraph):
""" Order nodes in a FunctionGraph """
"""
Order nodes in a FunctionGraph.
"""
return sort_apply_nodes(fgraph.inputs, fgraph.outputs, cmps)
return schedule
......
......@@ -11,18 +11,24 @@ from theano.gof import graph
class AlreadyThere(Exception):
"""Raised by a Feature's on_attach callback method if the FunctionGraph
"""
Raised by a Feature's on_attach callback method if the FunctionGraph
attempting to attach the feature already has a functionally identical
feature."""
feature.
"""
pass
class ReplacementDidntRemovedError(Exception):
"""This exception should be thrown by replace_all_validate_remove
"""
This exception should be thrown by replace_all_validate_remove
when an optimization wanted to remove a Variable or a Node from
the graph, but the replacement it gived didn't do that.
"""
pass
......@@ -34,7 +40,10 @@ class Feature(object):
by various operations on FunctionGraphs. It can be used to enforce
graph properties at all stages of graph optimization.
See :func:`theano.gof.toolbox` for common extensions.
See Also
--------
theano.gof.toolbox : for common extensions.
"""
def on_attach(self, function_graph):
......@@ -51,12 +60,14 @@ class Feature(object):
The feature has great freedom in what it can do with the
function_graph: it may, for example, add methods to it dynamically.
"""
def on_detach(self, function_graph):
"""
Called by remove_feature(feature). Should remove any dynamically-added
functionality that it installed into the function_graph.
"""
def on_import(self, function_graph, node, reason):
......@@ -66,12 +77,14 @@ class Feature(object):
Note: on_import is not called when the graph is created. If you
want to detect the first nodes to be implemented to the graph,
you should do this by implementing on_attach.
"""
def on_prune(self, function_graph, node, reason):
"""
Called whenever a node is pruned (removed) from the function_graph,
after it is disconnected from the graph.
"""
def on_change_input(self, function_graph, node, i, r, new_r, reason=None):
......@@ -82,6 +95,7 @@ class Feature(object):
If you raise an exception in this function, the state of the graph
might be broken for all intents and purposes.
"""
def orderings(self, function_graph):
......@@ -92,6 +106,7 @@ class Feature(object):
If you raise an exception in this function, the state of the graph
might be broken for all intents and purposes.
"""
return OrderedDict()
......@@ -166,8 +181,9 @@ class History(Feature):
def revert(self, fgraph, checkpoint):
"""
Reverts the graph to whatever it was at the provided
checkpoint (undoes all replacements). A checkpoint at any
checkpoint (undoes all replacements). A checkpoint at any
given time can be obtained using self.checkpoint().
"""
h = self.history[fgraph]
self.history[fgraph] = None
......@@ -302,8 +318,9 @@ class ReplaceValidate(History, Validator):
def replace_all_validate_remove(self, fgraph, replacements,
remove, reason=None, warn=True):
"""As replace_all_validate, revert the replacement if the ops
in the list remove are still in the graph. It also print a warning.
"""
As replace_all_validate, revert the replacement if the ops
in the list remove are still in the graph. Also print a warning.
"""
chk = fgraph.replace_all_validate(replacements, reason)
......
"""WRITEME Defines the `Type` class."""
"""
WRITEME
Defines the `Type` class.
"""
from theano.compat import PY3
from six import string_types
......@@ -15,7 +20,8 @@ __docformat__ = "restructuredtext en"
class CLinkerType(CLinkerObject):
"""Interface specification for Types that can be arguments to a `CLinkerOp`.
"""
Interface specification for Types that can be arguments to a `CLinkerOp`.
A CLinkerType instance is mainly reponsible for providing the C code that
interfaces python objects with a C `CLinkerOp` implementation.
......@@ -25,84 +31,101 @@ class CLinkerType(CLinkerObject):
"""
def c_is_simple(self):
"""Optional: Return True for small or builtin C types.
"""
Optional: Return True for small or builtin C types.
A hint to tell the compiler that this type is a builtin C type or a
small struct and that its memory footprint is negligible. Simple
small struct and that its memory footprint is negligible. Simple
objects may be passed on the stack.
"""
return False
def c_literal(self, data):
"""Optional: WRITEME
"""
Optional: WRITEME
:Parameters:
- `data`: WRITEME
Parameters
----------
data : WRITEME
WRITEME
:Exceptions:
- `MethodNotDefined`: Subclass does not implement this method
Raises
------
MethodNotDefined
Subclass does not implement this method.
"""
raise MethodNotDefined("c_literal", type(self),
self.__class__.__name__)
def c_declare(self, name, sub, check_input=True):
"""Required: Return c code to declare variables that will be
"""
Required: Return c code to declare variables that will be
instantiated by `c_extract`.
Example:
.. code-block: python
return "PyObject ** addr_of_%(name)s;"
:param name: the name of the ``PyObject *`` pointer that will
Parameters
----------
name: str
The name of the ``PyObject *`` pointer that will
the value for this Type
sub: dict string -> string
a dictionary of special codes. Most importantly
sub['fail']. See CLinker for more info on `sub` and ``fail``.
:type name: string
Notes
-----
It is important to include the `name` inside of variables which
are declared here, so that name collisions do not occur in the
source file that is generated.
:param sub: a dictionary of special codes. Most importantly
sub['fail']. See CLinker for more info on `sub` and ``fail``.
The variable called ``name`` is not necessarily defined yet
where this code is inserted. This code might be inserted to
create class variables for example, whereas the variable ``name``
might only exist inside certain functions in that class.
:type sub: dict string -> string
TODO: Why should variable declaration fail? Is it even allowed to?
:note: It is important to include the `name` inside of variables which
are declared here, so that name collisions do not occur in the
source file that is generated.
Raises
------
MethodNotDefined
Subclass does not implement this method.
:note: The variable called ``name`` is not necessarily defined yet
where this code is inserted. This code might be inserted to
create class variables for example, whereas the variable ``name``
might only exist inside certain functions in that class.
Examples
--------
.. code-block: python
:todo: Why should variable declaration fail? Is it even allowed to?
return "PyObject ** addr_of_%(name)s;"
:Exceptions:
- `MethodNotDefined`: Subclass does not implement this method
"""
raise MethodNotDefined()
def c_init(self, name, sub):
"""Required: Return c code to initialize the variables that were declared by
self.c_declare()
"""
Required: Return c code to initialize the variables that were declared
by self.c_declare().
Notes
-----
The variable called ``name`` is not necessarily defined yet
where this code is inserted. This code might be inserted in a
class constructor for example, whereas the variable ``name``
might only exist inside certain functions in that class.
TODO: Why should variable initialization fail? Is it even allowed to?
Example:
Examples
--------
.. code-block: python
return "addr_of_%(name)s = NULL;"
:note: The variable called ``name`` is not necessarily defined yet
where this code is inserted. This code might be inserted in a
class constructor for example, whereas the variable ``name``
might only exist inside certain functions in that class.
:todo: Why should variable initialization fail? Is it even allowed to?
"""
raise MethodNotDefined("c_init", type(self), self.__class__.__name__)
def c_extract(self, name, sub, check_input=True):
"""Required: Return c code to extract a PyObject * instance.
"""
Required: Return c code to extract a PyObject * instance.
The code returned from this function must be templated using
``%(name)s``, representing the name that the caller wants to
......@@ -112,11 +135,25 @@ class CLinkerType(CLinkerObject):
of py_%(name)s. If the data is improper, set an appropriate
exception and insert "%(fail)s".
:todo: Point out that template filling (via sub) is now performed
by this function. --jpt
TODO: Point out that template filling (via sub) is now performed
by this function. --jpt
Parameters
----------
name : str
The name of the ``PyObject *`` pointer that will
store the value for this Type.
sub : dict string -> string
A dictionary of special codes. Most importantly
sub['fail']. See CLinker for more info on `sub` and ``fail``.
Example:
Raises
------
MethodNotDefined
Subclass does not implement this method.
Examples
--------
.. code-block: python
return "if (py_%(name)s == Py_None)" + \\\
......@@ -125,29 +162,17 @@ class CLinkerType(CLinkerObject):
{ PyErr_SetString(PyExc_ValueError, \\\
'was expecting None'); %(fail)s;}"
:param name: the name of the ``PyObject *`` pointer that will
store the value for this Type
:type name: string
:param sub: a dictionary of special codes. Most importantly
sub['fail']. See CLinker for more info on `sub` and ``fail``.
:type sub: dict string -> string
:Exceptions:
- `MethodNotDefined`: Subclass does not implement this method
"""
raise MethodNotDefined("c_extract", type(self),
self.__class__.__name__)
def c_extract_out(self, name, sub, check_input=True):
"""Optional: C code to extract a PyObject * instance.
"""
Optional: C code to extract a PyObject * instance.
Unlike c_extract, c_extract_out has to accept Py_None,
meaning that the variable should be left uninitialized.
"""
return """
if (py_%(name)s == Py_None)
......@@ -164,7 +189,8 @@ class CLinkerType(CLinkerObject):
c_extract_code=self.c_extract(name, sub, check_input))
def c_cleanup(self, name, sub):
"""Return c code to clean up after `c_extract`.
"""
Return C code to clean up after `c_extract`.
This returns C code that should deallocate whatever `c_extract`
allocated or decrease the reference counts. Do not decrease
......@@ -172,55 +198,64 @@ class CLinkerType(CLinkerObject):
WRITEME
:Parameters:
- `name`: WRITEME
Parameters
----------
name : WRITEME
WRITEME
- `sub`: WRITEME
sub : WRITEME
WRITEME
:Exceptions:
- `MethodNotDefined`: Subclass does not implement this method
Raises
------
MethodNotDefined
Subclass does not implement this method.
"""
raise MethodNotDefined()
def c_sync(self, name, sub):
"""Required: Return c code to pack C types back into a PyObject.
"""
Required: Return C code to pack C types back into a PyObject.
The code returned from this function must be templated using
"%(name)s", representing the name that the caller wants to
call this Variable. The returned code may set "py_%(name)s"
call this Variable. The returned code may set "py_%(name)s"
to a PyObject* and that PyObject* will be accessible from
Python via variable.data. Do not forget to adjust reference
counts if "py_%(name)s" is changed from its original value.
:Parameters:
- `name`: WRITEME
Parameters
----------
name : WRITEME
WRITEME
- `sub`: WRITEME
sub : WRITEME
WRITEME
:Exceptions:
- `MethodNotDefined`: Subclass does not implement this method
Raises
------
MethodNotDefined
Subclass does not implement this method.
"""
raise MethodNotDefined("c_sync", type(self), self.__class__.__name__)
def c_code_cache_version(self):
"""Return a tuple of integers indicating the version of this Type.
"""
Return a tuple of integers indicating the version of this Type.
An empty tuple indicates an 'unversioned' Type that will not
be cached between processes.
The cache mechanism may erase cached modules that have been
superceded by newer versions. See `ModuleCache` for details.
superceded by newer versions. See `ModuleCache` for details.
"""
return ()
class PureType(object):
"""Interface specification for variable type instances.
"""
Interface specification for variable type instances.
A :term:`Type` instance is mainly reponsible for two things:
......@@ -228,8 +263,10 @@ class PureType(object):
- filtering a value assigned to a `Variable` so that the value
conforms to restrictions imposed by the type (also known as
casting, this is done by `filter`),
casting, this is done by `filter`).
"""
# the type that will be created by call to make_variable.
Variable = graph.Variable
......@@ -237,7 +274,8 @@ class PureType(object):
Constant = graph.Constant
def filter(self, data, strict=False, allow_downcast=None):
"""Required: Return data or an appropriately wrapped/converted data.
"""
Required: Return data or an appropriately wrapped/converted data.
Subclass implementation should raise a TypeError exception if
the data is not of an acceptable type.
......@@ -250,8 +288,10 @@ class PureType(object):
Type-dependent, but for now it means only Python floats can be
downcasted, and only to floatX scalars.
:Exceptions:
- `MethodNotDefined`: subclass doesn't implement this function.
Raises
------
MethodNotDefined
Subclass doesn't implement this function.
"""
raise MethodNotDefined("filter", type(self), self.__class__.__name__)
......@@ -264,13 +304,15 @@ class PureType(object):
# def filter_inplace(value, storage, strict=False, allow_downcast=None)
def filter_variable(self, other, allow_convert=True):
"""Convert a symbolic variable into this Type, if compatible.
"""
Convert a symbolic variable into this Type, if compatible.
For the moment, the only Types compatible with one another are
TensorType and CudaNdarrayType, provided they have the same
number of dimensions, same broadcasting pattern, and same dtype.
If Types are not compatible, a TypeError should be raised.
"""
if not isinstance(other, graph.Variable):
# The value is not a Variable: we cast it into
......@@ -291,7 +333,8 @@ class PureType(object):
return other
def convert_variable(self, var):
"""Patch variable so that its type will match self, if possible.
"""
Patch variable so that its type will match self, if possible.
If the variable can't be converted, this should return None.
......@@ -305,12 +348,16 @@ class PureType(object):
inverse.
The default is to not convert anything which is always safe.
"""
return None
def is_valid_value(self, a):
"""Required: Return True for any python object `a` that would be a
legal value for a Variable of this Type"""
"""
Required: Return True for any python object `a` that would be a
legal value for a Variable of this Type.
"""
try:
self.filter(a, strict=True)
return True
......@@ -318,15 +365,20 @@ legal value for a Variable of this Type"""
return False
def value_validity_msg(self, a):
"""Optional: return a message explaining the output of
is_valid_value"""
"""
Optional: Return a message explaining the output of
is_valid_value.
"""
return "none"
def make_variable(self, name=None):
"""Return a new `Variable` instance of Type `self`.
"""
Return a new `Variable` instance of Type `self`.
:Parameters:
- `name`: None or str
Parameters
----------
name : None or str
A pretty string for printing and debugging.
"""
......@@ -336,10 +388,12 @@ is_valid_value"""
return self.Constant(type=self, data=value, name=name)
def __call__(self, name=None):
"""Return a new `Variable` instance of Type `self`.
"""
Return a new `Variable` instance of Type `self`.
:Parameters:
- `name`: None or str
Parameters
----------
name : None or str
A pretty string for printing and debugging.
"""
......@@ -350,6 +404,7 @@ is_valid_value"""
Return True if a and b can be considered exactly equal.
a and b are assumed to be valid values of this Type.
"""
return a == b
......@@ -357,29 +412,38 @@ is_valid_value"""
"""
Return True if a and b can be considered approximately equal.
:param a: a potential value for a Variable of this Type.
:param b: a potential value for a Variable of this Type.
:rtype: Bool
This function is used by theano debugging tools to decide
whether two values are equivalent, admitting a certain amount
of numerical instability. For example, for floating-point
of numerical instability. For example, for floating-point
numbers this function should be an approximate comparison.
By default, this does an exact comparison.
Parameters
----------
a
A potential value for a Variable of this Type.
b
A potential value for a Variable of this Type.
Returns
-------
bool
"""
return self.values_eq(a, b)
# def get_shape_info(self, obj):
"""
Optional function. See TensorType().get_shape_info for definition
Optional function. See TensorType().get_shape_info for definition.
"""
# def get_size(self, shape_info):
"""
Optional function. See TensorType().get_size for definition
Optional function. See TensorType().get_size for definition.
"""
_nothing = """
......@@ -387,7 +451,8 @@ _nothing = """
class Type(object2, PureType, CLinkerType):
"""Convenience wrapper combining `PureType` and `CLinkerType`.
"""
Convenience wrapper combining `PureType` and `CLinkerType`.
Theano comes with several subclasses of such as:
......@@ -399,8 +464,8 @@ class Type(object2, PureType, CLinkerType):
But you are encouraged to write your own, as described in WRITEME.
The following following code illustrates the use of a Type
instance, here tensor.fvector:
The following code illustrates the use of a Type instance,
here tensor.fvector:
.. code-block:: python
......@@ -411,7 +476,7 @@ class Type(object2, PureType, CLinkerType):
c = tensor.fvector()
Whenever you create a symbolic variable in theano (technically,
`Variable`) it will contain a reference to a Type instance. That
`Variable`) it will contain a reference to a Type instance. That
reference is typically constant during the lifetime of the
Variable. Many variables can refer to a single Type instance, as
do b and c above. The Type instance defines the kind of value
......@@ -430,10 +495,13 @@ class Type(object2, PureType, CLinkerType):
class SingletonType(Type):
"""Convenient Base class for a Type subclass with no attributes
"""
Convenient Base class for a Type subclass with no attributes.
It saves having to implement __eq__ and __hash__.
It saves having to implement __eq__ and __hash__
"""
__instance = None
def __new__(cls):
......@@ -473,6 +541,7 @@ class Generic(SingletonType):
EXAMPLE of what this means, or when you would use this type.
WRITEME
"""
def filter(self, data, strict=False, allow_downcast=None):
......@@ -523,7 +592,20 @@ class CDataType(Type):
"""
Represents opaque C data to be passed around. The intent is to
ease passing arbitrary data between ops C code.
The constructor builds a type made to represent a C pointer in theano.
Parameters
----------
ctype
The type of the pointer (complete with the `*`).
freefunc
A function to call to free the pointer. This function must have a `void`
return and take a single pointer argument.
"""
import ctypes
if PY3:
_cdata_type = ctypes.py_object.from_address(
......@@ -534,15 +616,6 @@ class CDataType(Type):
del ctypes
def __init__(self, ctype, freefunc=None):
"""
Build a type made to represent a C pointer in theano.
:param ctype: The type of the pointer (complete with the `*`)
:param freefunc: a function to call to free the pointer. This
function must have a `void` return and take a
single pointer argument.
"""
assert isinstance(ctype, string_types)
self.ctype = ctype
if freefunc is not None:
......
"""
If you have two expressions
containing unification variables, these expressions can be "unified"
if there exists an assignment to all unification variables such that
the two expressions are equal. For instance, [5, A, B] and [A, C, 9]
can be unified if A=C=5 and B=9, yielding [5, 5, 9]. [5, [A, B]] and
[A, [1, 2]] cannot be unified because there is no value for A that
satisfies the constraints. That's useful for pattern matching.
If you have two expressions containing unification variables, these expressions
can be "unified" if there exists an assignment to all unification variables
such that the two expressions are equal.
For instance, [5, A, B] and [A, C, 9] can be unified if A=C=5 and B=9,
yielding [5, 5, 9].
[5, [A, B]] and [A, [1, 2]] cannot be unified because there is no value for A
that satisfies the constraints. That's useful for pattern matching.
"""
from __future__ import print_function
from copy import copy
......@@ -26,12 +28,15 @@ class Variable:
Behavior for unifying various types of variables should be added as
overloadings of the 'unify' function.
Note: there are two Variable classes in theano and this is the
more rarely used one.
Notes
-----
There are two Variable classes in theano and this is the more rarely used
one.
This class is used internally by the PatternSub optimization,
and possibly other subroutines that have to perform graph queries.
If that doesn't sound like what you're doing, the Variable class you
want is probably theano.gof.graph.Variable
want is probably theano.gof.graph.Variable.
"""
def __init__(self, name="?"):
self.name = name
......@@ -48,14 +53,18 @@ class Variable:
class FreeVariable(Variable):
"""
This Variable can take any value.
"""
pass
class BoundVariable(Variable):
"""
This Variable is bound to a value accessible via the value field.
"""
def __init__(self, name, value):
self.name = name
self.value = value
......@@ -65,7 +74,9 @@ class OrVariable(Variable):
"""
This Variable could be any value from a finite list of values,
accessible via the options field.
"""
def __init__(self, name, options):
self.name = name
self.options = options
......@@ -75,7 +86,9 @@ class NotVariable(Variable):
"""
This Variable can take any value but a finite amount of forbidden
values, accessible via the not_options field.
"""
def __init__(self, name, not_options):
self.name = name
self.not_options = not_options
......@@ -84,10 +97,12 @@ class NotVariable(Variable):
class VariableInList: # not a subclass of Variable
"""
This special kind of variable is matched against a list and unifies
an inner Variable to an OrVariable of the values in the list. For
example, if we unify VariableInList(FreeVariable('x')) to [1,2,3],
an inner Variable to an OrVariable of the values in the list.
For example, if we unify VariableInList(FreeVariable('x')) to [1,2,3],
the 'x' variable is unified to an OrVariable('?', [1,2,3]).
"""
def __init__(self, variable):
self.variable = variable
......@@ -120,13 +135,17 @@ class Unification:
"""
This class represents a possible unification of a group of variables
with each other or with tangible values.
"""
def __init__(self, inplace=False):
"""
Parameters
----------
inplace : bool
If inplace is False, the merge method will return a new Unification
that is independent from the previous one (which allows backtracking).
"""
"""
def __init__(self, inplace=False):
self.unif = {}
self.inplace = inplace
......@@ -134,6 +153,7 @@ class Unification:
"""
Links all the specified vars to a Variable that represents their
unification.
"""
if self.inplace:
U = self
......@@ -163,6 +183,7 @@ class Unification:
"""
For a variable v, returns a Variable that represents the tightest
set of possible values it can take.
"""
return self.unif.get(v, (v, None))[0]
......@@ -172,23 +193,25 @@ class Unification:
def unify_walk(a, b, U):
"""
unify_walk(a, b, U) returns an Unification where a and b are unified, given the
unification that already exists in the Unification U. If the unification fails,
it returns False.
unify_walk(a, b, U) returns an Unification where a and b are unified,
given the unification that already exists in the Unification U. If the
unification fails, it returns False.
There are two ways to expand the functionality of unify_walk. The first way is:
There are two ways to expand the functionality of unify_walk. The first way
is:
@comm_guard(type_of_a, type_of_b)
def unify_walk(a, b, U):
...
A function defined as such will be executed whenever the types of a and b
match the declaration. Note that comm_guard automatically guarantees that
your function is commutative: it will try to match the types of a, b or b, a.
It is recommended to define unify_walk in that fashion for new types of Variable
because different types of Variable interact a lot with each other, e.g.
when unifying an OrVariable with a NotVariable, etc. You can return the
special marker FALL_THROUGH to indicate that you want to relay execution
to the next match of the type signature. The definitions of unify_walk are tried
in the reverse order of their declaration.
your function is commutative: it will try to match the types of a, b or
b, a.
It is recommended to define unify_walk in that fashion for new types of
Variable because different types of Variable interact a lot with each other,
e.g. when unifying an OrVariable with a NotVariable, etc. You can return
the special marker FALL_THROUGH to indicate that you want to relay execution
to the next match of the type signature. The definitions of unify_walk are
tried in the reverse order of their declaration.
Another way is to override __unify_walk__ in an user-defined class.
......@@ -209,7 +232,8 @@ def unify_walk(a, b, U):
@comm_guard(FreeVariable, ANY_TYPE)
def unify_walk(fv, o, U):
"""
FreeV is unified to BoundVariable(other_object)
FreeV is unified to BoundVariable(other_object).
"""
v = BoundVariable("?", o)
return U.merge(v, fv)
......@@ -218,7 +242,8 @@ def unify_walk(fv, o, U):
@comm_guard(BoundVariable, ANY_TYPE)
def unify_walk(bv, o, U):
"""
The unification succeed iff BV.value == other_object
The unification succeed iff BV.value == other_object.
"""
if bv.value == o:
return U
......@@ -229,7 +254,8 @@ def unify_walk(bv, o, U):
@comm_guard(OrVariable, ANY_TYPE)
def unify_walk(ov, o, U):
"""
The unification succeeds iff other_object in OrV.options
The unification succeeds iff other_object in OrV.options.
"""
if o in ov.options:
v = BoundVariable("?", o)
......@@ -241,7 +267,8 @@ def unify_walk(ov, o, U):
@comm_guard(NotVariable, ANY_TYPE)
def unify_walk(nv, o, U):
"""
The unification succeeds iff other_object not in NV.not_options
The unification succeeds iff other_object not in NV.not_options.
"""
if o in nv.not_options:
return False
......@@ -254,6 +281,7 @@ def unify_walk(nv, o, U):
def unify_walk(fv, v, U):
"""
Both variables are unified.
"""
v = U[v]
return U.merge(v, fv)
......@@ -262,7 +290,8 @@ def unify_walk(fv, v, U):
@comm_guard(BoundVariable, Variable)
def unify_walk(bv, v, U):
"""
V is unified to BV.value
V is unified to BV.value.
"""
return unify_walk(v, bv.value, U)
......@@ -271,6 +300,7 @@ def unify_walk(bv, v, U):
def unify_walk(a, b, U):
"""
OrV(list1) == OrV(list2) == OrV(intersection(list1, list2))
"""
opt = intersection(a.options, b.options)
if not opt:
......@@ -286,6 +316,7 @@ def unify_walk(a, b, U):
def unify_walk(a, b, U):
"""
NV(list1) == NV(list2) == NV(union(list1, list2))
"""
opt = union(a.not_options, b.not_options)
v = NotVariable("?", opt)
......@@ -296,6 +327,7 @@ def unify_walk(a, b, U):
def unify_walk(o, n, U):
"""
OrV(list1) == NV(list2) == OrV(list1 \ list2)
"""
opt = [x for x in o.options if x not in n.not_options]
if not opt:
......@@ -311,6 +343,7 @@ def unify_walk(o, n, U):
def unify_walk(vil, l, U):
"""
Unifies VIL's inner Variable to OrV(list).
"""
v = vil.variable
ov = OrVariable("?", l)
......@@ -321,6 +354,7 @@ def unify_walk(vil, l, U):
def unify_walk(l1, l2, U):
"""
Tries to unify each corresponding pair of elements from l1 and l2.
"""
if len(l1) != len(l2):
return False
......@@ -335,6 +369,7 @@ def unify_walk(l1, l2, U):
def unify_walk(d1, d2, U):
"""
Tries to unify values of corresponding keys.
"""
for (k1, v1) in iteritems(d1):
if k1 in d2:
......@@ -349,6 +384,7 @@ def unify_walk(a, b, U):
"""
Checks for the existence of the __unify_walk__ method for one of
the objects.
"""
if (not isinstance(a, Variable) and
not isinstance(b, Variable) and
......@@ -364,6 +400,7 @@ def unify_walk(v, o, U):
This simply checks if the Var has an unification in U and uses it
instead of the Var. If the Var is already its tighest unification,
falls through.
"""
best_v = U[v]
if v is not best_v:
......@@ -447,6 +484,7 @@ def unify_merge(v, o, U):
This simply checks if the Var has an unification in U and uses it
instead of the Var. If the Var is already its tighest unification,
falls through.
"""
best_v = U[v]
if v is not best_v:
......
......@@ -11,13 +11,14 @@ from theano.compat import OrderedDict, PY3
def simple_extract_stack(f=None, limit=None):
"""This is traceback.extract_stack from python 2.7 with this
change:
"""
This is traceback.extract_stack from python 2.7 with this change:
- Comment the update of the cache
- Comment the update of the cache.
This is because this update cause an call to os.stat to get the
line content. This cause too much long on cluster.
"""
if f is None:
try:
......@@ -54,14 +55,22 @@ if sys.version_info[:2] > (3, 4):
def add_tag_trace(thing, user_line=1):
"""Add tag.trace to an node or variable.
"""
Add tag.trace to an node or variable.
The argument is returned after being affected (inplace).
:param thing: the object where we add .tag.trace
:param user_line: The max number of user line to keep.
:note: we alse use config.traceback.limit for the maximum number
of stack level we look.
Parameters
----------
thing
The object where we add .tag.trace.
user_line
The max number of user line to keep.
Notes
-----
We alse use config.traceback.limit for the maximum number of stack level
we look.
"""
limit = config.traceback.limit
......@@ -117,6 +126,7 @@ class MethodNotDefined(Exception):
When the user sees such an error, it is because an important interface
function has been left out of an implementation class.
"""
......@@ -159,8 +169,10 @@ class D:
def memoize(f):
"""Cache the return value for each tuple of arguments
(which must be hashable) """
"""
Cache the return value for each tuple of arguments (which must be hashable).
"""
cache = {}
def rval(*args, **kwargs):
......@@ -177,15 +189,16 @@ def memoize(f):
def deprecated(filename, msg=''):
"""Decorator which will print a warning message on the first call.
"""
Decorator which will print a warning message on the first call.
Use it like this::
Use it like this:
@deprecated('myfile', 'do something different...')
def fn_name(...)
...
And it will print::
And it will print:
WARNING myfile.fn_name deprecated. do something different...
......@@ -209,6 +222,7 @@ def uniq(seq):
Do not use set, this must always return the same value at the same index.
If we just exchange other values, but keep the same pattern of duplication,
we must keep the same order.
"""
# TODO: consider building a set out of seq so that the if condition
# is constant time -JB
......@@ -217,7 +231,8 @@ def uniq(seq):
def difference(seq1, seq2):
"""
Returns all elements in seq1 which are not in seq2: i.e ``seq1\seq2``
Returns all elements in seq1 which are not in seq2: i.e ``seq1\seq2``.
"""
try:
# try to use O(const * len(seq1)) algo
......@@ -252,6 +267,7 @@ def toposort(prereqs_d):
prereqs_d[x] contains all the elements that must come before x
in the ordering.
"""
# all1 = set(prereqs_d.keys())
......@@ -390,6 +406,7 @@ def type_guard(type1):
def flatten(a):
"""
Recursively flatten tuple, list and set in a list.
"""
if isinstance(a, (tuple, list, set)):
l = []
......@@ -412,9 +429,12 @@ def hist(coll):
def give_variables_names(variables):
""" Gives unique names to an iterable of variables. Modifies input.
"""
Gives unique names to an iterable of variables. Modifies input.
This function is idempotent."""
This function is idempotent.
"""
names = [var.name for var in variables]
h = hist(names)
......@@ -431,13 +451,17 @@ def give_variables_names(variables):
def remove(predicate, coll):
""" Return those items of collection for which predicate(item) is true.
"""
Return those items of collection for which predicate(item) is true.
Examples
--------
>>> from itertoolz import remove
>>> def even(x):
... return x % 2 == 0
>>> remove(even, [1, 2, 3, 4])
[1, 3]
"""
return [x for x in coll if not predicate(x)]
......@@ -466,12 +490,16 @@ else:
def hash_from_file(file_path):
"""Return the MD5 hash of a file."""
"""
Return the MD5 hash of a file.
"""
return hash_from_code(open(file_path, 'rb').read())
def hash_from_dict(d):
"""Work around the fact that dict are not hashable in python
"""
Work around the fact that dict are not hashable in python.
This request that all object have a sorted order that depend only
on the key of the object. We support only integer/float/string keys.
......@@ -479,8 +507,10 @@ def hash_from_dict(d):
Also, we transform values that are list into tuple as list are not
hashable.
:note: special case for OrderedDict, it use the order of the dict,
so the key don't need to be sortable.
Notes
-----
Special case for OrderedDict, it use the order of the dict,
so the key don't need to be sortable.
"""
if isinstance(d, OrderedDict):
......
"""
VMs that run Theano graph computations.
A VM is not actually different from a Linker, we just decided
VM was a better name at some point.
"""
from . import link
from collections import defaultdict
......@@ -142,7 +144,6 @@ def calculate_reallocate_info(order, fgraph, storage_map, compute_map_re,
class VM(object):
"""
A VM object's __call__ method evaluates a Theano program.
......@@ -155,33 +156,35 @@ class VM(object):
advantage of lazy computation, though they still produce the correct
output for lazy nodes.
Attributes:
call_counts - list of integers, one for each thunk. call_count[i] is the
number of times thunks[i] was called in the course of computations
performed by call_with_timers().
call_times - list of floats, one for each thunk. call_times[i] is
the amount of runtime spent on thunks[i] in the course of
computations performed by call_with_timers().
need_update_inputs - bool. True indicates that Function.__call__
must implement the feedback from output storage to input
storage. False means it *must not* repeat that feedback.
Parameters
----------
nodes
A list of nodes in toposort order.
thunks
A list of thunks to execute those nodes, in toposort order.
pre_call_clear
A list of containers to empty at the beginning of each call.
Attributes
----------
call_counts
List of integers, one for each thunk. call_count[i] is the number of
times thunks[i] was called in the course of computations performed by
call_with_timers().
call_times
List of floats, one for each thunk. call_times[i] is the amount of
runtime spent on thunks[i] in the course of computations performed by
call_with_timers().
need_update_inputs : bool
True indicates that Function.__call__ must implement the feedback from
output storage to input storage. False means it *must not* repeat that
feedback.
"""
def __init__(self, nodes, thunks, pre_call_clear):
"""
Allocate a virtual machine.
nodes - a list of nodes in toposort order
thunks - a list of thunks to execute those nodes, in toposort order
pre_call_clear - a list of containers to empty at the beginning of each
call.
"""
if len(nodes) != len(thunks):
raise ValueError()
self.nodes = nodes
......@@ -202,6 +205,7 @@ class VM(object):
Postcondition - all output variables have been computed. VMs vary in
what exactly this means and how it is done.
"""
raise NotImplementedError('override me')
......@@ -212,6 +216,7 @@ class VM(object):
Free internal variables and outputs. Essentially, free as much memory
as possible without intefering with the ability to evaluate subsequent
calls.
"""
raise NotImplementedError('override me')
......@@ -247,10 +252,10 @@ class VM(object):
class Loop(VM):
"""
Unconditional start-to-finish program execution in Python.
No garbage collection is allowed on intermediate results.
"""
# Some other part of Theano query that information
allow_gc = False
......@@ -280,10 +285,10 @@ class Loop(VM):
class LoopGC(VM):
"""
Unconditional start-to-finish program execution in Python.
Garbage collection is possible on intermediate results.
"""
def __init__(self, nodes, thunks, pre_call_clear, post_thunk_clear):
......@@ -327,7 +332,6 @@ class LoopGC(VM):
class Stack(VM):
"""
Finish-to-start evalution order of thunks.
......@@ -399,9 +403,11 @@ class Stack(VM):
raise ValueError("Must set dependencies when using GC")
def run_thunk_of_node(self, node):
"""Run the thunk corresponding to Apply instance `node`
"""
Run the thunk corresponding to Apply instance `node`.
Calls self.callback if it is defined.
"""
idx = self.node_idx[node]
t0 = time.time()
......@@ -683,34 +689,36 @@ except (OSError, theano.gof.cmodule.MissingGXX) as e:
class VM_Linker(link.LocalLinker):
"""
Class that satisfies the Linker interface by acting as a VM factory.
Parameters
----------
allow_gc
Force the virtual machine to clean up unnecessary
references, in order to allow garbage collection on
intermediate values during computation of a function.
If None use as default the value of the Theano flag allow_gc.
use_cloop
Use the C-based virtual machine if possible
callback
A callable object to call after each call to a thunk within
the virtual machine. It will be called with four arguments called
'node', 'thunk', 'storage_map', and 'compute_map'.
lazy
Useful only when use_cloop is False. When lazy is None, use the
theano flag vm.lazy value. Then if we have a None (default) we auto
detect if lazy evaluation is needed and use the apropriate
version. If lazy is True or False, we force the version used
between Loop/LoopGC and Stack.
c_thunks
If None or True, don't change the default. If False,
don't compile c code for the thunks.
"""
def __init__(self, allow_gc=None, use_cloop=False, callback=None,
lazy=None, schedule=None, c_thunks=None):
"""
allow_gc - force the virtual machine to clean up unnecessary
references, in order to allow garbage collection on
intermediate values during computation of a function.
If None use as default the value of the Theano flag allow_gc.
use_cloop - use the C-based virtual machine if possible
callback - a callable object to call after each call to a thunk within
the virtual machine. It will be called with four arguments called
'node', 'thunk', 'storage_map', and 'compute_map'.
lazy - Useful only when use_cloop is False. When lazy is None, use the
theano flag vm.lazy value. Then if we have a None (default) we auto
detect if lazy evaluation is needed and use the apropriate
version. If lazy is True or False, we force the version used
between Loop/LoopGC and Stack.
c_thunks - If None or True, don't change the default. If False,
don't compile c code for the thunks.
"""
# Note: if more parameters are added to __init__, make sure to forward
# them in the "type(self)(...)" call in the "accept" method below.
if allow_gc is None:
......@@ -727,13 +735,20 @@ class VM_Linker(link.LocalLinker):
def accept(self, fgraph, no_recycling=None):
"""
:param fgraph: a PerformLinker can have accepted one FunctionGraph
instance at a time.
:param no_recycling: WRITEME
Parameters
----------
fgraph
A PerformLinker can have accepted one FunctionGraph instance
at a time.
no_recycling
WRITEME
Returns
-------
Self if fgraph is the first FunctionGraph that has ever been
associated to self, else, a new VM_Linker associated to fgraph.
:returns: self if fgraph is the first FunctionGraph that has ever been
associated to self, else, a new VM_Linker associated to fgraph.
"""
if (config.profile and
hasattr(theano, 'sandbox') and
......@@ -779,27 +794,24 @@ class VM_Linker(link.LocalLinker):
Returns dict: variable K -> list of variables [v1, v2, v3, ...]
for each K in variables.
The variables v1, v2, ... are the full set of variables that depend
directly on K. When we know that none of them will need to be
computed, we know that:
* K will not need to be computed
* if K is already computed, it can be released for garbage collection
* K will not need to be computed.
* If K is already computed, it can be released for garbage collection.
Parameters
----------
variables - iterable over the variables used in a graph computation.
variables
Iterable over the variables used in a graph computation.
Notes
-----
It doesn't take care of the view_map/destroy_map. So it means it relies
on Python gc no to free the object real storage.
N.B. gc means garbage collection
Note
----
It don't take care of the view_map/destroy_map. So
it mean it rely on Python gc to don't free the object real
storage.
"""
dependencies = {}
for k in variables:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论