提交 1ab81f84 authored 作者: James Bergstra's avatar James Bergstra

commenting and minor refactors to destroyhandler, function_module, env

上级 60e767c2
"""Convenient driver of graph construction, optimization, and linking."""
"""Driver of graph construction, optimization, and linking.
"""
import copy_reg
import cPickle
......@@ -102,6 +104,15 @@ DUPLICATE = ['DUPLICATE'] # unique id object used as a placeholder for duplicate
class Function(object):
"""
Type of the functions returned by theano.function or theano.FunctionMaker.create.
`Function` is the callable object that does computation. It has the storage of inputs and
outputs, performs the packing and unpacking of inputs and return values. It implements the
square-bracket indexing so that you can look up the value of a symbolic node.
When a function is copied, this instance is duplicated. Contrast with self.maker
(instance of `FunctionMaker`) that is shared between copies.
"""
def __init__(self, fn, input_storage, output_storage, indices, outputs, defaults, unpack_single, maker):
......@@ -394,6 +405,13 @@ class SanityCheckFunction(Function):
NODEFAULT = ['NODEFAULT']
class FunctionMaker(object):
"""`FunctionMaker` is the class to `create` `Function` instances.
This class has the env, the optimizer, and the linker. When copying a `Function`, there is
no need to duplicate the `FunctionMaker` instance. Deepcopy still copies both, which can
result in re-compilation.
"""
@staticmethod
def wrap_in(input):
......@@ -432,18 +450,20 @@ class FunctionMaker(object):
else:
raise TypeError("Unknown output type: %s (%s)", type(output), output)
def __init__(self, inputs, outputs, mode = 'FAST_RUN', accept_inplace = False, function_builder = Function):
def __init__(self, inputs, outputs,
mode = 'FAST_COMPILE', accept_inplace = False, function_builder = Function):
"""
Create a FunctionMaker for the specified inputs, outputs and mode.
@param inputs: a list of SymbolicInput instances
@param outputs: a list of SymbolicOutput instances
outputs may also be a single Result (not a list), in which
case the functions produced by FunctionMaker will return
their output value directly
@param mode: a Mode instance telling FunctionMaker how to optimize and link
@param accept_inplace: True iff it is acceptable to have inplace operations
in the graph from the inputs to the outputs
:type inputs: a list of SymbolicInput instances
:type outputs: a list of SymbolicOutput instances
outputs may also be a single Result (not a list), in which
case the functions produced by FunctionMaker will return
their output value directly
:param mode: a Mode instance telling FunctionMaker how to optimize and link
:param accept_inplace: True iff it is acceptable to have inplace operations
in the graph from the inputs to the outputs
"""
......@@ -648,19 +668,7 @@ def function(inputs, outputs, mode='FAST_RUN', accept_inplace = False):
optimization phase (default is False)
Every element of the input list will be upgraded to an `In` instance if necessary,
using the following rules:
- a `Result` instance r will be upgraded like `In`(r)
- a tuple (name, r) will be `In`(r, name=name)
- a tuple (r, val) will be `In`(r, value=value, autoname=True)
- a tuple ((r,up), val) will be `In`(r, value=value, update=up, autoname=True)
- a tuple (name, r, val) will be `In`(r, name=name, value=value)
- a tuple (name, (r,up), val) will be `In`(r, name=name, value=val, update=up, autoname=True)
using the rules implemented by the `convert_function_input` function.
Similarly, every element of the output list will be upgraded to an
`Out` instance if necessary:
......@@ -681,54 +689,7 @@ def function(inputs, outputs, mode='FAST_RUN', accept_inplace = False):
"""
def wrap_in(input):
if isinstance(input, (SymbolicInput, SymbolicInputKit)):
return input
elif isinstance(input, gof.Result):
return In(input)
elif isinstance(input, (list, tuple)):
orig = input
if not input:
raise TypeError("Nonsensical input specification: %s" % input)
if isinstance(input[0], str):
name = input[0]
input = input[1:]
else:
name = None
if isinstance(input[0], (list, tuple)):
if len(input[0]) != 2 or len(input) != 2:
raise TypeError("Invalid input syntax: %s (check documentation or use an In instance)" % orig)
(result, update), value = input
elif isinstance(input[0], gof.Result):
if len(input) == 1:
result, update, value = input[0], None, None
elif len(input) == 2:
(result, value), update = input, None
else:
raise TypeError("Invalid input syntax: %s (check documentation or use an In instance)" % orig)
elif isinstance(input[0], (SymbolicInput, SymbolicInputKit)):
if len(input) == 1:
return input[0]
elif len(input) == 2:
input, value = input
if name is not None: input.name = name
input.value = value
return input
else:
raise TypeError("The input specification is not valid: %s" % input)
if not isinstance(result, gof.Result):
raise TypeError("Unknown input type: %s, expected Result instance" % type(result), result)
if update is not None and not isinstance(update, gof.Result):
raise TypeError("Unknown update type: %s, expected Result instance" % type(update), update)
if value is not None and isinstance(value, (gof.Result, SymbolicInput)):
raise TypeError("The value for input %s should not be a Result or SymbolicInput instance (got: %s)" % (result, value))
return In(result, name=name, value=value, update=update)
else:
raise TypeError("Unknown input type: %s, expected Result instance" % type(input), input)
inputs = map(wrap_in, inputs)
inputs = map(convert_function_input, inputs)
outputs = map(FunctionMaker.wrap_out, outputs) if isinstance(outputs, (list, tuple)) else FunctionMaker.wrap_out(outputs)
defaults = [getattr(input, 'value', None) for input in inputs]
......@@ -750,9 +711,76 @@ def function(inputs, outputs, mode='FAST_RUN', accept_inplace = False):
maker1 = FunctionMaker(inputs, outputs, mode[0], accept_inplace = accept_inplace, function_builder = builder)
fn = maker1.create(defaults)
else:
fn = FunctionMaker(inputs, outputs, mode, accept_inplace = accept_inplace).create(defaults)
Maker = getattr(mode, 'function_maker', FunctionMaker)
fn = Maker(inputs, outputs, mode, accept_inplace = accept_inplace).create(defaults)
return fn
def convert_function_input(input):
"""
Upgrade a input shortcut to an In instance.
The rules for upgrading are as follows:
- a `Result` instance r will be upgraded like `In`(r)
- a tuple (name, r) will be `In`(r, name=name)
- a tuple (r, val) will be `In`(r, value=value, autoname=True)
- a tuple ((r,up), val) will be `In`(r, value=value, update=up, autoname=True)
- a tuple (name, r, val) will be `In`(r, name=name, value=value)
- a tuple (name, (r,up), val) will be `In`(r, name=name, value=val, update=up, autoname=True)
"""
if isinstance(input, (SymbolicInput, SymbolicInputKit)):
return input
elif isinstance(input, gof.Result):
return In(input)
elif isinstance(input, (list, tuple)):
orig = input
if not input:
raise TypeError("Nonsensical input specification: %s" % input)
if isinstance(input[0], str):
name = input[0]
input = input[1:]
else:
name = None
if isinstance(input[0], (list, tuple)):
if len(input[0]) != 2 or len(input) != 2:
raise TypeError("Invalid input syntax: %s (check documentation or use an In instance)" % orig)
(result, update), value = input
elif isinstance(input[0], gof.Result):
if len(input) == 1:
result, update, value = input[0], None, None
elif len(input) == 2:
(result, value), update = input, None
else:
raise TypeError("Invalid input syntax: %s (check documentation or use an In instance)" % orig)
elif isinstance(input[0], (SymbolicInput, SymbolicInputKit)):
if len(input) == 1:
return input[0]
elif len(input) == 2:
input, value = input
if name is not None: input.name = name
input.value = value
return input
else:
raise TypeError("The input specification is not valid: %s" % input)
if not isinstance(result, gof.Result):
raise TypeError("Unknown input type: %s, expected Result instance" % type(result), result)
if update is not None and not isinstance(update, gof.Result):
raise TypeError("Unknown update type: %s, expected Result instance" % type(update), update)
if value is not None and isinstance(value, (gof.Result, SymbolicInput)):
raise TypeError("The value for input %s should not be a Result or SymbolicInput instance (got: %s)" % (result, value))
return In(result, name=name, value=value, update=update)
else:
raise TypeError("Unknown input type: %s, expected Result instance" % type(input), input)
......@@ -10,14 +10,15 @@ class ProtocolError(Exception):
"""WRITEME"""
pass
class DestroyHandler(toolbox.Bookkeeper):
class DestroyHandler(object):
"""WRITEME"""
def __init__(self):
def __init__(self, do_imports_on_attach=True):
self.map = {}
self.do_imports_on_attach=do_imports_on_attach
def on_attach(self, env):
dh = self.map.setdefault(env, DestroyHandlerHelper2())
dh = self.map.setdefault(env, DestroyHandlerHelper2(do_imports_on_attach=self.do_imports_on_attach))
dh.on_attach(env)
def on_detach(self, env):
......@@ -69,8 +70,9 @@ def get_impact(root, view_o):
class DestroyHandlerHelper2(toolbox.Bookkeeper):
"""WRITEME"""
def __init__(self):
def __init__(self, do_imports_on_attach=True):
self.env = None
self.do_imports_on_attach = do_imports_on_attach
def on_attach(self, env):
#boilerplate from old implementation
......@@ -99,7 +101,8 @@ class DestroyHandlerHelper2(toolbox.Bookkeeper):
self.stale_droot = True
self.debug_all_apps = set()
toolbox.Bookkeeper.on_attach(self, env)
if self.do_imports_on_attach:
toolbox.Bookkeeper.on_attach(self, env)
def refresh_droot_impact(self):
if self.stale_droot:
......@@ -153,6 +156,7 @@ class DestroyHandlerHelper2(toolbox.Bookkeeper):
if app in self.debug_all_apps: raise ProtocolError("double import")
self.debug_all_apps.add(app)
#print 'DH IMPORT', app, id(app), id(self), len(self.debug_all_apps)
# If it's a destructive op, add it to our watch list
if getattr(app.op, 'destroy_map', {}):
......
......@@ -29,16 +29,60 @@ class Env(utils.object2):
It can also be "extended" using env.extend(some_object). See the
toolbox and ext modules for common extensions.
Features added with the`extend` function can handle the following events:
- feature.on_attach(env)
Called by extend. The feature has great freedom in what
it can do with the env: it may, for example, add methods
to it dynicamically.
- feature.on_detach(env)
Called by remove_feature(feature). Should remove any dynamically-added
functionality that it installed into the env.
- feature.on_import(env, node)*
Called whenever a node is imported into env, which is
just before the node is actually connected to the graph.
- feature.on_prune(env, node)*
Called whenever a node is pruned (removed) from the env,
after it is disconnected from the graph.
- feature.on_change_input(env, node, i, r, new_r)*
Called whenever node.inputs[i] is changed from r to new_r.
At the moment the callback is done, the change has already
taken place.
- feature.orderings(env)
Called by toposort. It should return a dictionary of
{node: predecessors} where predecessors is a list of
nodes that should be computed before the key node.
* If you raise an exception in the functions marked with an
asterisk, the state of the graph might be inconsistent.
- feature.on_setup_node(env, node):
WRITEME
- feature.on_setup_result(env, result):
WRITEME
"""
### Special ###
# TODO: document which things that features can do to the env
def __init__(self, inputs, outputs):
def __init__(self, inputs, outputs, features=[]):
"""
Create an Env which operates on the subgraph bound by the inputs and outputs
sets.
WRITEME
This class keeps a pointer to the inputs and outputs, and also modifies them.
#TODO: document what variables are[not] set in the env when a feature is added via the
constructor. How constructed is the env?
"""
self._features = []
......@@ -50,6 +94,11 @@ class Env(utils.object2):
self.results = set()
self.inputs = list(inputs)
self.outputs = outputs
for f in features:
self.extend(f)
for input in self.inputs:
if input.owner is not None:
raise ValueError("One of the provided inputs is the output of an already existing node. " \
......@@ -58,7 +107,6 @@ class Env(utils.object2):
self.results.add(input)
self.__import_r__(outputs)
self.outputs = outputs
for i, output in enumerate(outputs):
output.clients.append(('output', i))
......@@ -74,6 +122,7 @@ class Env(utils.object2):
raise Exception("%s is already owned by another env" % r)
r.env = self
r.clients = []
#self.execute_callbacks('on_setup_result', r)
def __setup_node__(self, node):
# sets up node so it belongs to this env
......@@ -81,6 +130,7 @@ class Env(utils.object2):
raise Exception("%s is already owned by another env" % node)
node.env = self
node.deps = {}
#self.execute_callbacks('on_setup_node', node)
def disown(self):
""" WRITEME
......@@ -171,6 +221,7 @@ class Env(utils.object2):
raise TypeError("An input of the graph was not provided and not given a value", r)
for node in new_nodes:
assert node not in self.nodes
self.__setup_node__(node)
self.nodes.add(node)
for output in node.outputs:
......@@ -284,29 +335,6 @@ class Env(utils.object2):
Adds a feature to this env. The feature may define one
or more of the following methods:
- feature.on_attach(env)
Called by extend. The feature has great freedom in what
it can do with the env: it may, for example, add methods
to it dynicamically.
- feature.on_detach(env)
Called by remove_feature(feature).
- feature.on_import(env, node)*
Called whenever a node is imported into env, which is
just before the node is actually connected to the graph.
- feature.on_prune(env, node)*
Called whenever a node is pruned (removed) from the env,
after it is disconnected from the graph.
- feature.on_change_input(env, node, i, r, new_r)*
Called whenever node.inputs[i] is changed from r to new_r.
At the moment the callback is done, the change has already
taken place.
- feature.orderings(env)
Called by toposort. It should return a dictionary of
{node: predecessors} where predecessors is a list of
nodes that should be computed before the key node.
* If you raise an exception in the functions marked with an
asterisk, the state of the graph might be inconsistent.
"""
if feature in self._features:
return # the feature is already present
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论