提交 ae25334d authored 作者: Olivier Breuleux's avatar Olivier Breuleux

checked and improved doc on env.py

上级 534a1953
......@@ -199,7 +199,39 @@ class _test_all(unittest.TestCase):
assert not g.consistent()
g.replace(tv, Sigmoid(x))
assert g.consistent()
def test_11(self):
x, y, z = inputs()
e1 = TransposeView(TransposeView(x))
e2 = TransposeView(TransposeView(e1))
e3 = AddInPlace(e2, y)
e4 = AddInPlace(e1, z)
g = env([x,y,z], [e3, e4], False)
assert not g.consistent()
g.replace(e2, TransposeView(x), False)
assert not g.consistent()
def test_12(self):
x, y, z = inputs()
e0 = AddInPlace(x, y)
e = Dot(Sigmoid(e0), TransposeView(x))
g = env([x,y,z], [e], False)
assert not g.consistent()
new_e0 = Add(x, y)
g.replace(e0, new_e0, False)
assert g.consistent()
g.replace(new_e0, AddInPlace(x, y), False)
assert not g.consistent()
def test_13(self):
x, y, z = inputs()
e0 = TransposeView(x)
e = Dot(Sigmoid(AddInPlace(x, y)), e0)
g = env([x,y,z], [e], False)
assert not g.consistent()
new_e0 = Add(e0, y)
g.replace(e0, new_e0, False)
assert g.consistent()
if __name__ == '__main__':
......
......@@ -53,11 +53,12 @@ class Env(graph.Graph):
L{Tool}.
Regarding inputs and orphans:
In the context of a computation graph, the inputs and orphans are both
results that are the source nodes of computation. Those results that are
named as inputs will be assumed to contain fresh. In other words, the
backward search from outputs will stop at any node that has been explicitly
named as an input.
In the context of a computation graph, the inputs and orphans are
both results that are the source nodes of computation. Those
results that are named as inputs will be assumed to contain fresh.
In other words, the backward search from outputs will stop at any
node that has been explicitly named as an input.
"""
### Special ###
......@@ -90,6 +91,8 @@ class Env(graph.Graph):
# Set of all the results that are not an output of an op in the subgraph but
# are an input of an op in the subgraph.
# e.g. z for inputs=(x, y) and outputs=(x + (y - z),)
# We initialize them to the set of outputs; if an output depends on an input,
# it will be removed from the set of orphans.
self._orphans = set(outputs)
# Maps results to ops that use them:
......@@ -111,6 +114,7 @@ class Env(graph.Graph):
### Public interface ###
def add_output(self, output):
"Add an output to the Env."
self.outputs.add(output)
self.orphans.add(output)
self.__import_r__([output])
......@@ -138,10 +142,17 @@ class Env(graph.Graph):
return True
def satisfy(self, x):
"Adds the features required by x unless they are already present."
for feature_class in require_set(x):
self.add_feature(feature_class)
def add_feature(self, feature_class, do_import = True):
"""
Adds an instance of the feature_class to this env's supported
features. If do_import is True and feature_class is a subclass
of Listener, its on_import method will be called on all the Ops
already in the env.
"""
if feature_class in self._features:
return # the feature is already present
else:
......@@ -210,15 +221,17 @@ class Env(graph.Graph):
return op in self._ops
def orphans(self):
"""All results not within the subgraph bound by env.inputs and env.outputs, not in
env.inputs but required by some op."""
"""
All results not within the subgraph bound by env.inputs and
env.outputs, not in env.inputs but required by some op.
"""
return self._orphans
def replace(self, r, new_r, consistency_check = True):
"""
This is the main interface to manipulate the subgraph in Env.
For every op that uses r as input, makes it use new_r instead.
This may raise a GofTypeError if the new result violates type
This may raise an error if the new result violates type
constraints for one of the target ops. In that case, no
changes are made.
......@@ -228,8 +241,8 @@ class Env(graph.Graph):
graph the way it was before the call to replace.
If consistency_check is False, the replacement will succeed
even if there is an inconsistency. A GofTypeError will still
be raised if there are type mismatches.
even if there is an inconsistency, unless the replacement
violates hard constraints on the types involved.
"""
self.__import_r_satisfy__([new_r])
......@@ -277,9 +290,10 @@ class Env(graph.Graph):
def replace_all(self, d):
"""
For (r, new_r) in d.items(), replaces r with new_r. Checks for consistency at the
end and raises an InconsistencyError if the graph is not consistent. If an error is
raised, the graph is restored to what it was before.
For (r, new_r) in d.items(), replaces r with new_r. Checks for
consistency at the end and raises an InconsistencyError if the
graph is not consistent. If an error is raised, the graph is
restored to what it was before.
"""
chk = self.checkpoint()
try:
......@@ -295,19 +309,29 @@ class Env(graph.Graph):
raise
def results(self):
"All results within the subgraph bound by env.inputs and env.outputs and including them"
"""
All results within the subgraph bound by env.inputs and
env.outputs and including them
"""
return self._results
def revert(self, checkpoint):
"""
Reverts the graph to whatever it was at the provided checkpoint (undoes all replacements).
A checkpoint at any given time can be obtained using self.checkpoint().
Reverts the graph to whatever it was at the provided
checkpoint (undoes all replacements). A checkpoint at any
given time can be obtained using self.checkpoint().
"""
while len(self.history) > checkpoint:
f = self.history.pop()
f()
def supplemental_orderings(self):
"""
Returns a dictionary of {op: set(prerequisites)} that must
be satisfied in addition to the order defined by the structure
of the graph (returns orderings that not related to input/output
relationships).
"""
ords = {}
for ordering in self._orderings.values():
for op, prereqs in ordering.orderings().items():
......@@ -316,14 +340,18 @@ class Env(graph.Graph):
def toposort(self):
"""
Returns a list of ops in the order that they must be executed in order to preserve
the semantics of the graph and respect the constraints put forward by the listeners.
Returns a list of ops in the order that they must be executed
in order to preserve the semantics of the graph and respect
the constraints put forward by the listeners.
"""
ords = self.supplemental_orderings()
order = graph.io_toposort(self.inputs, self.outputs, ords)
return order
def validate(self):
"""
Raises an error if the graph is inconsistent.
"""
for constraint in self._constraints.values():
constraint.validate()
return True
......@@ -332,9 +360,21 @@ class Env(graph.Graph):
### Private interface ###
def __add_clients__(self, r, all):
"""
r -> result
all -> list of (op, i) pairs representing who r is an input of.
Updates the list of clients of r with all.
"""
self._clients.setdefault(r, set()).update(all)
def __remove_clients__(self, r, all):
"""
r -> result
all -> list of (op, i) pairs representing who r is an input of.
Removes all from the clients list of r.
"""
if not all:
return
self._clients[r].difference_update(all)
......@@ -344,11 +384,12 @@ class Env(graph.Graph):
self._orphans.remove(r)
def __import_r_satisfy__(self, results):
# Satisfies the owners of the results.
for op in graph.ops(self.results(), results):
self.satisfy(op)
def __import_r__(self, results):
# Imports the owners of the results
for result in results:
owner = result.owner
if owner:
......@@ -385,6 +426,7 @@ class Env(graph.Graph):
__import__.E_output = 'op output in Env.inputs'
def __prune_r__(self, results):
# Prunes the owners of the results.
for result in set(results):
if result in self.inputs:
continue
......@@ -393,6 +435,10 @@ class Env(graph.Graph):
self.__prune__(owner)
def __prune__(self, op):
# If op's outputs have no clients, removes it from the graph
# and recursively tries to prune its inputs. If at least one
# of the op's outputs is an output to the graph or has a client
# then __prune__ is a no-op.
for output in op.outputs:
# Cannot prune an op which is an output or used somewhere
if self.clients(output) or output in self.outputs: #output in self.outputs or self.clients(output):
......
......@@ -14,21 +14,69 @@ __all__ = ['Destroyer',
class DestroyHandler(Listener, Constraint, Orderings, Tool):
"""
This feature ensures that an env represents a consistent data flow
when some Ops overwrite their inputs and/or provide "views" over
some of their inputs. It does so by tracking dependencies between
data at different stages of the graph and ensuring that
destructive operations are performed after the destroyed data and
all of its views have been processed.
Examples:
* (x += 1) + (x += 1) -> fails because the first += makes the second
invalid
* x += transpose_view(x) -> fails because the input that is destroyed
depends on an input that shares the same data
* (a += b) + (c += a) -> succeeds but we have to do c += a first
* (a += b) + (b += c) + (c += a) -> fails because there's a cyclical
dependency (no possible ordering)
This feature allows some optimizations (eg sub += for +) to be applied
safely.
"""
def __init__(self, env):
# For an Op that has a view_map, {output : input it is a view of}
self.parent = {}
# Reverse mapping of parent: {input : outputs that are a view of it}
self.children = {}
# {foundation : {op that destroys it : path }}
# where foundation is a result such that (not self.parent[result])
# and path is a sequence of results such that:
# * path[0] == foundation
# * self.parent[path[i]] == path[i-1]
# * path[-1] == output of the Op that is the Destroyer
self.destroyers = {}
# Cache for the paths
self.paths = {}
### if any of dups, cycles or illegal is not empty, the env is inconsistent
# Set of results that are destroyed more than once.
self.dups = set()
# Set of sequences of results that represent a dependency cycle, i.e.
# [a, ... b, ... c, ... a] if our graph is ((a += b) + (b += c) + (c += a))
self.cycles = set()
# Set of results that have one Op that destroys them but have been marked
# indestructible by the user.
self.illegal = set()
self.env = env
self.seen = set()
# Initialize the children if the inputs and orphans.
for input in env.orphans().union(env.inputs):
self.children[input] = set()
def publish(self):
"""
Publishes the following on the env:
* destroyers(r) -> returns all Ops that destroy the result r
* destroy_handler -> self
"""
def __destroyers(r):
ret = self.destroyers.get(r, {})
ret = ret.keys()
......@@ -37,6 +85,13 @@ class DestroyHandler(Listener, Constraint, Orderings, Tool):
self.env.destroy_handler = self
def __path__(self, r):
"""
Returns a path from r to the result that it is ultimately
a view of, i.e. path such that:
* path[-1] == r
* path[i] == parent[path[i+1]]
* parent[path[0]] == None
"""
path = self.paths.get(r, None)
if path:
return path
......@@ -51,6 +106,10 @@ class DestroyHandler(Listener, Constraint, Orderings, Tool):
return rval
def __views__(self, r):
"""
Returns the set of results (inclusive) such that all the
results in the set are views of r, directly or indirectly.
"""
children = self.children[r]
if not children:
return set([r])
......@@ -61,6 +120,12 @@ class DestroyHandler(Listener, Constraint, Orderings, Tool):
return rval
def __users__(self, r):
"""
Returns the outputs of all the ops that use r or a view
of r. In other words, for all ops that have an input that
is r or a view of r, adds their outputs to the set that
is returned.
"""
views = self.__views__(r)
rval = set()
for view in views:
......@@ -70,23 +135,40 @@ class DestroyHandler(Listener, Constraint, Orderings, Tool):
return rval
def __pre__(self, op):
"""
Returns all results that must be computed prior to computing
this op.
"""
rval = set()
if op is None:
return rval
keep_going = False
for input in op.inputs:
# Get the basic result the input is a view of.
foundation = self.__path__(input)[0]
destroyers = self.destroyers.get(foundation, set())
if destroyers:
keep_going = True
# Is this op destroying the foundation? If yes,
# all users of the foundation must be computed before
# we overwrite its contents.
if op in destroyers:
users = self.__users__(foundation)
rval.update(users)
rval.update(op.inputs)
rval.difference_update(op.outputs)
rval.update(op.inputs) # obviously
rval.difference_update(op.outputs) # this op's outputs will always be in the users
return rval
def __detect_cycles_helper__(self, r, seq):
"""
Does a depth-first search to find cycles in the graph of
computation given a directed connection from an op to
its __pre__ set.
* seq -> sequence of nodes visited up to now
* r -> current node
If r is found in seq, we have a cycle and it is added to
the set of cycles.
"""
if r in seq:
self.cycles.add(tuple(seq[seq.index(r):]))
return
......@@ -95,6 +177,13 @@ class DestroyHandler(Listener, Constraint, Orderings, Tool):
self.__detect_cycles_helper__(r2, seq + [r])
def __detect_cycles__(self, start, just_remove=False):
"""
Tries to find a cycle containing any of the users of
start. Prior to doing, we remove all existing cycles
containing an user of start from the cycles set. If
just_remove is True, we return immediately after removing the
cycles.
"""
users = self.__users__(start)
users.add(start)
for user in users:
......@@ -107,6 +196,12 @@ class DestroyHandler(Listener, Constraint, Orderings, Tool):
self.__detect_cycles_helper__(user, [])
def get_maps(self, op):
"""
Returns vmap, dmap where:
* vmap -> {output : [inputs output is a view of]}
* dmap -> {output : [inputs that are destroyed by the Op
(and presumably returned as that output)]}
"""
try: vmap = op.view_map()
except AttributeError, AbstractFunctionError: vmap = {}
try: dmap = op.destroy_map()
......@@ -114,6 +209,10 @@ class DestroyHandler(Listener, Constraint, Orderings, Tool):
return vmap, dmap
def on_import(self, op):
"""
Recomputes the dependencies and search for inconsistencies given
that we just added an op to the env.
"""
self.seen.add(op)
view_map, destroy_map = self.get_maps(op)
......@@ -121,7 +220,7 @@ class DestroyHandler(Listener, Constraint, Orderings, Tool):
for i, output in enumerate(op.outputs):
views = view_map.get(output, None)
destroyed = destroy_map.get(output, None)
if destroyed:
for input in destroyed:
path = self.__path__(input)
......@@ -129,6 +228,8 @@ class DestroyHandler(Listener, Constraint, Orderings, Tool):
elif views:
if len(views) > 1:
# This is a limitation of DestroyHandler
# TODO: lift it (requires changes everywhere)
raise Exception("Output is a view of too many inputs.")
self.parent[output] = views[0]
for input in views:
......@@ -137,13 +238,26 @@ class DestroyHandler(Listener, Constraint, Orderings, Tool):
self.children[output] = set()
for output in op.outputs:
self.__detect_cycles__(output)
# output has no users and is not in any cycle because it
# is new. We must however check for cycles from the output
# eg if we are importing F in F(a += b, a) we will obtain
# the following cycle: [F.out, +=.out, F.out] because __pre__
# of +=.out, since it is destructive, must contains all the
# users of a including F.out. A cycle not involving F.out
# cannot occur.
self.__detect_cycles_helper__(output, [])
def on_prune(self, op):
"""
Recomputes the dependencies and searches for inconsistencies to remove
given that we just removed an op to the env.
"""
view_map, destroy_map = self.get_maps(op)
if destroy_map:
# Clean up self.destroyers considering that this op is gone.
destroyers = []
for i, input in enumerate(op.inputs):
destroyers.append(self.destroyers.get(self.__path__(input)[0], {}))
......@@ -153,6 +267,7 @@ class DestroyHandler(Listener, Constraint, Orderings, Tool):
self.__remove_destroyer__(path)
if view_map:
# Clean the children of the inputs if this Op was a view of any of them.
for i, input in enumerate(op.inputs):
self.children[input].difference_update(op.outputs)
......@@ -161,8 +276,13 @@ class DestroyHandler(Listener, Constraint, Orderings, Tool):
del self.paths[output]
except:
pass
# True means that we are just removing cycles pertaining to this output
# including cycles involving the users of the output (since there should
# be no more users after the op is pruned).
# No new cycles can be added by removing a node.
self.__detect_cycles__(output, True)
# Clean up parents and children
for i, output in enumerate(op.outputs):
try:
self.parent[output]
......@@ -175,6 +295,10 @@ class DestroyHandler(Listener, Constraint, Orderings, Tool):
def __add_destroyer__(self, path):
"""
Processes the information that path[0] is destroyed by path[-1].owner.
"""
foundation = path[0]
target = path[-1]
......@@ -186,11 +310,16 @@ class DestroyHandler(Listener, Constraint, Orderings, Tool):
if len(destroyers) > 1:
self.dups.add(foundation)
# results marked 'indestructible' must not be destroyed.
if getattr(foundation, 'indestructible', False):
self.illegal.add(foundation)
def __remove_destroyer__(self, path):
"""
Processes the information that path[0] is no longer destroyed by path[-1].owner.
"""
foundation = path[0]
target = path[-1]
op = target.owner
......@@ -207,25 +336,39 @@ class DestroyHandler(Listener, Constraint, Orderings, Tool):
def on_rewire(self, clients, r_1, r_2):
"""
Recomputes the dependencies and searches for inconsistencies to remove
given that all the clients are moved from r_1 to r_2, clients being
a list of (op, i) pairs such that op.inputs[i] used to be r_1 and is
now r_2.
"""
path_1 = self.__path__(r_1)
path_2 = self.__path__(r_2)
# All the affected results one level below the replacement.
prev = set()
for op, i in clients:
prev.update(op.outputs)
# Here we look at what destroys r_1, directly or indirectly. Since we
# replace r_1, we must adjust the destroyers. Each destroyer has a path,
# as described in __path__ and __add_destroyer__. Here is the logic to
# adjust a path that contains r_1 at index idx and r_prev at index idx+1.
# * idx == len(path)-1: do nothing
# * r_prev not in prev: do nothing
# * else: concatenate path_2 to the part of the path before r_1.
foundation = path_1[0]
destroyers = self.destroyers.get(foundation, {}).items()
for op, path in destroyers:
if r_1 in path:
idx = path.index(r_1)
self.__remove_destroyer__(path)
if not (idx > 0 and path[idx - 1] in prev):
if idx == len(path)-1 or path[idx+1] not in prev:
continue
index = path.index(r_1)
new_path = path_2 + path[index+1:]
self.__add_destroyer__(new_path)
self.__remove_destroyer__(path)
self.__add_destroyer__(path_2 + path[idx+1:])
# Clean up parents and children
for op, i in clients:
view_map, _ = self.get_maps(op)
for output, inputs in view_map.items():
......@@ -245,10 +388,17 @@ class DestroyHandler(Listener, Constraint, Orderings, Tool):
except:
pass
self.__detect_cycles__(r_1)
# Recompute the cycles from both r_1 and r_2.
self.__detect_cycles__(r_1) # we should really just remove the cycles that have r_1 and a result in prev just before
self.__detect_cycles__(r_2)
def validate(self):
"""
Raises an InconsistencyError on any of the following conditions:
* Some results are destroyed by more than one Op
* There is a cycle of preconditions
* An Op attempts to destroy an indestructible result.
"""
if self.dups:
raise InconsistencyError("The following values are destroyed more than once: %s" % self.dups)
elif self.cycles:
......@@ -259,6 +409,12 @@ class DestroyHandler(Listener, Constraint, Orderings, Tool):
return True
def orderings(self):
"""
Returns a dict of {op : set(ops that must be computed before it)} according
to DestroyHandler.
In particular, all the users of a destroyed result have priority over the
op that destroys the result.
"""
ords = {}
for foundation, destroyers in self.destroyers.items():
for op in destroyers.keys():
......@@ -267,11 +423,23 @@ class DestroyHandler(Listener, Constraint, Orderings, Tool):
class Destroyer:
"""
Base class for Ops that destroy one or more of their inputs in an
inplace operation, use them as temporary storage, puts garbage in
them or anything else that invalidates the contents for use by other
Ops.
"""
def destroyed_inputs(self):
raise AbstractFunctionError()
def destroy_map(self):
"""
Returns the map {output: [list of destroyed inputs]}
While it typically means that the storage of the output is
shared with each of the destroyed inputs, it does necessarily
have to be the case.
"""
# compatibility
return {self.out: self.destroyed_inputs()}
......@@ -280,27 +448,41 @@ class Destroyer:
class Viewer:
"""
Base class for Ops that return one or more views over one or more inputs,
which means that the inputs and outputs share their storage. Unless it also
extends Destroyer, this Op does not modify the storage in any way and thus
the input is safe for use by other Ops even after executing this one.
"""
def view_map(self):
"""
Returns the map {output: [list of viewed inputs]}
It means that the output shares storage with each of the inputs
in the list.
Note: support for more than one viewed input is minimal, but
this might improve in the future.
"""
raise AbstractFunctionError()
def view_roots(self, output):
def helper(r):
"""Return the leaves of a search through consecutive view_map()s"""
owner = r.owner
if owner is not None:
try:
view_map = owner.view_map()
except AttributeError, AbstractFunctionError:
return []
if r in view_map:
answer = []
for r2 in view_map[r]:
answer.extend(helper(r2))
return answer
else:
return [r]
def view_roots(self, r):
"""
Utility function that returns the leaves of a search through
consecutive view_map()s.
"""
owner = r.owner
if owner is not None:
try:
view_map = owner.view_map()
except AttributeError, AbstractFunctionError:
return []
if r in view_map:
answer = []
for r2 in view_map[r]:
answer.extend(helper(r2))
return answer
else:
return [r]
return helper(output)
else:
return [r]
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论