提交 1ab81f84 authored 作者: James Bergstra's avatar James Bergstra

commenting and minor refactors to destroyhandler, function_module, env

上级 60e767c2
"""Convenient driver of graph construction, optimization, and linking.""" """Driver of graph construction, optimization, and linking.
"""
import copy_reg import copy_reg
import cPickle import cPickle
...@@ -102,6 +104,15 @@ DUPLICATE = ['DUPLICATE'] # unique id object used as a placeholder for duplicate ...@@ -102,6 +104,15 @@ DUPLICATE = ['DUPLICATE'] # unique id object used as a placeholder for duplicate
class Function(object): class Function(object):
""" """
Type of the functions returned by theano.function or theano.FunctionMaker.create. Type of the functions returned by theano.function or theano.FunctionMaker.create.
`Function` is the callable object that does computation. It has the storage of inputs and
outputs, performs the packing and unpacking of inputs and return values. It implements the
square-bracket indexing so that you can look up the value of a symbolic node.
When a function is copied, this instance is duplicated. Contrast with self.maker
(instance of `FunctionMaker`) that is shared between copies.
""" """
def __init__(self, fn, input_storage, output_storage, indices, outputs, defaults, unpack_single, maker): def __init__(self, fn, input_storage, output_storage, indices, outputs, defaults, unpack_single, maker):
...@@ -394,6 +405,13 @@ class SanityCheckFunction(Function): ...@@ -394,6 +405,13 @@ class SanityCheckFunction(Function):
NODEFAULT = ['NODEFAULT'] NODEFAULT = ['NODEFAULT']
class FunctionMaker(object): class FunctionMaker(object):
"""`FunctionMaker` is the class to `create` `Function` instances.
This class has the env, the optimizer, and the linker. When copying a `Function`, there is
no need to duplicate the `FunctionMaker` instance. Deepcopy still copies both, which can
result in re-compilation.
"""
@staticmethod @staticmethod
def wrap_in(input): def wrap_in(input):
...@@ -432,18 +450,20 @@ class FunctionMaker(object): ...@@ -432,18 +450,20 @@ class FunctionMaker(object):
else: else:
raise TypeError("Unknown output type: %s (%s)", type(output), output) raise TypeError("Unknown output type: %s (%s)", type(output), output)
def __init__(self, inputs, outputs, mode = 'FAST_RUN', accept_inplace = False, function_builder = Function): def __init__(self, inputs, outputs,
mode = 'FAST_COMPILE', accept_inplace = False, function_builder = Function):
""" """
Create a FunctionMaker for the specified inputs, outputs and mode. :type inputs: a list of SymbolicInput instances
@param inputs: a list of SymbolicInput instances :type outputs: a list of SymbolicOutput instances
@param outputs: a list of SymbolicOutput instances outputs may also be a single Result (not a list), in which
outputs may also be a single Result (not a list), in which case the functions produced by FunctionMaker will return
case the functions produced by FunctionMaker will return their output value directly
their output value directly
@param mode: a Mode instance telling FunctionMaker how to optimize and link :param mode: a Mode instance telling FunctionMaker how to optimize and link
@param accept_inplace: True iff it is acceptable to have inplace operations
in the graph from the inputs to the outputs :param accept_inplace: True iff it is acceptable to have inplace operations
in the graph from the inputs to the outputs
""" """
...@@ -648,19 +668,7 @@ def function(inputs, outputs, mode='FAST_RUN', accept_inplace = False): ...@@ -648,19 +668,7 @@ def function(inputs, outputs, mode='FAST_RUN', accept_inplace = False):
optimization phase (default is False) optimization phase (default is False)
Every element of the input list will be upgraded to an `In` instance if necessary, Every element of the input list will be upgraded to an `In` instance if necessary,
using the following rules: using the rules implemented by the `convert_function_input` function.
- a `Result` instance r will be upgraded like `In`(r)
- a tuple (name, r) will be `In`(r, name=name)
- a tuple (r, val) will be `In`(r, value=value, autoname=True)
- a tuple ((r,up), val) will be `In`(r, value=value, update=up, autoname=True)
- a tuple (name, r, val) will be `In`(r, name=name, value=value)
- a tuple (name, (r,up), val) will be `In`(r, name=name, value=val, update=up, autoname=True)
Similarly, every element of the output list will be upgraded to an Similarly, every element of the output list will be upgraded to an
`Out` instance if necessary: `Out` instance if necessary:
...@@ -681,54 +689,7 @@ def function(inputs, outputs, mode='FAST_RUN', accept_inplace = False): ...@@ -681,54 +689,7 @@ def function(inputs, outputs, mode='FAST_RUN', accept_inplace = False):
""" """
def wrap_in(input): inputs = map(convert_function_input, inputs)
if isinstance(input, (SymbolicInput, SymbolicInputKit)):
return input
elif isinstance(input, gof.Result):
return In(input)
elif isinstance(input, (list, tuple)):
orig = input
if not input:
raise TypeError("Nonsensical input specification: %s" % input)
if isinstance(input[0], str):
name = input[0]
input = input[1:]
else:
name = None
if isinstance(input[0], (list, tuple)):
if len(input[0]) != 2 or len(input) != 2:
raise TypeError("Invalid input syntax: %s (check documentation or use an In instance)" % orig)
(result, update), value = input
elif isinstance(input[0], gof.Result):
if len(input) == 1:
result, update, value = input[0], None, None
elif len(input) == 2:
(result, value), update = input, None
else:
raise TypeError("Invalid input syntax: %s (check documentation or use an In instance)" % orig)
elif isinstance(input[0], (SymbolicInput, SymbolicInputKit)):
if len(input) == 1:
return input[0]
elif len(input) == 2:
input, value = input
if name is not None: input.name = name
input.value = value
return input
else:
raise TypeError("The input specification is not valid: %s" % input)
if not isinstance(result, gof.Result):
raise TypeError("Unknown input type: %s, expected Result instance" % type(result), result)
if update is not None and not isinstance(update, gof.Result):
raise TypeError("Unknown update type: %s, expected Result instance" % type(update), update)
if value is not None and isinstance(value, (gof.Result, SymbolicInput)):
raise TypeError("The value for input %s should not be a Result or SymbolicInput instance (got: %s)" % (result, value))
return In(result, name=name, value=value, update=update)
else:
raise TypeError("Unknown input type: %s, expected Result instance" % type(input), input)
inputs = map(wrap_in, inputs)
outputs = map(FunctionMaker.wrap_out, outputs) if isinstance(outputs, (list, tuple)) else FunctionMaker.wrap_out(outputs) outputs = map(FunctionMaker.wrap_out, outputs) if isinstance(outputs, (list, tuple)) else FunctionMaker.wrap_out(outputs)
defaults = [getattr(input, 'value', None) for input in inputs] defaults = [getattr(input, 'value', None) for input in inputs]
...@@ -750,9 +711,76 @@ def function(inputs, outputs, mode='FAST_RUN', accept_inplace = False): ...@@ -750,9 +711,76 @@ def function(inputs, outputs, mode='FAST_RUN', accept_inplace = False):
maker1 = FunctionMaker(inputs, outputs, mode[0], accept_inplace = accept_inplace, function_builder = builder) maker1 = FunctionMaker(inputs, outputs, mode[0], accept_inplace = accept_inplace, function_builder = builder)
fn = maker1.create(defaults) fn = maker1.create(defaults)
else: else:
fn = FunctionMaker(inputs, outputs, mode, accept_inplace = accept_inplace).create(defaults) Maker = getattr(mode, 'function_maker', FunctionMaker)
fn = Maker(inputs, outputs, mode, accept_inplace = accept_inplace).create(defaults)
return fn return fn
def convert_function_input(input):
"""
Upgrade a input shortcut to an In instance.
The rules for upgrading are as follows:
- a `Result` instance r will be upgraded like `In`(r)
- a tuple (name, r) will be `In`(r, name=name)
- a tuple (r, val) will be `In`(r, value=value, autoname=True)
- a tuple ((r,up), val) will be `In`(r, value=value, update=up, autoname=True)
- a tuple (name, r, val) will be `In`(r, name=name, value=value)
- a tuple (name, (r,up), val) will be `In`(r, name=name, value=val, update=up, autoname=True)
"""
if isinstance(input, (SymbolicInput, SymbolicInputKit)):
return input
elif isinstance(input, gof.Result):
return In(input)
elif isinstance(input, (list, tuple)):
orig = input
if not input:
raise TypeError("Nonsensical input specification: %s" % input)
if isinstance(input[0], str):
name = input[0]
input = input[1:]
else:
name = None
if isinstance(input[0], (list, tuple)):
if len(input[0]) != 2 or len(input) != 2:
raise TypeError("Invalid input syntax: %s (check documentation or use an In instance)" % orig)
(result, update), value = input
elif isinstance(input[0], gof.Result):
if len(input) == 1:
result, update, value = input[0], None, None
elif len(input) == 2:
(result, value), update = input, None
else:
raise TypeError("Invalid input syntax: %s (check documentation or use an In instance)" % orig)
elif isinstance(input[0], (SymbolicInput, SymbolicInputKit)):
if len(input) == 1:
return input[0]
elif len(input) == 2:
input, value = input
if name is not None: input.name = name
input.value = value
return input
else:
raise TypeError("The input specification is not valid: %s" % input)
if not isinstance(result, gof.Result):
raise TypeError("Unknown input type: %s, expected Result instance" % type(result), result)
if update is not None and not isinstance(update, gof.Result):
raise TypeError("Unknown update type: %s, expected Result instance" % type(update), update)
if value is not None and isinstance(value, (gof.Result, SymbolicInput)):
raise TypeError("The value for input %s should not be a Result or SymbolicInput instance (got: %s)" % (result, value))
return In(result, name=name, value=value, update=update)
else:
raise TypeError("Unknown input type: %s, expected Result instance" % type(input), input)
...@@ -10,14 +10,15 @@ class ProtocolError(Exception): ...@@ -10,14 +10,15 @@ class ProtocolError(Exception):
"""WRITEME""" """WRITEME"""
pass pass
class DestroyHandler(toolbox.Bookkeeper): class DestroyHandler(object):
"""WRITEME""" """WRITEME"""
def __init__(self): def __init__(self, do_imports_on_attach=True):
self.map = {} self.map = {}
self.do_imports_on_attach=do_imports_on_attach
def on_attach(self, env): def on_attach(self, env):
dh = self.map.setdefault(env, DestroyHandlerHelper2()) dh = self.map.setdefault(env, DestroyHandlerHelper2(do_imports_on_attach=self.do_imports_on_attach))
dh.on_attach(env) dh.on_attach(env)
def on_detach(self, env): def on_detach(self, env):
...@@ -69,8 +70,9 @@ def get_impact(root, view_o): ...@@ -69,8 +70,9 @@ def get_impact(root, view_o):
class DestroyHandlerHelper2(toolbox.Bookkeeper): class DestroyHandlerHelper2(toolbox.Bookkeeper):
"""WRITEME""" """WRITEME"""
def __init__(self): def __init__(self, do_imports_on_attach=True):
self.env = None self.env = None
self.do_imports_on_attach = do_imports_on_attach
def on_attach(self, env): def on_attach(self, env):
#boilerplate from old implementation #boilerplate from old implementation
...@@ -99,7 +101,8 @@ class DestroyHandlerHelper2(toolbox.Bookkeeper): ...@@ -99,7 +101,8 @@ class DestroyHandlerHelper2(toolbox.Bookkeeper):
self.stale_droot = True self.stale_droot = True
self.debug_all_apps = set() self.debug_all_apps = set()
toolbox.Bookkeeper.on_attach(self, env) if self.do_imports_on_attach:
toolbox.Bookkeeper.on_attach(self, env)
def refresh_droot_impact(self): def refresh_droot_impact(self):
if self.stale_droot: if self.stale_droot:
...@@ -153,6 +156,7 @@ class DestroyHandlerHelper2(toolbox.Bookkeeper): ...@@ -153,6 +156,7 @@ class DestroyHandlerHelper2(toolbox.Bookkeeper):
if app in self.debug_all_apps: raise ProtocolError("double import") if app in self.debug_all_apps: raise ProtocolError("double import")
self.debug_all_apps.add(app) self.debug_all_apps.add(app)
#print 'DH IMPORT', app, id(app), id(self), len(self.debug_all_apps)
# If it's a destructive op, add it to our watch list # If it's a destructive op, add it to our watch list
if getattr(app.op, 'destroy_map', {}): if getattr(app.op, 'destroy_map', {}):
......
...@@ -29,16 +29,60 @@ class Env(utils.object2): ...@@ -29,16 +29,60 @@ class Env(utils.object2):
It can also be "extended" using env.extend(some_object). See the It can also be "extended" using env.extend(some_object). See the
toolbox and ext modules for common extensions. toolbox and ext modules for common extensions.
Features added with the`extend` function can handle the following events:
- feature.on_attach(env)
Called by extend. The feature has great freedom in what
it can do with the env: it may, for example, add methods
to it dynicamically.
- feature.on_detach(env)
Called by remove_feature(feature). Should remove any dynamically-added
functionality that it installed into the env.
- feature.on_import(env, node)*
Called whenever a node is imported into env, which is
just before the node is actually connected to the graph.
- feature.on_prune(env, node)*
Called whenever a node is pruned (removed) from the env,
after it is disconnected from the graph.
- feature.on_change_input(env, node, i, r, new_r)*
Called whenever node.inputs[i] is changed from r to new_r.
At the moment the callback is done, the change has already
taken place.
- feature.orderings(env)
Called by toposort. It should return a dictionary of
{node: predecessors} where predecessors is a list of
nodes that should be computed before the key node.
* If you raise an exception in the functions marked with an
asterisk, the state of the graph might be inconsistent.
- feature.on_setup_node(env, node):
WRITEME
- feature.on_setup_result(env, result):
WRITEME
""" """
### Special ### ### Special ###
# TODO: document which things that features can do to the env
def __init__(self, inputs, outputs): def __init__(self, inputs, outputs, features=[]):
""" """
Create an Env which operates on the subgraph bound by the inputs and outputs Create an Env which operates on the subgraph bound by the inputs and outputs
sets. sets.
WRITEME This class keeps a pointer to the inputs and outputs, and also modifies them.
#TODO: document what variables are[not] set in the env when a feature is added via the
constructor. How constructed is the env?
""" """
self._features = [] self._features = []
...@@ -50,6 +94,11 @@ class Env(utils.object2): ...@@ -50,6 +94,11 @@ class Env(utils.object2):
self.results = set() self.results = set()
self.inputs = list(inputs) self.inputs = list(inputs)
self.outputs = outputs
for f in features:
self.extend(f)
for input in self.inputs: for input in self.inputs:
if input.owner is not None: if input.owner is not None:
raise ValueError("One of the provided inputs is the output of an already existing node. " \ raise ValueError("One of the provided inputs is the output of an already existing node. " \
...@@ -58,7 +107,6 @@ class Env(utils.object2): ...@@ -58,7 +107,6 @@ class Env(utils.object2):
self.results.add(input) self.results.add(input)
self.__import_r__(outputs) self.__import_r__(outputs)
self.outputs = outputs
for i, output in enumerate(outputs): for i, output in enumerate(outputs):
output.clients.append(('output', i)) output.clients.append(('output', i))
...@@ -74,6 +122,7 @@ class Env(utils.object2): ...@@ -74,6 +122,7 @@ class Env(utils.object2):
raise Exception("%s is already owned by another env" % r) raise Exception("%s is already owned by another env" % r)
r.env = self r.env = self
r.clients = [] r.clients = []
#self.execute_callbacks('on_setup_result', r)
def __setup_node__(self, node): def __setup_node__(self, node):
# sets up node so it belongs to this env # sets up node so it belongs to this env
...@@ -81,6 +130,7 @@ class Env(utils.object2): ...@@ -81,6 +130,7 @@ class Env(utils.object2):
raise Exception("%s is already owned by another env" % node) raise Exception("%s is already owned by another env" % node)
node.env = self node.env = self
node.deps = {} node.deps = {}
#self.execute_callbacks('on_setup_node', node)
def disown(self): def disown(self):
""" WRITEME """ WRITEME
...@@ -171,6 +221,7 @@ class Env(utils.object2): ...@@ -171,6 +221,7 @@ class Env(utils.object2):
raise TypeError("An input of the graph was not provided and not given a value", r) raise TypeError("An input of the graph was not provided and not given a value", r)
for node in new_nodes: for node in new_nodes:
assert node not in self.nodes
self.__setup_node__(node) self.__setup_node__(node)
self.nodes.add(node) self.nodes.add(node)
for output in node.outputs: for output in node.outputs:
...@@ -284,29 +335,6 @@ class Env(utils.object2): ...@@ -284,29 +335,6 @@ class Env(utils.object2):
Adds a feature to this env. The feature may define one Adds a feature to this env. The feature may define one
or more of the following methods: or more of the following methods:
- feature.on_attach(env)
Called by extend. The feature has great freedom in what
it can do with the env: it may, for example, add methods
to it dynicamically.
- feature.on_detach(env)
Called by remove_feature(feature).
- feature.on_import(env, node)*
Called whenever a node is imported into env, which is
just before the node is actually connected to the graph.
- feature.on_prune(env, node)*
Called whenever a node is pruned (removed) from the env,
after it is disconnected from the graph.
- feature.on_change_input(env, node, i, r, new_r)*
Called whenever node.inputs[i] is changed from r to new_r.
At the moment the callback is done, the change has already
taken place.
- feature.orderings(env)
Called by toposort. It should return a dictionary of
{node: predecessors} where predecessors is a list of
nodes that should be computed before the key node.
* If you raise an exception in the functions marked with an
asterisk, the state of the graph might be inconsistent.
""" """
if feature in self._features: if feature in self._features:
return # the feature is already present return # the feature is already present
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论