提交 646f4c01 authored 作者: Olivier Breuleux's avatar Olivier Breuleux

halfway through redoing env

上级 d27b991d
......@@ -2,7 +2,7 @@
from copy import copy
import graph
from features import Listener, Orderings, Constraint, Tool, uniq_features
##from features import Listener, Orderings, Constraint, Tool, uniq_features
import utils
from utils import AbstractFunctionError
......@@ -15,13 +15,6 @@ class InconsistencyError(Exception):
pass
def require_set(x):
try:
req = x.env_require
except AttributeError:
req = []
return req
class Env(graph.Graph):
"""
......@@ -35,10 +28,6 @@ class Env(graph.Graph):
result in the subgraph by another, e.g. replace (x + x).out by (2
* x).out. This is the basis for optimization in omega.
An Env's functionality can be extended with features, which must
be subclasses of L{Listener}, L{Constraint}, L{Orderings} or
L{Tool}.
Regarding inputs and orphans:
In the context of a computation graph, the inputs and orphans are
......@@ -50,160 +39,151 @@ class Env(graph.Graph):
### Special ###
def __init__(self, inputs, outputs): #, consistency_check = True):
def __init__(self, inputs, outputs):
"""
Create an Env which operates on the subgraph bound by the inputs and outputs
sets. If consistency_check is False, an illegal graph will be tolerated.
sets.
"""
# self._features = {}
# self._listeners = {}
# self._constraints = {}
# self._orderings = {}
# self._tools = {}
self._features = []
# The inputs and outputs set bound the subgraph this Env operates on.
self.inputs = list(inputs)
self.outputs = list(outputs)
# All nodes in the subgraph defined by inputs and outputs are cached in nodes
self.nodes = set()
# Ditto for results
self.results = set(self.inputs)
# Set of all the results that are not an output of an op in the subgraph but
# are an input of an op in the subgraph.
# e.g. z for inputs=(x, y) and outputs=(x + (y - z),)
# We initialize them to the set of outputs; if an output depends on an input,
# it will be removed from the set of orphans.
self.orphans = set(outputs).difference(inputs)
# for feature_class in uniq_features(features):
# self.add_feature(feature_class, False)
# Maps results to nodes that use them:
# if op.inputs[i] == v then (op, i) in self._clients[v]
self._clients = {}
# Ditto for results
self.results = set()
self.inputs = list(inputs)
for input in self.inputs:
if input.owner is not None:
raise ValueError("One of the provided inputs is the output of an already existing node. " \
"If that is okay, either discard that input's owner or use graph.clone.")
self.__setup_r__(input)
self.outputs = outputs
self.__import_r__(outputs)
# List of functions that undo the replace operations performed.
# e.g. to recover the initial graph one could write: for u in self.history.__reversed__(): u()
self.history = []
self.__import_r__(self.outputs)
# for op in self.nodes():
# self.satisfy(op)
# if consistency_check:
# self.validate()
### Public interface ###
### Setup a Result ###
def add_output(self, output):
"Add an output to the Env."
self.outputs.add(output)
self.orphans.add(output)
self.__import_r__([output])
def __setup_r__(self, r):
if hasattr(r, 'env') and r.env is not None and r.env is not self:
raise Exception("%s is already owned by another env" % r)
r.env = self
r.clients = []
def __setup_node__(self, node):
if hasattr(node, 'env') and node.env is not self:
raise Exception("%s is already owned by another env" % node)
node.env = self
node.deps = {}
### clients ###
def clients(self, r):
"Set of all the (op, i) pairs such that op.inputs[i] is r."
return self._clients.get(r, set())
return r.clients
def checkpoint(self):
"""
Returns an object that can be passed to self.revert in order to backtrack
to a previous state.
def __add_clients__(self, r, all):
"""
return len(self.history)
r -> result
all -> list of (op, i) pairs representing who r is an input of.
def consistent(self):
"""
Returns True iff the subgraph is consistent and does not violate the
constraints set by the listeners.
Updates the list of clients of r with all.
"""
try:
self.validate()
except InconsistencyError:
return False
return True
r.clients += all
# def satisfy(self, x):
# "Adds the features required by x unless they are already present."
# for feature_class in require_set(x):
# self.add_feature(feature_class)
def extend(self, feature, do_import = True, validate = False):
def __remove_clients__(self, r, all):
"""
@todo out of date
Adds an instance of the feature_class to this env's supported
features. If do_import is True and feature_class is a subclass
of Listener, its on_import method will be called on all the Nodes
already in the env.
r -> result
all -> list of (op, i) pairs representing who r is an input of.
Removes all from the clients list of r.
"""
if feature in self._features:
return # the feature is already present
self.__add_feature__(feature, do_import)
if validate:
self.validate()
for entry in all:
r.clients.remove(entry)
# remove from orphans?
def execute_callbacks(self, name, *args):
for feature in self._features:
try:
fn = getattr(feature, name)
except AttributeError:
continue
fn(*args)
def __add_feature__(self, feature, do_import):
self._features.append(feature)
publish = getattr(feature, 'publish', None)
if publish is not None:
publish()
if do_import:
try:
fn = feature.on_import
except AttributeError:
return
for node in self.io_toposort():
fn(node)
### import ###
def __del_feature__(self, feature):
try:
del self._features[feature]
except:
pass
unpublish = hasattr(feature, 'unpublish')
if unpublish is not None:
unpublish()
def __import_r__(self, results):
# Imports the owners of the results
for node in set(r.owner for r in results if r is not None):
self.__import__(node)
def get_feature(self, feature):
idx = self._features.index(feature)
return self._features[idx]
def __import__(self, node, check = True):
# We import the nodes in topological order. We only are interested
# in new nodes, so we use all results we know of as if they were the input set.
# (the functions in the graph module only use the input set to
# know where to stop going down)
new_nodes = graph.io_toposort(self.results, node.outputs)
if check:
for node in new_nodes:
if hasattr(node, 'env') and node.env is not self or \
any(hasattr(r, 'env') and r.env is not self or \
r.owner is None and not isinstance(r, Value) and r not in self.inputs
for r in node.inputs + node.outputs):
raise Exception("Could not import %s" % node)
for node in new_nodes:
self.__setup_node__(node)
self.nodes.add(node)
for output in node.outputs:
self.__setup_r__(output)
self.results.add(output)
for i, input in enumerate(node.inputs):
if input not in self.results:
if not isinstance(input, Value):
raise TypeError("The graph to import contains a leaf that is not an input and has no default value " \
"(graph state is bad now - use check = True)", input)
self.__setup_r__(input)
self.results.add(input)
self.__add_clients__(input, [(node, i)])
assert node.env is self
self.execute_callbacks('on_import', node)
def has_feature(self, feature):
return feature in self._features
def nclients(self, r):
"Same as len(self.clients(r))."
return len(self.clients(r))
### prune ###
def edge(self, r):
return r in self.inputs or r in self.orphans
def __prune_r__(self, results):
# Prunes the owners of the results.
for node in set(r.owner for r in results if r is not None):
self.__prune__(node)
for r in results:
if not r.clients and r in self.results:
self.results.remove(r)
def __prune__(self, node):
if node not in self.nodes:
raise Exception("%s does not belong to this Env and cannot be pruned." % node)
assert node.env is self
# If node's outputs have no clients, removes it from the graph
# and recursively tries to prune its inputs. If at least one
# of the op's outputs is an output to the graph or has a client
# then __prune__ is a no-op.
for output in node.outputs:
# Cannot prune an op which is an output or used somewhere
if self.clients(output) or output in self.outputs: #output in self.outputs or self.clients(output):
return
self.nodes.remove(node)
self.results.difference_update(node.outputs)
self.execute_callbacks('on_prune', node)
for i, input in enumerate(node.inputs):
self.__remove_clients__(input, [(node, i)])
self.__prune_r__(node.inputs)
def follow(self, r):
node = r.owner
if self.edge(r):
return None
else:
if node is None:
raise Exception("what the fuck")
return node.inputs
def has_node(self, node):
return node in self.nodes
### replace ###
def replace(self, r, new_r, consistency_check = True):
"""
......@@ -222,6 +202,8 @@ class Env(graph.Graph):
even if there is an inconsistency, unless the replacement
violates hard constraints on the types involved.
"""
if r.env is not self:
raise Exception("Cannot replace %s because it does not belong to this Env" % r)
assert r in self.results
# Save where we are so we can backtrack
......@@ -290,6 +272,108 @@ class Env(graph.Graph):
self.revert(chk)
raise
def checkpoint(self):
"""
Returns an object that can be passed to self.revert in order to backtrack
to a previous state.
"""
return len(self.history)
def consistent(self):
"""
Returns True iff the subgraph is consistent and does not violate the
constraints set by the listeners.
"""
try:
self.validate()
except InconsistencyError:
return False
return True
def extend(self, feature, do_import = True, validate = False):
"""
@todo out of date
Adds an instance of the feature_class to this env's supported
features. If do_import is True and feature_class is a subclass
of Listener, its on_import method will be called on all the Nodes
already in the env.
"""
if feature in self._features:
return # the feature is already present
self.__add_feature__(feature, do_import)
if validate:
self.validate()
def execute_callbacks(self, name, *args):
for feature in self._features:
try:
fn = getattr(feature, name)
except AttributeError:
continue
fn(*args)
def collect_callbacks(self, name, *args):
d = {}
for feature in self._features:
try:
fn = getattr(feature, name)
except AttributeError:
continue
d[feature] = fn(*args)
return d
def __add_feature__(self, feature, do_import):
self._features.append(feature)
publish = getattr(feature, 'publish', None)
if publish is not None:
publish()
if do_import:
try:
fn = feature.on_import
except AttributeError:
return
for node in self.io_toposort():
fn(node)
def __del_feature__(self, feature):
try:
del self._features[feature]
except:
pass
unpublish = hasattr(feature, 'unpublish')
if unpublish is not None:
unpublish()
def get_feature(self, feature):
idx = self._features.index(feature)
return self._features[idx]
def has_feature(self, feature):
return feature in self._features
def nclients(self, r):
"Same as len(self.clients(r))."
return len(self.clients(r))
def edge(self, r):
return r in self.inputs or r in self.orphans
def follow(self, r):
node = r.owner
if self.edge(r):
return None
else:
if node is None:
raise Exception("what the fuck")
return node.inputs
def has_node(self, node):
return node in self.nodes
def revert(self, checkpoint):
"""
Reverts the graph to whatever it was at the provided
......@@ -336,113 +420,6 @@ class Env(graph.Graph):
### Private interface ###
def __add_clients__(self, r, all):
"""
r -> result
all -> list of (op, i) pairs representing who r is an input of.
Updates the list of clients of r with all.
"""
self._clients.setdefault(r, set()).update(all)
def __remove_clients__(self, r, all):
"""
r -> result
all -> list of (op, i) pairs representing who r is an input of.
Removes all from the clients list of r.
"""
if not all:
return
self._clients[r].difference_update(all)
if not self._clients[r]:
del self._clients[r]
if r in self.orphans:
self.orphans.remove(r)
def __import_r_satisfy__(self, results):
# Satisfies the owners of the results.
for op in graph.ops(self.results, results):
self.satisfy(op)
def __import_r__(self, results):
# Imports the owners of the results
for result in results:
owner = result.owner
if owner:
self.__import__(result.owner)
if result not in self.results:
self.results.add(result)
self.orphans.add(result)
def __import__(self, op):
# We import the nodes in topological order. We only are interested
# in new nodes, so we use all results we know of as if they were the input set.
# (the functions in the graph module only use the input set to
# know where to stop going down)
new_nodes = graph.io_toposort(self.results.difference(self.orphans), op.outputs)
for op in new_nodes:
self.nodes.add(op)
self.results.update(op.outputs)
self.orphans.difference_update(op.outputs)
for i, input in enumerate(op.inputs):
self.__add_clients__(input, [(op, i)])
if input not in self.results:
# This input is an orphan because if the op that
# produced it was in the subgraph, io_toposort
# would have placed it before, so we would have
# seen it (or it would already be in the graph)
self.orphans.add(input)
self.results.add(input)
self.execute_callbacks('on_import', op)
# for listener in self._listeners.values():
# try:
# listener.on_import(op)
# except AbstractFunctionError:
# pass
__import__.E_output = 'op output in Env.inputs'
def __prune_r__(self, results):
# Prunes the owners of the results.
for result in set(results):
if result in self.inputs:
continue
owner = result.owner
if owner:
self.__prune__(owner)
# if result in self.results:
# self.results.remove(result)
# if result in self.orphans:
# self.orphans.remove(result)
def __prune__(self, op):
# If op's outputs have no clients, removes it from the graph
# and recursively tries to prune its inputs. If at least one
# of the op's outputs is an output to the graph or has a client
# then __prune__ is a no-op.
for output in op.outputs:
# Cannot prune an op which is an output or used somewhere
if self.clients(output) or output in self.outputs: #output in self.outputs or self.clients(output):
return
if op not in self.nodes: # this can happen from replacing an orphan
return
self.nodes.remove(op)
self.results.difference_update(op.outputs)
self.execute_callbacks('on_prune', op)
# for listener in self._listeners.values():
# try:
# listener.on_prune(op)
# except AbstractFunctionError:
# pass
for i, input in enumerate(op.inputs):
self.__remove_clients__(input, [(op, i)])
self.__prune_r__(op.inputs)
def __move_clients__(self, clients, r, new_r):
if not (r.type == new_r.type):
......
......@@ -6,11 +6,11 @@ from utils import object2
def deprecated(f):
printme = True
printme = [True]
def g(*args, **kwargs):
if printme:
if printme[0]:
print 'gof.graph.%s deprecated: April 29' % f.__name__
printme = False
printme[0] = False
return f(*args, **kwargs)
return g
......@@ -28,8 +28,6 @@ class Apply(object2):
for input in inputs:
if isinstance(input, Result):
self.inputs.append(input)
# elif isinstance(input, Type):
# self.inputs.append(Result(input, None, None))
else:
raise TypeError("The 'inputs' argument to Apply must contain Result instances, not %s" % input)
self.outputs = []
......@@ -42,12 +40,10 @@ class Apply(object2):
elif output.owner is not self or output.index != i:
raise ValueError("All output results passed to Apply must belong to it.")
self.outputs.append(output)
# elif isinstance(output, Type):
# self.outputs.append(Result(output, self, i))
else:
raise TypeError("The 'outputs' argument to Apply must contain Result instances with no owner, not %s" % output)
@deprecated
def default_output(self):
print 'default_output deprecated: April 29'
"""
Returns the default output for this Node, typically self.outputs[0].
Depends on the value of node.op.default_output
......@@ -66,6 +62,22 @@ class Apply(object2):
return str(self)
def __asapply__(self):
return self
def clone(self):
return self.__class__(self.op, self.inputs, [output.clone() for output in self.outputs])
def clone_with_new_inputs(self, inputs, check_type = True):
if check_type:
for curr, new in zip(self.inputs, inputs):
if not curr.type == new.type:
raise TypeError("Cannot change the type of this input.", curr, new)
new_node = self.clone()
new_node.inputs = inputs
# new_node.outputs = []
# for output in self.outputs:
# new_output = copy(output)
# new_output.owner = new_node
# new_node.outputs.append(new_output)
return new_node
nin = property(lambda self: len(self.inputs))
nout = property(lambda self: len(self.outputs))
......@@ -77,7 +89,6 @@ class Result(object2):
self.owner = owner
self.index = index
self.name = name
def __str__(self):
if self.name is not None:
return self.name
......@@ -88,22 +99,35 @@ class Result(object2):
else:
return str(self.owner.op) + "." + str(self.index)
else:
return "?::" + str(self.type)
return "<?>::" + str(self.type)
def __repr__(self):
return str(self)
@deprecated
def __asresult__(self):
return self
def clone(self):
return self.__class__(self.type, None, None, self.name)
class Constant(Result):
class Value(Result):
#__slots__ = ['data']
def __init__(self, type, data, name = None):
Result.__init__(self, type, None, None, name)
self.data = type.filter(data)
self.indestructible = True
def __str__(self):
if self.name is not None:
return self.name
return "<" + str(self.data) + ">" #+ "::" + str(self.type)
def clone(self):
return self.__class__(self.type, self.data)
class Constant(Value):
#__slots__ = ['data']
def __init__(self, type, data, name = None):
Value.__init__(self, type, data, name)
### self.indestructible = True
def equals(self, other):
# this does what __eq__ should do, but Result and Apply should always be hashable by id
return isinstance(other, Constant) and self.signature() == other.signature()
return type(other) == type(self) and self.signature() == other.signature()
def signature(self):
return (self.type, self.data)
def __str__(self):
......@@ -314,23 +338,12 @@ def clone_get_equiv(i, o, copy_inputs_and_orphans = False):
if node is None: # result is an orphan
if copy_inputs_and_orphans:
cpy = copy(result)
cpy.owner = None
cpy.index = None
d[result] = cpy
else:
d[result] = result
return d[result]
else:
new_node = copy(node)
new_node.inputs = [clone_helper(input) for input in node.inputs]
new_node.outputs = []
for output in node.outputs:
new_output = copy(output)
new_output.owner = new_node
new_node.outputs.append(new_output)
# new_node = Apply(node.op,
# [clone_helper(input) for input in node.inputs],
# [output.type for output in node.outputs])
new_node = node.clone_with_new_inputs([clone_helper(input) for input in node.inputs])
d[node] = new_node
for output, new_output in zip(node.outputs, new_node.outputs):
d[output] = new_output
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论