提交 646f4c01 authored 作者: Olivier Breuleux's avatar Olivier Breuleux

halfway through redoing env

上级 d27b991d
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
from copy import copy from copy import copy
import graph import graph
from features import Listener, Orderings, Constraint, Tool, uniq_features ##from features import Listener, Orderings, Constraint, Tool, uniq_features
import utils import utils
from utils import AbstractFunctionError from utils import AbstractFunctionError
...@@ -15,13 +15,6 @@ class InconsistencyError(Exception): ...@@ -15,13 +15,6 @@ class InconsistencyError(Exception):
pass pass
def require_set(x):
try:
req = x.env_require
except AttributeError:
req = []
return req
class Env(graph.Graph): class Env(graph.Graph):
""" """
...@@ -35,10 +28,6 @@ class Env(graph.Graph): ...@@ -35,10 +28,6 @@ class Env(graph.Graph):
result in the subgraph by another, e.g. replace (x + x).out by (2 result in the subgraph by another, e.g. replace (x + x).out by (2
* x).out. This is the basis for optimization in omega. * x).out. This is the basis for optimization in omega.
An Env's functionality can be extended with features, which must
be subclasses of L{Listener}, L{Constraint}, L{Orderings} or
L{Tool}.
Regarding inputs and orphans: Regarding inputs and orphans:
In the context of a computation graph, the inputs and orphans are In the context of a computation graph, the inputs and orphans are
...@@ -50,160 +39,151 @@ class Env(graph.Graph): ...@@ -50,160 +39,151 @@ class Env(graph.Graph):
### Special ### ### Special ###
def __init__(self, inputs, outputs): #, consistency_check = True): def __init__(self, inputs, outputs):
""" """
Create an Env which operates on the subgraph bound by the inputs and outputs Create an Env which operates on the subgraph bound by the inputs and outputs
sets. If consistency_check is False, an illegal graph will be tolerated. sets.
""" """
# self._features = {}
# self._listeners = {}
# self._constraints = {}
# self._orderings = {}
# self._tools = {}
self._features = [] self._features = []
# The inputs and outputs set bound the subgraph this Env operates on.
self.inputs = list(inputs)
self.outputs = list(outputs)
# All nodes in the subgraph defined by inputs and outputs are cached in nodes # All nodes in the subgraph defined by inputs and outputs are cached in nodes
self.nodes = set() self.nodes = set()
# Ditto for results
self.results = set(self.inputs)
# Set of all the results that are not an output of an op in the subgraph but # Ditto for results
# are an input of an op in the subgraph. self.results = set()
# e.g. z for inputs=(x, y) and outputs=(x + (y - z),)
# We initialize them to the set of outputs; if an output depends on an input,
# it will be removed from the set of orphans.
self.orphans = set(outputs).difference(inputs)
# for feature_class in uniq_features(features):
# self.add_feature(feature_class, False)
# Maps results to nodes that use them:
# if op.inputs[i] == v then (op, i) in self._clients[v]
self._clients = {}
self.inputs = list(inputs)
for input in self.inputs:
if input.owner is not None:
raise ValueError("One of the provided inputs is the output of an already existing node. " \
"If that is okay, either discard that input's owner or use graph.clone.")
self.__setup_r__(input)
self.outputs = outputs
self.__import_r__(outputs)
# List of functions that undo the replace operations performed. # List of functions that undo the replace operations performed.
# e.g. to recover the initial graph one could write: for u in self.history.__reversed__(): u() # e.g. to recover the initial graph one could write: for u in self.history.__reversed__(): u()
self.history = [] self.history = []
self.__import_r__(self.outputs)
# for op in self.nodes():
# self.satisfy(op)
# if consistency_check:
# self.validate()
### Public interface ### ### Setup a Result ###
def add_output(self, output): def __setup_r__(self, r):
"Add an output to the Env." if hasattr(r, 'env') and r.env is not None and r.env is not self:
self.outputs.add(output) raise Exception("%s is already owned by another env" % r)
self.orphans.add(output) r.env = self
self.__import_r__([output]) r.clients = []
def __setup_node__(self, node):
if hasattr(node, 'env') and node.env is not self:
raise Exception("%s is already owned by another env" % node)
node.env = self
node.deps = {}
### clients ###
def clients(self, r): def clients(self, r):
"Set of all the (op, i) pairs such that op.inputs[i] is r." "Set of all the (op, i) pairs such that op.inputs[i] is r."
return self._clients.get(r, set()) return r.clients
def checkpoint(self): def __add_clients__(self, r, all):
"""
Returns an object that can be passed to self.revert in order to backtrack
to a previous state.
""" """
return len(self.history) r -> result
all -> list of (op, i) pairs representing who r is an input of.
def consistent(self): Updates the list of clients of r with all.
"""
Returns True iff the subgraph is consistent and does not violate the
constraints set by the listeners.
""" """
try: r.clients += all
self.validate()
except InconsistencyError:
return False
return True
# def satisfy(self, x): def __remove_clients__(self, r, all):
# "Adds the features required by x unless they are already present."
# for feature_class in require_set(x):
# self.add_feature(feature_class)
def extend(self, feature, do_import = True, validate = False):
""" """
@todo out of date r -> result
Adds an instance of the feature_class to this env's supported all -> list of (op, i) pairs representing who r is an input of.
features. If do_import is True and feature_class is a subclass
of Listener, its on_import method will be called on all the Nodes Removes all from the clients list of r.
already in the env.
""" """
if feature in self._features: for entry in all:
return # the feature is already present r.clients.remove(entry)
self.__add_feature__(feature, do_import) # remove from orphans?
if validate:
self.validate()
def execute_callbacks(self, name, *args):
for feature in self._features:
try:
fn = getattr(feature, name)
except AttributeError:
continue
fn(*args)
def __add_feature__(self, feature, do_import): ### import ###
self._features.append(feature)
publish = getattr(feature, 'publish', None)
if publish is not None:
publish()
if do_import:
try:
fn = feature.on_import
except AttributeError:
return
for node in self.io_toposort():
fn(node)
def __del_feature__(self, feature): def __import_r__(self, results):
try: # Imports the owners of the results
del self._features[feature] for node in set(r.owner for r in results if r is not None):
except: self.__import__(node)
pass
unpublish = hasattr(feature, 'unpublish')
if unpublish is not None:
unpublish()
def get_feature(self, feature): def __import__(self, node, check = True):
idx = self._features.index(feature) # We import the nodes in topological order. We only are interested
return self._features[idx] # in new nodes, so we use all results we know of as if they were the input set.
# (the functions in the graph module only use the input set to
# know where to stop going down)
new_nodes = graph.io_toposort(self.results, node.outputs)
if check:
for node in new_nodes:
if hasattr(node, 'env') and node.env is not self or \
any(hasattr(r, 'env') and r.env is not self or \
r.owner is None and not isinstance(r, Value) and r not in self.inputs
for r in node.inputs + node.outputs):
raise Exception("Could not import %s" % node)
for node in new_nodes:
self.__setup_node__(node)
self.nodes.add(node)
for output in node.outputs:
self.__setup_r__(output)
self.results.add(output)
for i, input in enumerate(node.inputs):
if input not in self.results:
if not isinstance(input, Value):
raise TypeError("The graph to import contains a leaf that is not an input and has no default value " \
"(graph state is bad now - use check = True)", input)
self.__setup_r__(input)
self.results.add(input)
self.__add_clients__(input, [(node, i)])
assert node.env is self
self.execute_callbacks('on_import', node)
def has_feature(self, feature):
return feature in self._features
def nclients(self, r): ### prune ###
"Same as len(self.clients(r))."
return len(self.clients(r))
def edge(self, r): def __prune_r__(self, results):
return r in self.inputs or r in self.orphans # Prunes the owners of the results.
for node in set(r.owner for r in results if r is not None):
self.__prune__(node)
for r in results:
if not r.clients and r in self.results:
self.results.remove(r)
def __prune__(self, node):
if node not in self.nodes:
raise Exception("%s does not belong to this Env and cannot be pruned." % node)
assert node.env is self
# If node's outputs have no clients, removes it from the graph
# and recursively tries to prune its inputs. If at least one
# of the op's outputs is an output to the graph or has a client
# then __prune__ is a no-op.
for output in node.outputs:
# Cannot prune an op which is an output or used somewhere
if self.clients(output) or output in self.outputs: #output in self.outputs or self.clients(output):
return
self.nodes.remove(node)
self.results.difference_update(node.outputs)
self.execute_callbacks('on_prune', node)
for i, input in enumerate(node.inputs):
self.__remove_clients__(input, [(node, i)])
self.__prune_r__(node.inputs)
def follow(self, r):
node = r.owner
if self.edge(r):
return None
else:
if node is None:
raise Exception("what the fuck")
return node.inputs
def has_node(self, node):
return node in self.nodes ### replace ###
def replace(self, r, new_r, consistency_check = True): def replace(self, r, new_r, consistency_check = True):
""" """
...@@ -222,6 +202,8 @@ class Env(graph.Graph): ...@@ -222,6 +202,8 @@ class Env(graph.Graph):
even if there is an inconsistency, unless the replacement even if there is an inconsistency, unless the replacement
violates hard constraints on the types involved. violates hard constraints on the types involved.
""" """
if r.env is not self:
raise Exception("Cannot replace %s because it does not belong to this Env" % r)
assert r in self.results assert r in self.results
# Save where we are so we can backtrack # Save where we are so we can backtrack
...@@ -290,6 +272,108 @@ class Env(graph.Graph): ...@@ -290,6 +272,108 @@ class Env(graph.Graph):
self.revert(chk) self.revert(chk)
raise raise
def checkpoint(self):
"""
Returns an object that can be passed to self.revert in order to backtrack
to a previous state.
"""
return len(self.history)
def consistent(self):
"""
Returns True iff the subgraph is consistent and does not violate the
constraints set by the listeners.
"""
try:
self.validate()
except InconsistencyError:
return False
return True
def extend(self, feature, do_import = True, validate = False):
"""
@todo out of date
Adds an instance of the feature_class to this env's supported
features. If do_import is True and feature_class is a subclass
of Listener, its on_import method will be called on all the Nodes
already in the env.
"""
if feature in self._features:
return # the feature is already present
self.__add_feature__(feature, do_import)
if validate:
self.validate()
def execute_callbacks(self, name, *args):
for feature in self._features:
try:
fn = getattr(feature, name)
except AttributeError:
continue
fn(*args)
def collect_callbacks(self, name, *args):
d = {}
for feature in self._features:
try:
fn = getattr(feature, name)
except AttributeError:
continue
d[feature] = fn(*args)
return d
def __add_feature__(self, feature, do_import):
self._features.append(feature)
publish = getattr(feature, 'publish', None)
if publish is not None:
publish()
if do_import:
try:
fn = feature.on_import
except AttributeError:
return
for node in self.io_toposort():
fn(node)
def __del_feature__(self, feature):
try:
del self._features[feature]
except:
pass
unpublish = hasattr(feature, 'unpublish')
if unpublish is not None:
unpublish()
def get_feature(self, feature):
idx = self._features.index(feature)
return self._features[idx]
def has_feature(self, feature):
return feature in self._features
def nclients(self, r):
"Same as len(self.clients(r))."
return len(self.clients(r))
def edge(self, r):
return r in self.inputs or r in self.orphans
def follow(self, r):
node = r.owner
if self.edge(r):
return None
else:
if node is None:
raise Exception("what the fuck")
return node.inputs
def has_node(self, node):
return node in self.nodes
def revert(self, checkpoint): def revert(self, checkpoint):
""" """
Reverts the graph to whatever it was at the provided Reverts the graph to whatever it was at the provided
...@@ -336,113 +420,6 @@ class Env(graph.Graph): ...@@ -336,113 +420,6 @@ class Env(graph.Graph):
### Private interface ### ### Private interface ###
def __add_clients__(self, r, all):
"""
r -> result
all -> list of (op, i) pairs representing who r is an input of.
Updates the list of clients of r with all.
"""
self._clients.setdefault(r, set()).update(all)
def __remove_clients__(self, r, all):
"""
r -> result
all -> list of (op, i) pairs representing who r is an input of.
Removes all from the clients list of r.
"""
if not all:
return
self._clients[r].difference_update(all)
if not self._clients[r]:
del self._clients[r]
if r in self.orphans:
self.orphans.remove(r)
def __import_r_satisfy__(self, results):
# Satisfies the owners of the results.
for op in graph.ops(self.results, results):
self.satisfy(op)
def __import_r__(self, results):
# Imports the owners of the results
for result in results:
owner = result.owner
if owner:
self.__import__(result.owner)
if result not in self.results:
self.results.add(result)
self.orphans.add(result)
def __import__(self, op):
# We import the nodes in topological order. We only are interested
# in new nodes, so we use all results we know of as if they were the input set.
# (the functions in the graph module only use the input set to
# know where to stop going down)
new_nodes = graph.io_toposort(self.results.difference(self.orphans), op.outputs)
for op in new_nodes:
self.nodes.add(op)
self.results.update(op.outputs)
self.orphans.difference_update(op.outputs)
for i, input in enumerate(op.inputs):
self.__add_clients__(input, [(op, i)])
if input not in self.results:
# This input is an orphan because if the op that
# produced it was in the subgraph, io_toposort
# would have placed it before, so we would have
# seen it (or it would already be in the graph)
self.orphans.add(input)
self.results.add(input)
self.execute_callbacks('on_import', op)
# for listener in self._listeners.values():
# try:
# listener.on_import(op)
# except AbstractFunctionError:
# pass
__import__.E_output = 'op output in Env.inputs'
def __prune_r__(self, results):
# Prunes the owners of the results.
for result in set(results):
if result in self.inputs:
continue
owner = result.owner
if owner:
self.__prune__(owner)
# if result in self.results:
# self.results.remove(result)
# if result in self.orphans:
# self.orphans.remove(result)
def __prune__(self, op):
# If op's outputs have no clients, removes it from the graph
# and recursively tries to prune its inputs. If at least one
# of the op's outputs is an output to the graph or has a client
# then __prune__ is a no-op.
for output in op.outputs:
# Cannot prune an op which is an output or used somewhere
if self.clients(output) or output in self.outputs: #output in self.outputs or self.clients(output):
return
if op not in self.nodes: # this can happen from replacing an orphan
return
self.nodes.remove(op)
self.results.difference_update(op.outputs)
self.execute_callbacks('on_prune', op)
# for listener in self._listeners.values():
# try:
# listener.on_prune(op)
# except AbstractFunctionError:
# pass
for i, input in enumerate(op.inputs):
self.__remove_clients__(input, [(op, i)])
self.__prune_r__(op.inputs)
def __move_clients__(self, clients, r, new_r): def __move_clients__(self, clients, r, new_r):
if not (r.type == new_r.type): if not (r.type == new_r.type):
......
...@@ -6,11 +6,11 @@ from utils import object2 ...@@ -6,11 +6,11 @@ from utils import object2
def deprecated(f): def deprecated(f):
printme = True printme = [True]
def g(*args, **kwargs): def g(*args, **kwargs):
if printme: if printme[0]:
print 'gof.graph.%s deprecated: April 29' % f.__name__ print 'gof.graph.%s deprecated: April 29' % f.__name__
printme = False printme[0] = False
return f(*args, **kwargs) return f(*args, **kwargs)
return g return g
...@@ -28,8 +28,6 @@ class Apply(object2): ...@@ -28,8 +28,6 @@ class Apply(object2):
for input in inputs: for input in inputs:
if isinstance(input, Result): if isinstance(input, Result):
self.inputs.append(input) self.inputs.append(input)
# elif isinstance(input, Type):
# self.inputs.append(Result(input, None, None))
else: else:
raise TypeError("The 'inputs' argument to Apply must contain Result instances, not %s" % input) raise TypeError("The 'inputs' argument to Apply must contain Result instances, not %s" % input)
self.outputs = [] self.outputs = []
...@@ -42,12 +40,10 @@ class Apply(object2): ...@@ -42,12 +40,10 @@ class Apply(object2):
elif output.owner is not self or output.index != i: elif output.owner is not self or output.index != i:
raise ValueError("All output results passed to Apply must belong to it.") raise ValueError("All output results passed to Apply must belong to it.")
self.outputs.append(output) self.outputs.append(output)
# elif isinstance(output, Type):
# self.outputs.append(Result(output, self, i))
else: else:
raise TypeError("The 'outputs' argument to Apply must contain Result instances with no owner, not %s" % output) raise TypeError("The 'outputs' argument to Apply must contain Result instances with no owner, not %s" % output)
@deprecated
def default_output(self): def default_output(self):
print 'default_output deprecated: April 29'
""" """
Returns the default output for this Node, typically self.outputs[0]. Returns the default output for this Node, typically self.outputs[0].
Depends on the value of node.op.default_output Depends on the value of node.op.default_output
...@@ -66,6 +62,22 @@ class Apply(object2): ...@@ -66,6 +62,22 @@ class Apply(object2):
return str(self) return str(self)
def __asapply__(self): def __asapply__(self):
return self return self
def clone(self):
return self.__class__(self.op, self.inputs, [output.clone() for output in self.outputs])
def clone_with_new_inputs(self, inputs, check_type = True):
if check_type:
for curr, new in zip(self.inputs, inputs):
if not curr.type == new.type:
raise TypeError("Cannot change the type of this input.", curr, new)
new_node = self.clone()
new_node.inputs = inputs
# new_node.outputs = []
# for output in self.outputs:
# new_output = copy(output)
# new_output.owner = new_node
# new_node.outputs.append(new_output)
return new_node
nin = property(lambda self: len(self.inputs)) nin = property(lambda self: len(self.inputs))
nout = property(lambda self: len(self.outputs)) nout = property(lambda self: len(self.outputs))
...@@ -77,7 +89,6 @@ class Result(object2): ...@@ -77,7 +89,6 @@ class Result(object2):
self.owner = owner self.owner = owner
self.index = index self.index = index
self.name = name self.name = name
def __str__(self): def __str__(self):
if self.name is not None: if self.name is not None:
return self.name return self.name
...@@ -88,22 +99,35 @@ class Result(object2): ...@@ -88,22 +99,35 @@ class Result(object2):
else: else:
return str(self.owner.op) + "." + str(self.index) return str(self.owner.op) + "." + str(self.index)
else: else:
return "?::" + str(self.type) return "<?>::" + str(self.type)
def __repr__(self): def __repr__(self):
return str(self) return str(self)
@deprecated @deprecated
def __asresult__(self): def __asresult__(self):
return self return self
def clone(self):
return self.__class__(self.type, None, None, self.name)
class Constant(Result): class Value(Result):
#__slots__ = ['data'] #__slots__ = ['data']
def __init__(self, type, data, name = None): def __init__(self, type, data, name = None):
Result.__init__(self, type, None, None, name) Result.__init__(self, type, None, None, name)
self.data = type.filter(data) self.data = type.filter(data)
self.indestructible = True def __str__(self):
if self.name is not None:
return self.name
return "<" + str(self.data) + ">" #+ "::" + str(self.type)
def clone(self):
return self.__class__(self.type, self.data)
class Constant(Value):
#__slots__ = ['data']
def __init__(self, type, data, name = None):
Value.__init__(self, type, data, name)
### self.indestructible = True
def equals(self, other): def equals(self, other):
# this does what __eq__ should do, but Result and Apply should always be hashable by id # this does what __eq__ should do, but Result and Apply should always be hashable by id
return isinstance(other, Constant) and self.signature() == other.signature() return type(other) == type(self) and self.signature() == other.signature()
def signature(self): def signature(self):
return (self.type, self.data) return (self.type, self.data)
def __str__(self): def __str__(self):
...@@ -314,23 +338,12 @@ def clone_get_equiv(i, o, copy_inputs_and_orphans = False): ...@@ -314,23 +338,12 @@ def clone_get_equiv(i, o, copy_inputs_and_orphans = False):
if node is None: # result is an orphan if node is None: # result is an orphan
if copy_inputs_and_orphans: if copy_inputs_and_orphans:
cpy = copy(result) cpy = copy(result)
cpy.owner = None
cpy.index = None
d[result] = cpy d[result] = cpy
else: else:
d[result] = result d[result] = result
return d[result] return d[result]
else: else:
new_node = copy(node) new_node = node.clone_with_new_inputs([clone_helper(input) for input in node.inputs])
new_node.inputs = [clone_helper(input) for input in node.inputs]
new_node.outputs = []
for output in node.outputs:
new_output = copy(output)
new_output.owner = new_node
new_node.outputs.append(new_output)
# new_node = Apply(node.op,
# [clone_helper(input) for input in node.inputs],
# [output.type for output in node.outputs])
d[node] = new_node d[node] = new_node
for output, new_output in zip(node.outputs, new_node.outputs): for output, new_output in zip(node.outputs, new_node.outputs):
d[output] = new_output d[output] = new_output
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论