提交 14090d83 authored 作者: Olivier Breuleux's avatar Olivier Breuleux

documentation and cleanup

上级 abd860d7
...@@ -4,11 +4,15 @@ import unittest ...@@ -4,11 +4,15 @@ import unittest
from link import PerformLinker, Profiler from link import PerformLinker, Profiler
from cc import * from cc import *
from type import Type from type import Type
from graph import Result, as_result, Apply, Constant from graph import Result, Apply, Constant
from op import Op from op import Op
import env import env
import toolbox import toolbox
def as_result(x):
assert isinstance(x, Result)
return x
class TDouble(Type): class TDouble(Type):
def filter(self, data): def filter(self, data):
return float(data) return float(data)
......
...@@ -3,18 +3,20 @@ import unittest ...@@ -3,18 +3,20 @@ import unittest
from type import Type from type import Type
import graph import graph
from graph import Result, as_result, Apply from graph import Result, Apply
from op import Op from op import Op
from opt import PatternOptimizer, OpSubOptimizer from opt import PatternOptimizer, OpSubOptimizer
from ext import * from ext import *
from env import Env, InconsistencyError from env import Env, InconsistencyError
#from toolbox import EquivTool
from toolbox import ReplaceValidate from toolbox import ReplaceValidate
from copy import copy from copy import copy
#from _test_result import MyResult
def as_result(x):
assert isinstance(x, Result)
return x
class MyType(Type): class MyType(Type):
......
...@@ -15,6 +15,10 @@ else: ...@@ -15,6 +15,10 @@ else:
realtestcase = unittest.TestCase realtestcase = unittest.TestCase
def as_result(x):
assert isinstance(x, Result)
return x
class MyType(Type): class MyType(Type):
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
import unittest import unittest
import graph import graph
from graph import Result, as_result, Apply, Constant from graph import Result, Apply, Constant
from type import Type from type import Type
from op import Op from op import Op
import env import env
...@@ -14,6 +14,10 @@ from link import * ...@@ -14,6 +14,10 @@ from link import *
#from _test_result import Double #from _test_result import Double
def as_result(x):
assert isinstance(x, Result)
return x
class TDouble(Type): class TDouble(Type):
def filter(self, data): def filter(self, data):
return float(data) return float(data)
......
...@@ -3,9 +3,11 @@ import unittest ...@@ -3,9 +3,11 @@ import unittest
from copy import copy from copy import copy
from op import * from op import *
from type import Type, Generic from type import Type, Generic
from graph import Apply, as_result from graph import Apply, Result
#from result import Result def as_result(x):
assert isinstance(x, Result)
return x
class MyType(Type): class MyType(Type):
......
import unittest import unittest
from graph import Result, as_result, Apply, Constant from type import Type
from graph import Result, Apply, Constant
from op import Op from op import Op
from opt import * from opt import *
from env import Env from env import Env
from toolbox import * from toolbox import *
def as_result(x):
assert isinstance(x, Result)
return x
class MyType(Type): class MyType(Type):
......
import unittest import unittest
from graph import Result, as_result, Apply from graph import Result, Apply
from type import Type from type import Type
from op import Op from op import Op
#from opt import PatternOptimizer, OpSubOptimizer
from env import Env, InconsistencyError from env import Env, InconsistencyError
from toolbox import * from toolbox import *
def as_result(x):
assert isinstance(x, Result)
return x
class MyType(Type): class MyType(Type):
def __init__(self, name): def __init__(self, name):
......
差异被折叠。
from copy import copy from copy import copy
import graph import graph
##from features import Listener, Orderings, Constraint, Tool, uniq_features
import utils import utils
from utils import AbstractFunctionError
class InconsistencyError(Exception): class InconsistencyError(Exception):
""" """
This exception is raised by Env whenever one of the listeners marks This exception should be thrown by listeners to Env when the
the graph as inconsistent. graph's state is invalid.
""" """
pass pass
class Env(object): #(graph.Graph): class Env(utils.object2):
""" """
An Env represents a subgraph bound by a set of input results and a An Env represents a subgraph bound by a set of input results and a
set of output results. An op is in the subgraph iff it depends on set of output results. The inputs list should contain all the inputs
the value of some of the Env's inputs _and_ some of the Env's on which the outputs depend. Results of type Value or Constant are
outputs depend on it. A result is in the subgraph iff it is an not counted as inputs.
input or an output of an op that is in the subgraph.
The Env supports the replace operation which allows to replace a The Env supports the replace operation which allows to replace a
result in the subgraph by another, e.g. replace (x + x).out by (2 result in the subgraph by another, e.g. replace (x + x).out by (2
* x).out. This is the basis for optimization in theano. * x).out. This is the basis for optimization in theano.
It can also be "extended" using env.extend(some_object). See the
toolbox and ext modules for common extensions.
""" """
### Special ### ### Special ###
...@@ -65,12 +64,14 @@ class Env(object): #(graph.Graph): ...@@ -65,12 +64,14 @@ class Env(object): #(graph.Graph):
### Setup a Result ### ### Setup a Result ###
def __setup_r__(self, r): def __setup_r__(self, r):
# sets up r so it belongs to this env
if hasattr(r, 'env') and r.env is not None and r.env is not self: if hasattr(r, 'env') and r.env is not None and r.env is not self:
raise Exception("%s is already owned by another env" % r) raise Exception("%s is already owned by another env" % r)
r.env = self r.env = self
r.clients = [] r.clients = []
def __setup_node__(self, node): def __setup_node__(self, node):
# sets up node so it belongs to this env
if hasattr(node, 'env') and node.env is not self: if hasattr(node, 'env') and node.env is not self:
raise Exception("%s is already owned by another env" % node) raise Exception("%s is already owned by another env" % node)
node.env = self node.env = self
...@@ -80,28 +81,27 @@ class Env(object): #(graph.Graph): ...@@ -80,28 +81,27 @@ class Env(object): #(graph.Graph):
### clients ### ### clients ###
def clients(self, r): def clients(self, r):
"Set of all the (op, i) pairs such that op.inputs[i] is r." "Set of all the (node, i) pairs such that node.inputs[i] is r."
return r.clients return r.clients
def __add_clients__(self, r, all): def __add_clients__(self, r, new_clients):
""" """
r -> result r -> result
all -> list of (op, i) pairs representing who r is an input of. new_clients -> list of (node, i) pairs such that node.inputs[i] is r.
Updates the list of clients of r with all. Updates the list of clients of r with new_clients.
""" """
r.clients += all r.clients += new_clients
def __remove_clients__(self, r, all, prune = True): def __remove_clients__(self, r, clients_to_remove, prune = True):
""" """
r -> result r -> result
all -> list of (op, i) pairs representing who r is an input of. clients_to_remove -> list of (op, i) pairs such that node.inputs[i] is not r anymore.
Removes all from the clients list of r. Removes all from the clients list of r.
""" """
for entry in all: for entry in clients_to_remove:
r.clients.remove(entry) r.clients.remove(entry)
# remove from orphans?
if not r.clients: if not r.clients:
if prune: if prune:
self.__prune_r__([r]) self.__prune_r__([r])
...@@ -188,6 +188,15 @@ class Env(object): #(graph.Graph): ...@@ -188,6 +188,15 @@ class Env(object): #(graph.Graph):
### change input ### ### change input ###
def change_input(self, node, i, new_r): def change_input(self, node, i, new_r):
"""
Changes node.inputs[i] to new_r.
new_r.type == old_r.type must be True, where old_r is the
current value of node.inputs[i] which we want to replace.
For each feature that has a 'on_change_input' method, calls:
feature.on_change_input(env, node, i, old_r, new_r)
"""
if node == 'output': if node == 'output':
r = self.outputs[i] r = self.outputs[i]
if not r.type == new_r.type: if not r.type == new_r.type:
...@@ -214,10 +223,7 @@ class Env(object): #(graph.Graph): ...@@ -214,10 +223,7 @@ class Env(object): #(graph.Graph):
def replace(self, r, new_r): def replace(self, r, new_r):
""" """
This is the main interface to manipulate the subgraph in Env. This is the main interface to manipulate the subgraph in Env.
For every op that uses r as input, makes it use new_r instead. For every node that uses r as input, makes it use new_r instead.
This may raise an error if the new result violates type
constraints for one of the target nodes. In that case, no
changes are made.
""" """
if r.env is not self: if r.env is not self:
raise Exception("Cannot replace %s because it does not belong to this Env" % r) raise Exception("Cannot replace %s because it does not belong to this Env" % r)
...@@ -238,11 +244,32 @@ class Env(object): #(graph.Graph): ...@@ -238,11 +244,32 @@ class Env(object): #(graph.Graph):
def extend(self, feature): def extend(self, feature):
""" """
@todo out of date Adds a feature to this env. The feature may define one
Adds an instance of the feature_class to this env's supported or more of the following methods:
features. If do_import is True and feature_class is a subclass
of Listener, its on_import method will be called on all the Nodes - feature.on_attach(env)
already in the env. Called by extend. The feature has great freedom in what
it can do with the env: it may, for example, add methods
to it dynicamically.
- feature.on_detach(env)
Called by remove_feature(feature).
- feature.on_import(env, node)*
Called whenever a node is imported into env, which is
just before the node is actually connected to the graph.
- feature.on_prune(env, node)*
Called whenever a node is pruned (removed) from the env,
after it is disconnected from the graph.
- feature.on_change_input(env, node, i, r, new_r)*
Called whenever node.inputs[i] is changed from r to new_r.
At the moment the callback is done, the change has already
taken place.
- feature.orderings(env)
Called by toposort. It should return a dictionary of
{node: predecessors} where predecessors is a list of
nodes that should be computed before the key node.
* If you raise an exception in the functions marked with an
asterisk, the state of the graph might be inconsistent.
""" """
if feature in self._features: if feature in self._features:
return # the feature is already present return # the feature is already present
...@@ -256,6 +283,11 @@ class Env(object): #(graph.Graph): ...@@ -256,6 +283,11 @@ class Env(object): #(graph.Graph):
raise raise
def remove_feature(self, feature): def remove_feature(self, feature):
"""
Removes the feature from the graph.
Calls feature.on_detach(env) if an on_detach method is defined.
"""
try: try:
self._features.remove(feature) self._features.remove(feature)
except: except:
...@@ -268,6 +300,11 @@ class Env(object): #(graph.Graph): ...@@ -268,6 +300,11 @@ class Env(object): #(graph.Graph):
### callback utils ### ### callback utils ###
def execute_callbacks(self, name, *args): def execute_callbacks(self, name, *args):
"""
Calls
getattr(feature, name)(*args)
for each feature which has a method called after name.
"""
for feature in self._features: for feature in self._features:
try: try:
fn = getattr(feature, name) fn = getattr(feature, name)
...@@ -276,6 +313,11 @@ class Env(object): #(graph.Graph): ...@@ -276,6 +313,11 @@ class Env(object): #(graph.Graph):
fn(self, *args) fn(self, *args)
def collect_callbacks(self, name, *args): def collect_callbacks(self, name, *args):
"""
Returns a dictionary d such that:
d[feature] == getattr(feature, name)(*args)
For each feature which has a method called after name.
"""
d = {} d = {}
for feature in self._features: for feature in self._features:
try: try:
...@@ -289,6 +331,17 @@ class Env(object): #(graph.Graph): ...@@ -289,6 +331,17 @@ class Env(object): #(graph.Graph):
### misc ### ### misc ###
def toposort(self): def toposort(self):
"""
Returns an ordering of the graph's Apply nodes such that:
- All the nodes of the inputs of a node are before that node.
- Satisfies the orderings provided by each feature that has
an 'orderings' method.
If a feature has an 'orderings' method, it will be called with
this env as sole argument. It should return a dictionary of
{node: predecessors} where predecessors is a list of nodes
that should be computed before the key node.
"""
env = self env = self
ords = {} ords = {}
for feature in env._features: for feature in env._features:
...@@ -314,10 +367,10 @@ class Env(object): #(graph.Graph): ...@@ -314,10 +367,10 @@ class Env(object): #(graph.Graph):
raise Exception("what the fuck") raise Exception("what the fuck")
return node.inputs return node.inputs
def has_node(self, node):
return node in self.nodes
def check_integrity(self): def check_integrity(self):
"""
Call this for a diagnosis if things go awry.
"""
nodes = graph.ops(self.inputs, self.outputs) nodes = graph.ops(self.inputs, self.outputs)
if self.nodes != nodes: if self.nodes != nodes:
missing = nodes.difference(self.nodes) missing = nodes.difference(self.nodes)
......
#from features import Listener, Constraint, Orderings, Tool from collections import defaultdict
import graph import graph
import utils import utils
from utils import AbstractFunctionError import toolbox
from copy import copy from utils import AbstractFunctionError
from env import InconsistencyError from env import InconsistencyError
from toolbox import Bookkeeper
from collections import defaultdict
class DestroyHandler(toolbox.Bookkeeper):
class DestroyHandler(Bookkeeper): #(Listener, Constraint, Orderings, Tool):
""" """
This feature ensures that an env represents a consistent data flow This feature ensures that an env represents a consistent data flow
when some Ops overwrite their inputs and/or provide "views" over when some Ops overwrite their inputs and/or provide "views" over
...@@ -29,14 +21,16 @@ class DestroyHandler(Bookkeeper): #(Listener, Constraint, Orderings, Tool): ...@@ -29,14 +21,16 @@ class DestroyHandler(Bookkeeper): #(Listener, Constraint, Orderings, Tool):
Examples: Examples:
- (x += 1) + (x += 1) -> fails because the first += makes the second - (x += 1) + (x += 1) -> fails because the first += makes the second
invalid invalid
- x += transpose_view(x) -> fails because the input that is destroyed
depends on an input that shares the same data
- (a += b) + (c += a) -> succeeds but we have to do c += a first - (a += b) + (c += a) -> succeeds but we have to do c += a first
- (a += b) + (b += c) + (c += a) -> fails because there's a cyclical - (a += b) + (b += c) + (c += a) -> fails because there's a cyclical
dependency (no possible ordering) dependency (no possible ordering)
This feature allows some optimizations (eg sub += for +) to be applied This feature allows some optimizations (eg sub += for +) to be applied
safely. safely.
@todo
- x += transpose_view(x) -> fails because the input that is destroyed
depends on an input that shares the same data
""" """
def __init__(self): def __init__(self):
...@@ -88,11 +82,7 @@ class DestroyHandler(Bookkeeper): #(Listener, Constraint, Orderings, Tool): ...@@ -88,11 +82,7 @@ class DestroyHandler(Bookkeeper): #(Listener, Constraint, Orderings, Tool):
self.seen = set() self.seen = set()
Bookkeeper.on_attach(self, env) toolbox.Bookkeeper.on_attach(self, env)
# # Initialize the children if the inputs and orphans.
# for input in env.inputs: # env.orphans.union(env.inputs):
# self.children[input] = set()
def on_detach(self, env): def on_detach(self, env):
del self.parent del self.parent
...@@ -105,19 +95,6 @@ class DestroyHandler(Bookkeeper): #(Listener, Constraint, Orderings, Tool): ...@@ -105,19 +95,6 @@ class DestroyHandler(Bookkeeper): #(Listener, Constraint, Orderings, Tool):
del self.seen del self.seen
self.env = None self.env = None
# def publish(self):
# """
# Publishes the following on the env:
# - destroyers(r) -> returns all L{Op}s that destroy the result r
# - destroy_handler -> self
# """
# def __destroyers(r):
# ret = self.destroyers.get(r, {})
# ret = ret.keys()
# return ret
# self.env.destroyers = __destroyers
# self.env.destroy_handler = self
def __path__(self, r): def __path__(self, r):
""" """
Returns a path from r to the result that it is ultimately Returns a path from r to the result that it is ultimately
...@@ -171,7 +148,7 @@ class DestroyHandler(Bookkeeper): #(Listener, Constraint, Orderings, Tool): ...@@ -171,7 +148,7 @@ class DestroyHandler(Bookkeeper): #(Listener, Constraint, Orderings, Tool):
def __pre__(self, op): def __pre__(self, op):
""" """
Returns all results that must be computed prior to computing Returns all results that must be computed prior to computing
this op. this node.
""" """
rval = set() rval = set()
if op is None: if op is None:
...@@ -222,7 +199,7 @@ class DestroyHandler(Bookkeeper): #(Listener, Constraint, Orderings, Tool): ...@@ -222,7 +199,7 @@ class DestroyHandler(Bookkeeper): #(Listener, Constraint, Orderings, Tool):
users = set(self.__users__(start)) users = set(self.__users__(start))
users.add(start) users.add(start)
for user in users: for user in users:
for cycle in copy(self.cycles): for cycle in set(self.cycles):
if user in cycle: if user in cycle:
self.cycles.remove(cycle) self.cycles.remove(cycle)
if just_remove: if just_remove:
...@@ -234,7 +211,7 @@ class DestroyHandler(Bookkeeper): #(Listener, Constraint, Orderings, Tool): ...@@ -234,7 +211,7 @@ class DestroyHandler(Bookkeeper): #(Listener, Constraint, Orderings, Tool):
""" """
@return: (vmap, dmap) where: @return: (vmap, dmap) where:
- vmap -> {output : [inputs output is a view of]} - vmap -> {output : [inputs output is a view of]}
- dmap -> {output : [inputs that are destroyed by the Op - dmap -> {output : [inputs that are destroyed by the node
(and presumably returned as that output)]} (and presumably returned as that output)]}
""" """
try: _vmap = node.op.view_map try: _vmap = node.op.view_map
...@@ -260,7 +237,7 @@ class DestroyHandler(Bookkeeper): #(Listener, Constraint, Orderings, Tool): ...@@ -260,7 +237,7 @@ class DestroyHandler(Bookkeeper): #(Listener, Constraint, Orderings, Tool):
def on_import(self, env, op): def on_import(self, env, op):
""" """
Recomputes the dependencies and search for inconsistencies given Recomputes the dependencies and search for inconsistencies given
that we just added an op to the env. that we just added an node to the env.
""" """
self.seen.add(op) self.seen.add(op)
...@@ -303,7 +280,7 @@ class DestroyHandler(Bookkeeper): #(Listener, Constraint, Orderings, Tool): ...@@ -303,7 +280,7 @@ class DestroyHandler(Bookkeeper): #(Listener, Constraint, Orderings, Tool):
def on_prune(self, env, op): def on_prune(self, env, op):
""" """
Recomputes the dependencies and searches for inconsistencies to remove Recomputes the dependencies and searches for inconsistencies to remove
given that we just removed an op to the env. given that we just removed a node to the env.
""" """
view_map, destroy_map = self.get_maps(op) view_map, destroy_map = self.get_maps(op)
...@@ -400,7 +377,7 @@ class DestroyHandler(Bookkeeper): #(Listener, Constraint, Orderings, Tool): ...@@ -400,7 +377,7 @@ class DestroyHandler(Bookkeeper): #(Listener, Constraint, Orderings, Tool):
""" """
Recomputes the dependencies and searches for inconsistencies to remove Recomputes the dependencies and searches for inconsistencies to remove
given that all the clients are moved from r_1 to r_2, clients being given that all the clients are moved from r_1 to r_2, clients being
a list of (op, i) pairs such that op.inputs[i] used to be r_1 and is a list of (node, i) pairs such that node.inputs[i] used to be r_1 and is
now r_2. now r_2.
""" """
path_1 = self.__path__(r_1) path_1 = self.__path__(r_1)
...@@ -485,53 +462,6 @@ class DestroyHandler(Bookkeeper): #(Listener, Constraint, Orderings, Tool): ...@@ -485,53 +462,6 @@ class DestroyHandler(Bookkeeper): #(Listener, Constraint, Orderings, Tool):
# class Destroyer:
# """
# Base class for Ops that destroy one or more of their inputs in an
# inplace operation, use them as temporary storage, puts garbage in
# them or anything else that invalidates the contents for use by other
# Ops.
# Usage of this class in an env requires DestroyHandler.
# """
# def destroyed_inputs(self):
# raise AbstractFunctionError()
# def destroy_map(self):
# """
# Returns the map {output: [list of destroyed inputs]}
# While it typically means that the storage of the output is
# shared with each of the destroyed inputs, it does necessarily
# have to be the case.
# """
# # compatibility
# return {self.out: self.destroyed_inputs()}
# __env_require__ = DestroyHandler
# class Viewer:
# """
# Base class for Ops that return one or more views over one or more inputs,
# which means that the inputs and outputs share their storage. Unless it also
# extends Destroyer, this Op does not modify the storage in any way and thus
# the input is safe for use by other Ops even after executing this one.
# """
# def view_map(self):
# """
# Returns the map {output: [list of viewed inputs]}
# It means that the output shares storage with each of the inputs
# in the list.
# Note: support for more than one viewed input is minimal, but
# this might improve in the future.
# """
# raise AbstractFunctionError()
def view_roots(r): def view_roots(r):
""" """
Utility function that returns the leaves of a search through Utility function that returns the leaves of a search through
......
...@@ -3,20 +3,9 @@ from copy import copy ...@@ -3,20 +3,9 @@ from copy import copy
from collections import deque from collections import deque
import utils import utils
from utils import object2
def deprecated(f):
printme = [True]
def g(*args, **kwargs):
if printme[0]:
print 'gof.graph.%s deprecated: April 29' % f.__name__
printme[0] = False
return f(*args, **kwargs)
return g
class Apply(utils.object2):
class Apply(object2):
""" """
Note: it is illegal for an output element to have an owner != self Note: it is illegal for an output element to have an owner != self
""" """
...@@ -74,18 +63,13 @@ class Apply(object2): ...@@ -74,18 +63,13 @@ class Apply(object2):
raise TypeError("Cannot change the type of this input.", curr, new) raise TypeError("Cannot change the type of this input.", curr, new)
new_node = self.clone() new_node = self.clone()
new_node.inputs = inputs new_node.inputs = inputs
# new_node.outputs = []
# for output in self.outputs:
# new_output = copy(output)
# new_output.owner = new_node
# new_node.outputs.append(new_output)
return new_node return new_node
nin = property(lambda self: len(self.inputs)) nin = property(lambda self: len(self.inputs))
nout = property(lambda self: len(self.outputs)) nout = property(lambda self: len(self.outputs))
class Result(object2): class Result(utils.object2):
#__slots__ = ['type', 'owner', 'index', 'name'] #__slots__ = ['type', 'owner', 'index', 'name']
def __init__(self, type, owner = None, index = None, name = None): def __init__(self, type, owner = None, index = None, name = None):
self.type = type self.type = type
...@@ -111,9 +95,6 @@ class Result(object2): ...@@ -111,9 +95,6 @@ class Result(object2):
return "<?>::" + str(self.type) return "<?>::" + str(self.type)
def __repr__(self): def __repr__(self):
return str(self) return str(self)
@deprecated
def __asresult__(self):
return self
def clone(self): def clone(self):
return self.__class__(self.type, None, None, self.name) return self.__class__(self.type, None, None, self.name)
...@@ -137,7 +118,6 @@ class Constant(Value): ...@@ -137,7 +118,6 @@ class Constant(Value):
#__slots__ = ['data'] #__slots__ = ['data']
def __init__(self, type, data, name = None): def __init__(self, type, data, name = None):
Value.__init__(self, type, data, name) Value.__init__(self, type, data, name)
### self.indestructible = True
def equals(self, other): def equals(self, other):
# this does what __eq__ should do, but Result and Apply should always be hashable by id # this does what __eq__ should do, but Result and Apply should always be hashable by id
return type(other) == type(self) and self.signature() == other.signature() return type(other) == type(self) and self.signature() == other.signature()
...@@ -148,32 +128,6 @@ class Constant(Value): ...@@ -148,32 +128,6 @@ class Constant(Value):
return self.name return self.name
return str(self.data) #+ "::" + str(self.type) return str(self.data) #+ "::" + str(self.type)
@deprecated
def as_result(x):
if isinstance(x, Result):
return x
# elif isinstance(x, Type):
# return Result(x, None, None)
elif hasattr(x, '__asresult__'):
r = x.__asresult__()
if not isinstance(r, Result):
raise TypeError("%s.__asresult__ must return a Result instance" % x, (x, r))
return r
else:
raise TypeError("Cannot wrap %s in a Result" % x)
@deprecated
def as_apply(x):
if isinstance(x, Apply):
return x
elif hasattr(x, '__asapply__'):
node = x.__asapply__()
if not isinstance(node, Apply):
raise TypeError("%s.__asapply__ must return an Apply instance" % x, (x, node))
return node
else:
raise TypeError("Cannot map %s to Apply" % x)
def stack_search(start, expand, mode='bfs', build_inv = False): def stack_search(start, expand, mode='bfs', build_inv = False):
"""Search through L{Result}s, either breadth- or depth-first """Search through L{Result}s, either breadth- or depth-first
@type start: deque @type start: deque
...@@ -234,45 +188,6 @@ def inputs(result_list): ...@@ -234,45 +188,6 @@ def inputs(result_list):
return rval return rval
# def results_and_orphans(r_in, r_out, except_unreachable_input=False):
# r_in_set = set(r_in)
# class Dummy(object): pass
# dummy = Dummy()
# dummy.inputs = r_out
# def expand_inputs(io):
# if io in r_in_set:
# return None
# try:
# return [io.owner] if io.owner != None else None
# except AttributeError:
# return io.inputs
# ops_and_results, dfsinv = stack_search(
# deque([dummy]),
# expand_inputs, 'dfs', True)
# if except_unreachable_input:
# for r in r_in:
# if r not in dfsinv:
# raise Exception(results_and_orphans.E_unreached)
# clients = stack_search(
# deque(r_in),
# lambda io: dfsinv.get(io,None), 'dfs')
# ops_to_compute = [o for o in clients if is_op(o) and o is not dummy]
# results = []
# for o in ops_to_compute:
# results.extend(o.inputs)
# results.extend(r_out)
# op_set = set(ops_to_compute)
# assert len(ops_to_compute) == len(op_set)
# orphans = [r for r in results \
# if (r.owner not in op_set) and (r not in r_in_set)]
# return results, orphans
# results_and_orphans.E_unreached = 'there were unreachable inputs'
def results_and_orphans(i, o): def results_and_orphans(i, o):
""" """
""" """
...@@ -286,24 +201,6 @@ def results_and_orphans(i, o): ...@@ -286,24 +201,6 @@ def results_and_orphans(i, o):
return results, orphans return results, orphans
#def results_and_orphans(i, o):
# results = set()
# orphans = set()
# def helper(r):
# if r in results:
# return
# results.add(r)
# if r.owner is None:
# if r not in i:
# orphans.add(r)
# else:
# for r2 in r.owner.inputs:
# helper(r2)
# for output in o:
# helper(output)
# return results, orphans
def ops(i, o): def ops(i, o):
""" """
@type i: list @type i: list
...@@ -569,122 +466,3 @@ def as_string(i, o, ...@@ -569,122 +466,3 @@ def as_string(i, o,
return [describe(output) for output in o] return [describe(output) for output in o]
# class Graph:
# """
# Object-oriented wrapper for all the functions in this module.
# """
# def __init__(self, inputs, outputs):
# self.inputs = inputs
# self.outputs = outputs
# def ops(self):
# return ops(self.inputs, self.outputs)
# def values(self):
# return values(self.inputs, self.outputs)
# def orphans(self):
# return orphans(self.inputs, self.outputs)
# def io_toposort(self):
# return io_toposort(self.inputs, self.outputs)
# def toposort(self):
# return self.io_toposort()
# def clone(self):
# o = clone(self.inputs, self.outputs)
# return Graph(self.inputs, o)
# def __str__(self):
# return as_string(self.inputs, self.outputs)
if 0:
#these were the old implementations
# they were replaced out of a desire that graph search routines would not
# depend on the hash or id of any node, so that it would be deterministic
# and consistent between program executions.
@utils.deprecated('gof.graph', 'preserving only for review')
def _results_and_orphans(i, o, except_unreachable_input=False):
"""
@type i: list
@param i: input L{Result}s
@type o: list
@param o: output L{Result}s
Returns the pair (results, orphans). The former is the set of
L{Result}s that are involved in the subgraph that lies between i and
o. This includes i, o, orphans(i, o) and all results of all
intermediary steps from i to o. The second element of the returned
pair is orphans(i, o).
"""
results = set()
i = set(i)
results.update(i)
incomplete_paths = []
reached = set()
def helper(r, path):
if r in i:
reached.add(r)
results.update(path)
elif r.owner is None:
incomplete_paths.append(path)
else:
op = r.owner
for r2 in op.inputs:
helper(r2, path + [r2])
for output in o:
helper(output, [output])
orphans = set()
for path in incomplete_paths:
for r in path:
if r not in results:
orphans.add(r)
break
if except_unreachable_input and len(i) != len(reached):
raise Exception(results_and_orphans.E_unreached)
results.update(orphans)
return results, orphans
def _io_toposort(i, o, orderings = {}):
"""
@type i: list
@param i: input L{Result}s
@type o: list
@param o: output L{Result}s
@param orderings: {op: [requirements for op]} (defaults to {})
@rtype: ordered list
@return: L{Op}s that belong in the subgraph between i and o which
respects the following constraints:
- all inputs in i are assumed to be already computed
- the L{Op}s that compute an L{Op}'s inputs must be computed before it
- the orderings specified in the optional orderings parameter must be satisfied
Note that this function does not take into account ordering information
related to destructive operations or other special behavior.
"""
prereqs_d = copy(orderings)
all = ops(i, o)
for op in all:
asdf = set([input.owner for input in op.inputs if input.owner and input.owner in all])
prereqs_d.setdefault(op, set()).update(asdf)
return utils.toposort(prereqs_d)
from utils import AbstractFunctionError
import utils import utils
import graph
from graph import Value import sys, traceback
import sys
import traceback
__excepthook = sys.excepthook __excepthook = sys.excepthook
...@@ -67,7 +64,7 @@ class Linker: ...@@ -67,7 +64,7 @@ class Linker:
print new_e.data # 3.0 print new_e.data # 3.0
print e.data # 3.0 iff inplace == True (else unknown) print e.data # 3.0 iff inplace == True (else unknown)
""" """
raise AbstractFunctionError() raise utils.AbstractFunctionError()
def make_function(self, unpack_single = True, **kwargs): def make_function(self, unpack_single = True, **kwargs):
""" """
...@@ -151,7 +148,7 @@ def map_storage(env, order, input_storage, output_storage): ...@@ -151,7 +148,7 @@ def map_storage(env, order, input_storage, output_storage):
for node in order: for node in order:
for r in node.inputs: for r in node.inputs:
if r not in storage_map: if r not in storage_map:
assert isinstance(r, Value) assert isinstance(r, graph.Value)
storage_map[r] = [r.data] storage_map[r] = [r.data]
for r in node.outputs: for r in node.outputs:
storage_map.setdefault(r, [None]) storage_map.setdefault(r, [None])
......
""" """
Contains the L{Op} class, which is the base interface for all operations Contains the L{Op} class, which is the base interface for all operations
compatible with gof's graph manipulation routines. compatible with gof's graph manipulation routines.
""" """
import utils import utils
from utils import AbstractFunctionError, object2
from copy import copy
class Op(object2): class Op(utils.object2):
default_output = None default_output = None
"""@todo """@todo
...@@ -22,9 +17,19 @@ class Op(object2): ...@@ -22,9 +17,19 @@ class Op(object2):
############# #############
def make_node(self, *inputs): def make_node(self, *inputs):
raise AbstractFunctionError() """
This function should return an Apply instance representing the
application of this Op on the provided inputs.
"""
raise utils.AbstractFunctionError()
def __call__(self, *inputs): def __call__(self, *inputs):
"""
Shortcut for:
self.make_node(*inputs).outputs[self.default_output] (if default_output is defined)
self.make_node(*inputs).outputs[0] (if only one output)
self.make_node(*inputs).outputs (if more than one output)
"""
node = self.make_node(*inputs) node = self.make_node(*inputs)
if self.default_output is not None: if self.default_output is not None:
return node.outputs[self.default_output] return node.outputs[self.default_output]
...@@ -44,6 +49,7 @@ class Op(object2): ...@@ -44,6 +49,7 @@ class Op(object2):
Calculate the function on the inputs and put the results in the Calculate the function on the inputs and put the results in the
output storage. output storage.
- node: Apply instance that contains the symbolic inputs and outputs
- inputs: sequence of inputs (immutable) - inputs: sequence of inputs (immutable)
- output_storage: list of mutable 1-element lists (do not change - output_storage: list of mutable 1-element lists (do not change
the length of these lists) the length of these lists)
...@@ -53,7 +59,7 @@ class Op(object2): ...@@ -53,7 +59,7 @@ class Op(object2):
by a previous call to impl and impl is free to reuse it as it by a previous call to impl and impl is free to reuse it as it
sees fit. sees fit.
""" """
raise AbstractFunctionError() raise utils.AbstractFunctionError()
##################### #####################
# C code generation # # C code generation #
...@@ -62,9 +68,8 @@ class Op(object2): ...@@ -62,9 +68,8 @@ class Op(object2):
def c_code(self, node, name, inputs, outputs, sub): def c_code(self, node, name, inputs, outputs, sub):
"""Return the C implementation of an Op. """Return the C implementation of an Op.
Returns templated C code that does the computation associated Returns C code that does the computation associated to this L{Op},
to this L{Op}. You may assume that input validation and output given names for the inputs and outputs.
allocation have already been done.
@param inputs: list of strings. There is a string for each input @param inputs: list of strings. There is a string for each input
of the function, and the string is the name of a C of the function, and the string is the name of a C
...@@ -80,7 +85,7 @@ class Op(object2): ...@@ -80,7 +85,7 @@ class Op(object2):
'fail'). 'fail').
""" """
raise AbstractFunctionError('%s.c_code' \ raise utils.AbstractFunctionError('%s.c_code is not defined' \
% self.__class__.__name__) % self.__class__.__name__)
def c_code_cleanup(self, node, name, inputs, outputs, sub): def c_code_cleanup(self, node, name, inputs, outputs, sub):
...@@ -89,44 +94,33 @@ class Op(object2): ...@@ -89,44 +94,33 @@ class Op(object2):
This is a convenient place to clean up things allocated by c_code(). This is a convenient place to clean up things allocated by c_code().
""" """
raise AbstractFunctionError() raise utils.AbstractFunctionError()
def c_compile_args(self): def c_compile_args(self):
""" """
Return a list of compile args recommended to manipulate this L{Op}. Return a list of compile args recommended to manipulate this L{Op}.
""" """
raise AbstractFunctionError() raise utils.AbstractFunctionError()
def c_headers(self): def c_headers(self):
""" """
Return a list of header files that must be included from C to manipulate Return a list of header files that must be included from C to manipulate
this L{Op}. this L{Op}.
""" """
raise AbstractFunctionError() raise utils.AbstractFunctionError()
def c_libraries(self): def c_libraries(self):
""" """
Return a list of libraries to link against to manipulate this L{Op}. Return a list of libraries to link against to manipulate this L{Op}.
""" """
raise AbstractFunctionError() raise utils.AbstractFunctionError()
def c_support_code(self): def c_support_code(self):
""" """
Return utility code for use by this L{Op}. It may refer to support code Return utility code for use by this L{Op}. It may refer to support code
defined for its input L{Result}s. defined for its input L{Result}s.
""" """
raise AbstractFunctionError() raise utils.AbstractFunctionError()
class PropertiedOp(Op):
def __eq__(self, other):
return type(self) == type(other) and self.__dict__ == other.__dict__
def __str__(self):
if hasattr(self, 'name') and self.name:
return self.name
else:
return "%s{%s}" % (self.__class__.__name__, ", ".join("%s=%s" % (k, v) for k, v in self.__dict__.items() if k != "name"))
差异被折叠。
from random import shuffle
import utils
from functools import partial from functools import partial
import graph import graph
...@@ -14,51 +12,6 @@ class Bookkeeper: ...@@ -14,51 +12,6 @@ class Bookkeeper:
def on_detach(self, env): def on_detach(self, env):
for node in graph.io_toposort(env.inputs, env.outputs): for node in graph.io_toposort(env.inputs, env.outputs):
self.on_prune(env, node) self.on_prune(env, node)
# class Toposorter:
# def on_attach(self, env):
# if hasattr(env, 'toposort'):
# raise Exception("Toposorter feature is already present or in conflict with another plugin.")
# env.toposort = partial(self.toposort, env)
# def on_detach(self, env):
# del env.toposort
# def toposort(self, env):
# ords = {}
# for feature in env._features:
# if hasattr(feature, 'orderings'):
# for op, prereqs in feature.orderings(env).items():
# ords.setdefault(op, set()).update(prereqs)
# order = graph.io_toposort(env.inputs, env.outputs, ords)
# return order
# def supplemental_orderings(self):
# """
# Returns a dictionary of {op: set(prerequisites)} that must
# be satisfied in addition to the order defined by the structure
# of the graph (returns orderings that not related to input/output
# relationships).
# """
# ords = {}
# for feature in self._features:
# if hasattr(feature, 'orderings'):
# for op, prereqs in feature.orderings().items():
# ords.setdefault(op, set()).update(prereqs)
# return ords
# def toposort(self):
# """
# Returns a list of nodes in the order that they must be executed
# in order to preserve the semantics of the graph and respect
# the constraints put forward by the listeners.
# """
# ords = self.supplemental_orderings()
# order = graph.io_toposort(self.inputs, self.outputs, ords)
# return order
class History: class History:
...@@ -223,151 +176,3 @@ class PrintListener(object): ...@@ -223,151 +176,3 @@ class PrintListener(object):
# class EquivTool(dict):
# def __init__(self, env):
# self.env = env
# def on_rewire(self, clients, r, new_r):
# repl = self(new_r)
# if repl is r:
# self.ungroup(r, new_r)
# elif repl is not new_r:
# raise Exception("Improper use of EquivTool!")
# else:
# self.group(new_r, r)
# def publish(self):
# self.env.equiv = self
# self.env.set_equiv = self.set_equiv
# def unpublish(self):
# del self.env.equiv
# del self.env.set_equiv
# def set_equiv(self, d):
# self.update(d)
# def group(self, main, *keys):
# "Marks all the keys as having been replaced by the Result main."
# keys = [key for key in keys if key is not main]
# if self.has_key(main):
# raise Exception("Only group results that have not been grouped before.")
# for key in keys:
# if self.has_key(key):
# raise Exception("Only group results that have not been grouped before.")
# if key is main:
# continue
# self.setdefault(key, main)
# def ungroup(self, main, *keys):
# "Undoes group(main, *keys)"
# keys = [key for key in keys if key is not main]
# for key in keys:
# if self[key] is main:
# del self[key]
# def __call__(self, key):
# "Returns the currently active replacement for the given key."
# next = self.get(key, None)
# while next:
# key = next
# next = self.get(next, None)
# return key
# class InstanceFinder(Listener, Tool, dict):
# def __init__(self, env):
# self.env = env
# def all_bases(self, cls):
# return utils.all_bases(cls, lambda cls: cls is not object)
# def on_import(self, op):
# for base in self.all_bases(op.__class__):
# self.setdefault(base, set()).add(op)
# def on_prune(self, op):
# for base in self.all_bases(op.__class__):
# self[base].remove(op)
# if not self[base]:
# del self[base]
# def __query__(self, cls):
# all = [x for x in self.get(cls, [])]
# shuffle(all) # this helps a lot for debugging because the order of the replacements will vary
# while all:
# next = all.pop()
# if next in self.env.ops():
# yield next
# def query(self, cls):
# return self.__query__(cls)
# def publish(self):
# self.env.get_instances_of = self.query
# class DescFinder(Listener, Tool, dict):
# def __init__(self, env):
# self.env = env
# def on_import(self, op):
# self.setdefault(op.desc(), set()).add(op)
# def on_prune(self, op):
# desc = op.desc()
# self[desc].remove(op)
# if not self[desc]:
# del self[desc]
# def __query__(self, desc):
# all = [x for x in self.get(desc, [])]
# shuffle(all) # this helps for debugging because the order of the replacements will vary
# while all:
# next = all.pop()
# if next in self.env.ops():
# yield next
# def query(self, desc):
# return self.__query__(desc)
# def publish(self):
# self.env.get_from_desc = self.query
### UNUSED AND UNTESTED ###
# class ChangeListener(Listener):
# def __init__(self, env):
# self.change = False
# def on_import(self, op):
# self.change = True
# def on_prune(self, op):
# self.change = True
# def on_rewire(self, clients, r, new_r):
# self.change = True
# def __call__(self, value = "get"):
# if value == "get":
# return self.change
# else:
# self.change = value
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论