提交 65e08101 authored 作者: Olivier Breuleux's avatar Olivier Breuleux

env redone, toolbox redone

上级 646f4c01
......@@ -2,13 +2,15 @@
import unittest
from type import Type
import graph
from graph import Result, as_result, Apply
from op import Op
from opt import PatternOptimizer, OpSubOptimizer
from ext import *
from env import Env, InconsistencyError
from toolbox import EquivTool
#from toolbox import EquivTool
from toolbox import ReplaceValidate
from copy import copy
......@@ -65,8 +67,11 @@ def inputs():
_Env = Env
def Env(inputs, outputs, validate = True):
e = _Env(inputs, outputs)
e.extend(EquivTool(e))
e.extend(DestroyHandler(e), validate = validate)
##e.extend(EquivTool(e))
e.extend(DestroyHandler())
e.extend(ReplaceValidate())
if validate:
e.validate()
return e
......@@ -108,19 +113,19 @@ class _test_all(unittest.TestCase):
g = Env([x,y,z], [e1, e2])
chk = g.checkpoint()
assert g.consistent()
g.replace(e1, add_in_place(x, y))
g.replace_validate(e1, add_in_place(x, y))
assert g.consistent()
try:
g.replace(e2, add_in_place(y, x))
g.replace_validate(e2, add_in_place(y, x))
self.fail()
except InconsistencyError:
pass
assert g.consistent()
g.revert(chk)
g.replace(e2, add_in_place(y, x))
g.replace_validate(e2, add_in_place(y, x))
assert g.consistent()
try:
g.replace(e1, add_in_place(x, y))
g.replace_validate(e1, add_in_place(x, y))
self.fail()
except InconsistencyError:
pass
......@@ -136,7 +141,7 @@ class _test_all(unittest.TestCase):
assert str(g) != "[Dot(Dot(AddInPlace(x, y), AddInPlace(y, z)), AddInPlace(z, x))]" # we don't want to see that!
e2 = dot(dot(add_in_place(x,y), add_in_place(y,z)), add_in_place(z,x))
try:
g2 = Env([x,y,z], [e2])
g2 = Env(*graph.clone([x,y,z], [e2]))
self.fail()
except InconsistencyError:
pass
......@@ -154,16 +159,18 @@ class _test_all(unittest.TestCase):
e = dot(aip, transpose_view(x))
g = Env([x,y,z], [e], False)
assert not g.consistent()
g.replace(aip, add(x, z))
g.replace_validate(aip, add(x, z))
assert g.consistent()
def test_usage_loop_through_views_2(self):
x, y, z = inputs()
e0 = transpose_view(transpose_view(transpose_view(sigmoid(x))))
e0 = transpose_view(transpose_view(sigmoid(x)))
e = dot(add_in_place(x,y), transpose_view(e0))
g = Env([x,y,z], [e])
assert g.consistent() # because sigmoid can do the copy
g.replace(e0, x, False)
# print g
# print g.destroy_handler.children
g.replace(e0, x)
assert not g.consistent() # we cut off the path to the sigmoid
def test_usage_loop_insert_views(self):
......@@ -184,10 +191,10 @@ class _test_all(unittest.TestCase):
chk = g.checkpoint()
PatternOptimizer((transpose_view, (transpose_view, 'x')), 'x').optimize(g)
assert str(g) == "[x]"
g.replace(g.equiv(e), add(x,y))
print g
new_e = add(x,y)
g.replace_validate(x, new_e)
assert str(g) == "[Add(x, y)]"
g.replace(g.equiv(e), dot(add_in_place(x,y), transpose_view(x)), False)
g.replace(new_e, dot(add_in_place(x,y), transpose_view(x)))
assert str(g) == "[Dot(AddInPlace(x, y), TransposeView(x))]"
assert not g.consistent()
g.revert(chk)
......@@ -202,7 +209,7 @@ class _test_all(unittest.TestCase):
e = add_in_place(x, y)
g = Env([x,y,z], [e], False)
assert not g.consistent()
g.replace(e, add(x, y))
g.replace_validate(e, add(x, y))
assert g.consistent()
def test_indestructible_through_views(self):
......@@ -212,7 +219,7 @@ class _test_all(unittest.TestCase):
e = add_in_place(tv, y)
g = Env([x,y,z], [e], False)
assert not g.consistent()
g.replace(tv, sigmoid(x))
g.replace_validate(tv, sigmoid(x))
assert g.consistent()
def test_repair_destroy_path(self):
......@@ -223,7 +230,7 @@ class _test_all(unittest.TestCase):
e4 = add_in_place(e1, z)
g = Env([x,y,z], [e3, e4], False)
assert not g.consistent()
g.replace(e2, transpose_view(x), False)
g.replace(e2, transpose_view(x))
assert not g.consistent()
def test_indirect(self):
......@@ -233,9 +240,9 @@ class _test_all(unittest.TestCase):
g = Env([x,y,z], [e], False)
assert not g.consistent()
new_e0 = add(x, y)
g.replace(e0, new_e0, False)
g.replace(e0, new_e0)
assert g.consistent()
g.replace(new_e0, add_in_place(x, y), False)
g.replace(new_e0, add_in_place(x, y))
assert not g.consistent()
def test_indirect_2(self):
......@@ -245,12 +252,12 @@ class _test_all(unittest.TestCase):
g = Env([x,y,z], [e], False)
assert not g.consistent()
new_e0 = add(e0, y)
g.replace(e0, new_e0, False)
g.replace(e0, new_e0)
assert g.consistent()
if __name__ == '__main__':
unittest.main()
#unittest.main()
_test_all('test_usage_loop_through_views').debug()
......@@ -59,19 +59,19 @@ def inputs():
return x, y, z
class _test_EquivTool(unittest.TestCase):
# class _test_EquivTool(unittest.TestCase):
def test_straightforward(self):
x, y, z = inputs()
sx = sigmoid(x)
e = add(sx, sigmoid(y))
g = Env([x, y, z], [e])
g.extend(EquivTool(g))
assert hasattr(g, 'equiv')
assert g.equiv(sx) is sx
g.replace(sx, dot(x, z))
assert g.equiv(sx) is not sx
assert g.equiv(sx).owner.op is dot
# def test_straightforward(self):
# x, y, z = inputs()
# sx = sigmoid(x)
# e = add(sx, sigmoid(y))
# g = Env([x, y, z], [e])
# g.extend(EquivTool(g))
# assert hasattr(g, 'equiv')
# assert g.equiv(sx) is sx
# g.replace(sx, dot(x, z))
# assert g.equiv(sx) is not sx
# assert g.equiv(sx).owner.op is dot
class _test_NodeFinder(unittest.TestCase):
......@@ -81,7 +81,7 @@ class _test_NodeFinder(unittest.TestCase):
e0 = dot(y, z)
e = add(add(sigmoid(x), sigmoid(sigmoid(z))), dot(add(x, y), e0))
g = Env([x, y, z], [e])
g.extend(NodeFinder(g))
g.extend(NodeFinder())
assert hasattr(g, 'get_nodes')
for type, num in ((add, 3), (sigmoid, 3), (dot, 2)):
if not len([x for x in g.get_nodes(type)]) == num:
......@@ -100,7 +100,7 @@ class _test_NodeFinder(unittest.TestCase):
x, y, z = inputs()
e = add(add(sigmoid(x), sigmoid(sigmoid(z))), dot(add(x, y), dot(y, z)))
g = Env([x, y, z], [e])
g.extend(NodeFinder(g))
g.extend(NodeFinder())
gen = g.get_nodes(sigmoid) # I want to get Sigmoid instances
g.replace(e, add(x, y)) # but here I prune them all
assert len([x for x in gen]) == 0 # the generator should not yield them
......
......@@ -16,7 +16,7 @@ class InconsistencyError(Exception):
class Env(graph.Graph):
class Env(object): #(graph.Graph):
"""
An Env represents a subgraph bound by a set of input results and a
set of output results. An op is in the subgraph iff it depends on
......@@ -59,14 +59,19 @@ class Env(graph.Graph):
raise ValueError("One of the provided inputs is the output of an already existing node. " \
"If that is okay, either discard that input's owner or use graph.clone.")
self.__setup_r__(input)
self.results.add(input)
self.outputs = outputs
self.__import_r__(outputs)
self.outputs = outputs
for i, output in enumerate(outputs):
output.clients.append(('output', i))
self.node_locks = {}
self.result_locks = {}
# List of functions that undo the replace operations performed.
# e.g. to recover the initial graph one could write: for u in self.history.__reversed__(): u()
self.history = []
# # List of functions that undo the replace operations performed.
# # e.g. to recover the initial graph one could write: for u in self.history.__reversed__(): u()
# self.history = []
### Setup a Result ###
......@@ -99,7 +104,7 @@ class Env(graph.Graph):
"""
r.clients += all
def __remove_clients__(self, r, all):
def __remove_clients__(self, r, all, prune = True):
"""
r -> result
all -> list of (op, i) pairs representing who r is an input of.
......@@ -109,14 +114,24 @@ class Env(graph.Graph):
for entry in all:
r.clients.remove(entry)
# remove from orphans?
if not r.clients:
if prune:
self.__prune_r__([r])
return False
return True
return False
### import ###
def __import_r__(self, results):
# Imports the owners of the results
for node in set(r.owner for r in results if r is not None):
for node in set(r.owner for r in results if r.owner is not None):
self.__import__(node)
for r in results:
if r.owner is None and not isinstance(r, graph.Value) and r not in self.inputs:
raise TypeError("Undeclared input", r)
self.results.add(r)
def __import__(self, node, check = True):
# We import the nodes in topological order. We only are interested
......@@ -127,11 +142,13 @@ class Env(graph.Graph):
if check:
for node in new_nodes:
if hasattr(node, 'env') and node.env is not self or \
any(hasattr(r, 'env') and r.env is not self or \
r.owner is None and not isinstance(r, Value) and r not in self.inputs
for r in node.inputs + node.outputs):
raise Exception("Could not import %s" % node)
if hasattr(node, 'env') and node.env is not self:
raise Exception("%s is already owned by another env" % node)
for r in node.inputs:
if hasattr(r, 'env') and r.env is not self:
raise Exception("%s is already owned by another env" % r)
if r.owner is None and not isinstance(r, graph.Value) and r not in self.inputs:
raise TypeError("Undeclared input", r)
for node in new_nodes:
self.__setup_node__(node)
......@@ -141,9 +158,6 @@ class Env(graph.Graph):
self.results.add(output)
for i, input in enumerate(node.inputs):
if input not in self.results:
if not isinstance(input, Value):
raise TypeError("The graph to import contains a leaf that is not an input and has no default value " \
"(graph state is bad now - use check = True)", input)
self.__setup_r__(input)
self.results.add(input)
self.__add_clients__(input, [(node, i)])
......@@ -155,7 +169,7 @@ class Env(graph.Graph):
def __prune_r__(self, results):
# Prunes the owners of the results.
for node in set(r.owner for r in results if r is not None):
for node in set(r.owner for r in results if r.owner is not None):
self.__prune__(node)
for r in results:
if not r.clients and r in self.results:
......@@ -179,78 +193,99 @@ class Env(graph.Graph):
for i, input in enumerate(node.inputs):
self.__remove_clients__(input, [(node, i)])
self.__prune_r__(node.inputs)
#self.__prune_r__(node.inputs)
### replace ###
### change input ###
def change_input(self, node, i, new_r):
if node == 'output':
r = self.outputs[i]
if not r.type == new_r.type:
raise TypeError("The type of the replacement must be the same as the type of the original Result.", r, new_r)
self.outputs[i] = new_r
else:
if node.env is not self:
raise Exception("Cannot operate on %s because it does not belong to this Env" % node)
r = node.inputs[i]
if not r.type == new_r.type:
raise TypeError("The type of the replacement must be the same as the type of the original Result.", r, new_r)
node.inputs[i] = new_r
self.__import_r__([new_r])
self.__add_clients__(new_r, [(node, i)])
prune = self.__remove_clients__(r, [(node, i)], False)
self.execute_callbacks('on_change_input', node, i, r, new_r)
if prune:
self.__prune_r__([r])
def replace(self, r, new_r, consistency_check = True):
### replace ###
def replace(self, r, new_r):
"""
This is the main interface to manipulate the subgraph in Env.
For every op that uses r as input, makes it use new_r instead.
This may raise an error if the new result violates type
constraints for one of the target nodes. In that case, no
changes are made.
If the replacement makes the graph inconsistent and the value
of consistency_check is True, this function will raise an
InconsistencyError and will undo the operation, leaving the
graph the way it was before the call to replace.
If consistency_check is False, the replacement will succeed
even if there is an inconsistency, unless the replacement
violates hard constraints on the types involved.
"""
if r.env is not self:
raise Exception("Cannot replace %s because it does not belong to this Env" % r)
if not r.type == new_r.type:
raise TypeError("The type of the replacement must be the same as the type of the original Result.", r, new_r)
assert r in self.results
# Save where we are so we can backtrack
if consistency_check:
chk = self.checkpoint()
for node, i in r.clients:
assert node == 'output' and self.outputs[i] is r or node.inputs[i] is r
self.change_input(node, i, new_r)
# # Save where we are so we can backtrack
# if consistency_check:
# chk = self.checkpoint()
# The copy is required so undo can know what clients to move back!
clients = copy(self.clients(r))
# # The copy is required so undo can know what clients to move back!
# clients = copy(self.clients(r))
# Messy checks so we know what to do if we are replacing an output
# result. Note that if v is an input result, we do nothing at all for
# now (it's not clear what it means to replace an input result).
was_output = False
if r in self.outputs:
was_output = True
self.outputs[self.outputs.index(r)] = new_r
# # Messy checks so we know what to do if we are replacing an output
# # result. Note that if v is an input result, we do nothing at all for
# # now (it's not clear what it means to replace an input result).
# was_output = False
# if r in self.outputs:
# was_output = True
# self.outputs[self.outputs.index(r)] = new_r
was_input = False
if r in self.inputs:
was_input = True
self.inputs[self.inputs.index(r)] = new_r
# was_input = False
# if r in self.inputs:
# was_input = True
# self.inputs[self.inputs.index(r)] = new_r
# The actual replacement operation occurs here. This might raise
# an error.
self.__move_clients__(clients, r, new_r) # not sure how to order this wrt to adjusting the outputs
# # The actual replacement operation occurs here. This might raise
# # an error.
# self.__move_clients__(clients, r, new_r) # not sure how to order this wrt to adjusting the outputs
# This function undoes the replacement.
def undo():
# Restore self.outputs
if was_output:
self.outputs[self.outputs.index(new_r)] = r
# # This function undoes the replacement.
# def undo():
# # Restore self.outputs
# if was_output:
# self.outputs[self.outputs.index(new_r)] = r
# Restore self.inputs
if was_input:
self.inputs[self.inputs.index(new_r)] = r
# # Restore self.inputs
# if was_input:
# self.inputs[self.inputs.index(new_r)] = r
# Move back the clients. This should never raise an error.
self.__move_clients__(clients, new_r, r)
# # Move back the clients. This should never raise an error.
# self.__move_clients__(clients, new_r, r)
self.history.append(undo)
# self.history.append(undo)
if consistency_check:
try:
self.validate()
except InconsistencyError, e:
self.revert(chk)
raise
# if consistency_check:
# try:
# self.validate()
# except InconsistencyError, e:
# self.revert(chk)
# raise
def replace_all(self, d):
"""
......@@ -259,42 +294,47 @@ class Env(graph.Graph):
graph is not consistent. If an error is raised, the graph is
restored to what it was before.
"""
chk = self.checkpoint()
try:
for r, new_r in d.items():
self.replace(r, new_r, False)
except Exception, e:
self.revert(chk)
raise
try:
self.validate()
except InconsistencyError, e:
self.revert(chk)
raise
for r, new_r in d.items():
self.replace(r, new_r, False)
# chk = self.checkpoint()
# try:
# for r, new_r in d.items():
# self.replace(r, new_r, False)
# except Exception, e:
# self.revert(chk)
# raise
# try:
# self.validate()
# except InconsistencyError, e:
# self.revert(chk)
# raise
def checkpoint(self):
"""
Returns an object that can be passed to self.revert in order to backtrack
to a previous state.
"""
return len(self.history)
# def checkpoint(self):
# """
# Returns an object that can be passed to self.revert in order to backtrack
# to a previous state.
# """
# return len(self.history)
def consistent(self):
"""
Returns True iff the subgraph is consistent and does not violate the
constraints set by the listeners.
"""
try:
self.validate()
except InconsistencyError:
return False
return True
# def consistent(self):
# """
# Returns True iff the subgraph is consistent and does not violate the
# constraints set by the listeners.
# """
# try:
# self.validate()
# except InconsistencyError:
# return False
# return True
def extend(self, feature, do_import = True, validate = False):
### features ###
def extend(self, feature):
"""
@todo out of date
Adds an instance of the feature_class to this env's supported
......@@ -304,17 +344,34 @@ class Env(graph.Graph):
"""
if feature in self._features:
return # the feature is already present
self.__add_feature__(feature, do_import)
if validate:
self.validate()
self._features.append(feature)
attach = getattr(feature, 'on_attach', None)
if attach is not None:
try:
attach(self)
except:
self._features.pop()
raise
def remove_feature(self, feature):
try:
self._features.remove(feature)
except:
return
deattach = getattr(feature, 'on_deattach', None)
if deattach is not None:
deattach(self)
### callback utils ###
def execute_callbacks(self, name, *args):
for feature in self._features:
try:
fn = getattr(feature, name)
except AttributeError:
continue
fn(*args)
fn(self, *args)
def collect_callbacks(self, name, *args):
d = {}
......@@ -326,35 +383,9 @@ class Env(graph.Graph):
d[feature] = fn(*args)
return d
def __add_feature__(self, feature, do_import):
self._features.append(feature)
publish = getattr(feature, 'publish', None)
if publish is not None:
publish()
if do_import:
try:
fn = feature.on_import
except AttributeError:
return
for node in self.io_toposort():
fn(node)
def __del_feature__(self, feature):
try:
del self._features[feature]
except:
pass
unpublish = hasattr(feature, 'unpublish')
if unpublish is not None:
unpublish()
def get_feature(self, feature):
idx = self._features.index(feature)
return self._features[idx]
def has_feature(self, feature):
return feature in self._features
### misc ###
def nclients(self, r):
"Same as len(self.clients(r))."
return len(self.clients(r))
......@@ -374,114 +405,156 @@ class Env(graph.Graph):
def has_node(self, node):
return node in self.nodes
def revert(self, checkpoint):
"""
Reverts the graph to whatever it was at the provided
checkpoint (undoes all replacements). A checkpoint at any
given time can be obtained using self.checkpoint().
"""
while len(self.history) > checkpoint:
f = self.history.pop()
f()
def check_integrity(self):
nodes = graph.ops(self.inputs, self.outputs)
if self.nodes != nodes:
missing = nodes.difference(self.nodes)
excess = self.nodes.difference(nodes)
raise Exception("The nodes are inappropriately cached. missing, in excess: ", missing, excess)
for node in nodes:
if node.env is not self:
raise Exception("Node should belong to the env.", node)
for i, result in enumerate(node.inputs):
if result.env is not self:
raise Exception("Input of node should belong to the env.", result, (node, i))
if (node, i) not in result.clients:
raise Exception("Inconsistent clients list.", (node, i), result.clients)
results = graph.results(self.inputs, self.outputs)
if self.results != results:
missing = results.difference(self.results)
excess = self.results.difference(results)
raise Exception("The results are inappropriately cached. missing, in excess: ", missing, excess)
for result in results:
if result.owner is None and result not in self.inputs and not isinstance(result, graph.Value):
raise Exception("Undeclared input.", result)
if result.env is not self:
raise Exception("Result should belong to the env.", result)
for node, i in result.clients:
if node == 'output':
if self.outputs[i] is not result:
raise Exception("Inconsistent clients list.", result, self.outputs[i])
continue
if node not in nodes:
raise Exception("Client not in env.", result, (node, i))
if node.inputs[i] is not result:
raise Exception("Inconsistent clients list.", result, node.inputs[i])
def supplemental_orderings(self):
"""
Returns a dictionary of {op: set(prerequisites)} that must
be satisfied in addition to the order defined by the structure
of the graph (returns orderings that not related to input/output
relationships).
"""
ords = {}
for feature in self._features:
if hasattr(feature, 'orderings'):
for op, prereqs in feature.orderings().items():
ords.setdefault(op, set()).update(prereqs)
return ords
def toposort(self):
"""
Returns a list of nodes in the order that they must be executed
in order to preserve the semantics of the graph and respect
the constraints put forward by the listeners.
"""
ords = self.supplemental_orderings()
order = graph.io_toposort(self.inputs, self.outputs, ords)
return order
def validate(self):
"""
Raises an error if the graph is inconsistent.
"""
self.execute_callbacks('validate')
# for constraint in self._constraints.values():
# constraint.validate()
return True
# def revert(self, checkpoint):
# """
# Reverts the graph to whatever it was at the provided
# checkpoint (undoes all replacements). A checkpoint at any
# given time can be obtained using self.checkpoint().
# """
# while len(self.history) > checkpoint:
# f = self.history.pop()
# f()
# def supplemental_orderings(self):
# """
# Returns a dictionary of {op: set(prerequisites)} that must
# be satisfied in addition to the order defined by the structure
# of the graph (returns orderings that not related to input/output
# relationships).
# """
# ords = {}
# for feature in self._features:
# if hasattr(feature, 'orderings'):
# for op, prereqs in feature.orderings().items():
# ords.setdefault(op, set()).update(prereqs)
# return ords
# def toposort(self):
# """
# Returns a list of nodes in the order that they must be executed
# in order to preserve the semantics of the graph and respect
# the constraints put forward by the listeners.
# """
# ords = self.supplemental_orderings()
# order = graph.io_toposort(self.inputs, self.outputs, ords)
# return order
# def validate(self):
# """
# Raises an error if the graph is inconsistent.
# """
# self.execute_callbacks('validate')
# # for constraint in self._constraints.values():
# # constraint.validate()
# return True
### Private interface ###
def __move_clients__(self, clients, r, new_r):
# def __move_clients__(self, clients, r, new_r):
if not (r.type == new_r.type):
raise TypeError("Cannot move clients between Results that have different types.", r, new_r)
# if not (r.type == new_r.type):
# raise TypeError("Cannot move clients between Results that have different types.", r, new_r)
# We import the new result in the fold
self.__import_r__([new_r])
for op, i in clients:
op.inputs[i] = new_r
# try:
# # Try replacing the inputs
# for op, i in clients:
# op.set_input(i, new_r)
# except:
# # Oops!
# for op, i in clients:
# op.set_input(i, r)
# self.__prune_r__([new_r])
# raise
self.__remove_clients__(r, clients)
self.__add_clients__(new_r, clients)
# # We import the new result in the fold
# # why was this line AFTER the set_inputs???
# # if we do it here then satisfy in import fucks up...
# self.__import_r__([new_r])
self.execute_callbacks('on_rewire', clients, r, new_r)
# for listener in self._listeners.values():
# try:
# listener.on_rewire(clients, r, new_r)
# except AbstractFunctionError:
# pass
# We try to get rid of the old one
self.__prune_r__([r])
# for op, i in clients:
# op.inputs[i] = new_r
# # try:
# # # Try replacing the inputs
# # for op, i in clients:
# # op.set_input(i, new_r)
# # except:
# # # Oops!
# # for op, i in clients:
# # op.set_input(i, r)
# # self.__prune_r__([new_r])
# # raise
# self.__remove_clients__(r, clients)
# self.__add_clients__(new_r, clients)
# # # We import the new result in the fold
# # # why was this line AFTER the set_inputs???
# # # if we do it here then satisfy in import fucks up...
# # self.__import_r__([new_r])
# self.execute_callbacks('on_rewire', clients, r, new_r)
# # for listener in self._listeners.values():
# # try:
# # listener.on_rewire(clients, r, new_r)
# # except AbstractFunctionError:
# # pass
# # We try to get rid of the old one
# self.__prune_r__([r])
def __str__(self):
return "[%s]" % ", ".join(graph.as_string(self.inputs, self.outputs))
def clone_get_equiv(self, clone_inputs = True):
equiv = graph.clone_get_equiv(self.inputs, self.outputs, clone_inputs)
new = self.__class__([equiv[input] for input in self.inputs],
[equiv[output] for output in self.outputs])
for feature in self._features:
new.extend(feature)
return new, equiv
# def clone_get_equiv(self, clone_inputs = True):
# equiv = graph.clone_get_equiv(self.inputs, self.outputs, clone_inputs)
# new = self.__class__([equiv[input] for input in self.inputs],
# [equiv[output] for output in self.outputs])
# for feature in self._features:
# new.extend(feature)
# return new, equiv
# def clone(self, clone_inputs = True):
# equiv = graph.clone_get_equiv(self.inputs, self.outputs, clone_inputs)
# new = self.__class__([equiv[input] for input in self.inputs],
# [equiv[output] for output in self.outputs])
# for feature in self._features:
# new.extend(feature)
# try:
# new.set_equiv(equiv)
# except AttributeError:
# pass
# return new
# def __copy__(self):
# return self.clone()
def clone(self, clone_inputs = True):
equiv = graph.clone_get_equiv(self.inputs, self.outputs, clone_inputs)
new = self.__class__([equiv[input] for input in self.inputs],
[equiv[output] for output in self.outputs])
for feature in self._features:
new.extend(feature)
try:
new.set_equiv(equiv)
except AttributeError:
pass
return new
def __copy__(self):
return self.clone()
from features import Listener, Constraint, Orderings, Tool
#from features import Listener, Constraint, Orderings, Tool
import utils
from utils import AbstractFunctionError
from copy import copy
from env import InconsistencyError
from toolbox import Bookkeeper
class DestroyHandler(Listener, Constraint, Orderings, Tool):
from collections import defaultdict
class DestroyHandler(Bookkeeper): #(Listener, Constraint, Orderings, Tool):
"""
This feature ensures that an env represents a consistent data flow
when some Ops overwrite their inputs and/or provide "views" over
......@@ -27,14 +35,32 @@ class DestroyHandler(Listener, Constraint, Orderings, Tool):
This feature allows some optimizations (eg sub += for +) to be applied
safely.
"""
def __init__(self, env):
def __init__(self):
self.env = None
def on_attach(self, env):
if self.env is not None:
raise Exception("A DestroyHandler instance can only serve one Env.")
for attr in ('destroyers', 'destroy_handler'):
if hasattr(env, attr):
raise Exception("DestroyHandler feature is already present or in conflict with another plugin.")
def __destroyers(r):
ret = self.destroyers.get(r, {})
ret = ret.keys()
return ret
env.destroyers = __destroyers
env.destroy_handler = self
self.env = env
# For an Op that has a view_map, {output : input it is a view of}
self.parent = {}
# Reverse mapping of parent: {input : outputs that are a view of it}
self.children = {}
self.children = defaultdict(set)
# {foundation : {op that destroys it : path }}
# where foundation is a result such that (not self.parent[result])
......@@ -57,25 +83,37 @@ class DestroyHandler(Listener, Constraint, Orderings, Tool):
# indestructible by the user.
self.illegal = set()
self.env = env
self.seen = set()
# Initialize the children if the inputs and orphans.
for input in env.orphans.union(env.inputs):
self.children[input] = set()
def publish(self):
"""
Publishes the following on the env:
- destroyers(r) -> returns all L{Op}s that destroy the result r
- destroy_handler -> self
"""
def __destroyers(r):
ret = self.destroyers.get(r, {})
ret = ret.keys()
return ret
self.env.destroyers = __destroyers
self.env.destroy_handler = self
Bookkeeper.on_attach(self, env)
# # Initialize the children if the inputs and orphans.
# for input in env.inputs: # env.orphans.union(env.inputs):
# self.children[input] = set()
def on_detach(self, env):
del self.parent
del self.children
del self.destroyers
del self.paths
del self.dups
del self.cycles
del self.illegal
del self.seen
self.env = None
# def publish(self):
# """
# Publishes the following on the env:
# - destroyers(r) -> returns all L{Op}s that destroy the result r
# - destroy_handler -> self
# """
# def __destroyers(r):
# ret = self.destroyers.get(r, {})
# ret = ret.keys()
# return ret
# self.env.destroyers = __destroyers
# self.env.destroy_handler = self
def __path__(self, r):
"""
......@@ -105,12 +143,12 @@ class DestroyHandler(Listener, Constraint, Orderings, Tool):
"""
children = self.children[r]
if not children:
return set([r])
return [r]
else:
rval = set([r])
rval = [r]
for child in children:
rval.update(self.__views__(child))
return rval
rval += self.__views__(child)
return utils.uniq(rval)
def __users__(self, r):
"""
......@@ -120,12 +158,12 @@ class DestroyHandler(Listener, Constraint, Orderings, Tool):
is returned.
"""
views = self.__views__(r)
rval = set()
rval = [] # set()
for view in views:
for op, i in self.env.clients(view):
if op in self.seen:
rval.update(op.outputs)
return rval
for node, i in view.clients: #self.env.clients(view):
if node != 'output':
rval += node.outputs
return utils.uniq(rval)
def __pre__(self, op):
"""
......@@ -178,7 +216,7 @@ class DestroyHandler(Listener, Constraint, Orderings, Tool):
just_remove is True, we return immediately after removing the
cycles.
"""
users = self.__users__(start)
users = set(self.__users__(start))
users.add(start)
for user in users:
for cycle in copy(self.cycles):
......@@ -208,13 +246,14 @@ class DestroyHandler(Listener, Constraint, Orderings, Tool):
dmap[node.outputs[oidx]] = [node.inputs[iidx] for iidx in iidxs]
return vmap, dmap
def on_import(self, op):
def on_import(self, env, op):
"""
Recomputes the dependencies and search for inconsistencies given
that we just added an op to the env.
"""
self.seen.add(op)
op.deps['destroy'] = []
view_map, destroy_map = self.get_maps(op)
for input in op.inputs:
......@@ -251,7 +290,7 @@ class DestroyHandler(Listener, Constraint, Orderings, Tool):
self.__detect_cycles_helper__(output, [])
def on_prune(self, op):
def on_prune(self, env, op):
"""
Recomputes the dependencies and searches for inconsistencies to remove
given that we just removed an op to the env.
......@@ -295,6 +334,7 @@ class DestroyHandler(Listener, Constraint, Orderings, Tool):
del self.children[output]
self.seen.remove(op)
del op.deps['destroy']
def __add_destroyer__(self, path):
......@@ -305,11 +345,18 @@ class DestroyHandler(Listener, Constraint, Orderings, Tool):
foundation = path[0]
target = path[-1]
op = target.owner
node = target.owner
destroyers = self.destroyers.setdefault(foundation, {})
path = destroyers.setdefault(op, path)
path = destroyers.setdefault(node, path)
print "add", path
node.deps['destroy'] += [user.owner for user in self.__users__(foundation) if user not in node.outputs]
# for foundation, destroyers in self.destroyers.items():
# for op in destroyers.keys():
# ords.setdefault(op, set()).update([user.owner for user in self.__users__(foundation) if user not in op.outputs])
if len(destroyers) > 1:
self.dups.add(foundation)
......@@ -325,10 +372,17 @@ class DestroyHandler(Listener, Constraint, Orderings, Tool):
foundation = path[0]
target = path[-1]
op = target.owner
node = target.owner
print "rm", path
print node.deps['destroy']
for user in self.__users__(foundation):
print " -- ", user
if user not in node.outputs:
node.deps['destroy'].remove(user.owner)
destroyers = self.destroyers[foundation]
del destroyers[op]
del destroyers[node]
if not destroyers:
if foundation in self.illegal:
......@@ -338,14 +392,18 @@ class DestroyHandler(Listener, Constraint, Orderings, Tool):
self.dups.remove(foundation)
def on_rewire(self, clients, r_1, r_2):
def on_change_input(self, env, node, i, r, new_r):
if node != 'output':
self.on_rewire(env, [(node, i)], r, new_r)
def on_rewire(self, env, clients, r_1, r_2):
"""
Recomputes the dependencies and searches for inconsistencies to remove
given that all the clients are moved from r_1 to r_2, clients being
a list of (op, i) pairs such that op.inputs[i] used to be r_1 and is
now r_2.
"""
path_1 = self.__path__(r_1)
path_2 = self.__path__(r_2)
......@@ -396,7 +454,7 @@ class DestroyHandler(Listener, Constraint, Orderings, Tool):
self.children.setdefault(r_2, set())
self.__detect_cycles__(r_2)
def validate(self):
def validate(self, env):
"""
Raises an L{InconsistencyError} on any of the following conditions:
- Some results are destroyed by more than one L{Op}
......@@ -412,9 +470,9 @@ class DestroyHandler(Listener, Constraint, Orderings, Tool):
else:
return True
def orderings(self):
def orderings(self, env):
"""
Returns a dict of {op : set(ops that must be computed before it)} according
Returns a dict of {node : set(nodes that must be computed before it)} according
to L{DestroyHandler}.
In particular, all the users of a destroyed result have priority over the
L{Op} that destroys the result.
......@@ -426,6 +484,8 @@ class DestroyHandler(Listener, Constraint, Orderings, Tool):
return ords
class Destroyer:
"""
Base class for Ops that destroy one or more of their inputs in an
......@@ -493,3 +553,4 @@ def view_roots(r):
return [r]
else:
return [r]
......@@ -202,7 +202,7 @@ def results_and_orphans(i, o, except_unreachable_input=False):
"""
results = set()
i = set(i)
results.update(i)
# results.update(i)
incomplete_paths = []
reached = set()
......@@ -287,7 +287,7 @@ def orphans(i, o):
return results_and_orphans(i, o)[1]
def clone(i, o, copy_inputs = False):
def clone(i, o, copy_inputs = True):
"""
@type i: list
@param i: input L{Result}s
......@@ -299,8 +299,8 @@ def clone(i, o, copy_inputs = False):
Copies the subgraph contained between i and o and returns the
outputs of that copy (corresponding to o).
"""
equiv = clone_get_equiv(i, o)
return [equiv[output] for output in o]
equiv = clone_get_equiv(i, o, copy_inputs)
return [equiv[input] for input in i], [equiv[output] for output in o]
def clone_get_equiv(i, o, copy_inputs_and_orphans = False):
......@@ -324,7 +324,7 @@ def clone_get_equiv(i, o, copy_inputs_and_orphans = False):
for input in i:
if copy_inputs_and_orphans:
cpy = copy(input)
cpy = input.clone()
cpy.owner = None
cpy.index = None
d[input] = cpy
......@@ -337,7 +337,7 @@ def clone_get_equiv(i, o, copy_inputs_and_orphans = False):
node = result.owner
if node is None: # result is an orphan
if copy_inputs_and_orphans:
cpy = copy(result)
cpy = result.clone()
d[result] = cpy
else:
d[result] = result
......
......@@ -115,7 +115,10 @@ class OpSpecificOptimizer(LocalOptimizer):
"""
def add_requirements(self, env):
env.extend(toolbox.NodeFinder(env))
try:
env.extend(toolbox.NodeFinder())
env.extend(toolbox.ReplaceValidate())
except: pass
def candidates(self, env):
"""
......@@ -135,7 +138,10 @@ class OpSubOptimizer(Optimizer):
"""
def add_requirements(self, env):
env.extend(toolbox.NodeFinder(env))
try:
env.extend(toolbox.NodeFinder())
env.extend(toolbox.ReplaceValidate())
except: pass
def __init__(self, op1, op2, failure_callback = None):
"""
......@@ -163,7 +169,7 @@ class OpSubOptimizer(Optimizer):
repl = self.op2.make_node(*node.inputs)
assert len(node.outputs) == len(repl.outputs)
for old, new in zip(node.outputs, repl.outputs):
env.replace(old, new)
env.replace_validate(old, new)
except Exception, e:
if self.failure_callback is not None:
self.failure_callback(node, repl, e)
......@@ -182,7 +188,10 @@ class OpRemover(Optimizer):
"""
def add_requirements(self, env):
env.extend(toolbox.NodeFinder(env))
try:
env.extend(toolbox.NodeFinder())
env.extend(toolbox.ReplaceValidate())
except: pass
def __init__(self, op, failure_callback = None):
"""
......
from random import shuffle
import utils
from functools import partial
import graph
class EquivTool(dict):
class Bookkeeper:
def on_attach(self, env):
for node in graph.io_toposort(env.inputs, env.outputs):
self.on_import(env, node)
def __init__(self, env):
self.env = env
def on_deattach(self, env):
for node in graph.io_toposort(env.inputs, env.outputs):
self.on_prune(env, node)
class History:
def __init__(self):
self.history = {}
def on_rewire(self, clients, r, new_r):
repl = self(new_r)
if repl is r:
self.ungroup(r, new_r)
elif repl is not new_r:
raise Exception("Improper use of EquivTool!")
else:
self.group(new_r, r)
def publish(self):
self.env.equiv = self
self.env.set_equiv = self.set_equiv
def unpublish(self):
del self.env.equiv
del self.env.set_equiv
def set_equiv(self, d):
self.update(d)
def group(self, main, *keys):
"Marks all the keys as having been replaced by the Result main."
keys = [key for key in keys if key is not main]
if self.has_key(main):
raise Exception("Only group results that have not been grouped before.")
for key in keys:
if self.has_key(key):
raise Exception("Only group results that have not been grouped before.")
if key is main:
continue
self.setdefault(key, main)
def ungroup(self, main, *keys):
"Undoes group(main, *keys)"
keys = [key for key in keys if key is not main]
for key in keys:
if self[key] is main:
del self[key]
def __call__(self, key):
"Returns the currently active replacement for the given key."
next = self.get(key, None)
while next:
key = next
next = self.get(next, None)
return key
class NodeFinder(dict):
def __init__(self, env):
def on_attach(self, env):
if hasattr(env, 'checkpoint') or hasattr(env, 'revert'):
raise Exception("History feature is already present or in conflict with another plugin.")
self.history[env] = []
env.checkpoint = lambda: len(self.history[env])
env.revert = partial(self.revert, env)
def on_deattach(self, env):
del env.checkpoint
del env.revert
del self.history[env]
def on_change_input(self, env, node, i, r, new_r):
if self.history[env] is None:
return
h = self.history[env]
h.append(lambda: env.change_input(node, i, r))
def revert(self, env, checkpoint):
"""
Reverts the graph to whatever it was at the provided
checkpoint (undoes all replacements). A checkpoint at any
given time can be obtained using self.checkpoint().
"""
h = self.history[env]
self.history[env] = None
while len(h) > checkpoint:
f = h.pop()
f()
self.history[env] = h
class Validator:
def on_attach(self, env):
if hasattr(env, 'validate'):
raise Exception("Validator feature is already present or in conflict with another plugin.")
env.validate = lambda: env.execute_callbacks('validate')
def consistent():
try:
env.validate()
return True
except:
return False
env.consistent = consistent
def on_deattach(self, env):
del env.validate
del env.consistent
class ReplaceValidate(History, Validator):
def on_attach(self, env):
History.on_attach(self, env)
Validator.on_attach(self, env)
for attr in ('replace_validate', 'replace_all_validate'):
if hasattr(env, attr):
raise Exception("ReplaceValidate feature is already present or in conflict with another plugin.")
env.replace_validate = partial(self.replace_validate, env)
env.replace_all_validate = partial(self.replace_all_validate, env)
def on_deattach(self, env):
History.on_deattach(self, env)
Validator.on_deattach(self, env)
del env.replace_validate
del env.replace_all_validate
def replace_validate(self, env, r, new_r):
self.replace_all_validate(env, [(r, new_r)])
def replace_all_validate(self, env, replacements):
chk = env.checkpoint()
for r, new_r in replacements:
env.replace(r, new_r)
try:
env.validate()
except:
env.revert(chk)
raise
class NodeFinder(dict, Bookkeeper):
def __init__(self):
self.env = None
def on_attach(self, env):
if self.env is not None:
raise Exception("A NodeFinder instance can only serve one Env.")
if hasattr(env, 'get_nodes'):
raise Exception("NodeFinder is already present or in conflict with another plugin.")
self.env = env
env.get_nodes = partial(self.query, env)
Bookkeeper.on_attach(self, env)
def on_import(self, node):
def on_deattach(self, env):
if self.env is not env:
raise Exception("This NodeFinder instance was not attached to the provided env.")
self.env = None
del env.get_nodes
Bookkeeper.on_deattach(self, env)
def on_import(self, env, node):
try:
self.setdefault(node.op, set()).add(node)
except TypeError:
pass
self.setdefault(node.op, []).append(node)
except TypeError: #node.op is unhashable
return
def on_prune(self, node):
def on_prune(self, env, node):
try:
self[node.op].remove(node)
except TypeError:
nodes = self[node.op]
except TypeError: #node.op is unhashable
return
if not self[node.op]:
nodes.remove(node)
if not nodes:
del self[node.op]
def query(self, op):
def query(self, env, op):
try:
all = self.get(op, [])
except TypeError:
raise TypeError("%s in unhashable and cannot be queried by the optimizer" % op)
all = [x for x in all]
shuffle(all) # this helps a lot for debugging because the order of the replacements will vary
all = list(all)
while all:
next = all.pop()
if self.env.has_node(next):
if next in env.nodes:
yield next
def publish(self):
self.env.get_nodes = self.query
def __eq__(self, other):
return isinstance(other, NodeFinder) and self.env is other.env
class PrintListener(object):
def __init__(self, active = True):
self.active = active
def on_attach(self, env):
if self.active:
print "-- attaching to: ", env
def on_deattach(self, env):
if self.active:
print "-- deattaching from: ", env
def on_import(self, env, node):
if self.active:
print "-- importing: %s" % node
def on_prune(self, env, node):
if self.active:
print "-- pruning: %s" % node
def on_change_input(self, env, node, i, r, new_r):
if self.active:
print "-- changing (%s.inputs[%s]) from %s to %s" % (node, i, r, new_r)
# class EquivTool(dict):
# def __init__(self, env):
# self.env = env
# def on_rewire(self, clients, r, new_r):
# repl = self(new_r)
# if repl is r:
# self.ungroup(r, new_r)
# elif repl is not new_r:
# raise Exception("Improper use of EquivTool!")
# else:
# self.group(new_r, r)
# def publish(self):
# self.env.equiv = self
# self.env.set_equiv = self.set_equiv
# def unpublish(self):
# del self.env.equiv
# del self.env.set_equiv
# def set_equiv(self, d):
# self.update(d)
# def group(self, main, *keys):
# "Marks all the keys as having been replaced by the Result main."
# keys = [key for key in keys if key is not main]
# if self.has_key(main):
# raise Exception("Only group results that have not been grouped before.")
# for key in keys:
# if self.has_key(key):
# raise Exception("Only group results that have not been grouped before.")
# if key is main:
# continue
# self.setdefault(key, main)
# def ungroup(self, main, *keys):
# "Undoes group(main, *keys)"
# keys = [key for key in keys if key is not main]
# for key in keys:
# if self[key] is main:
# del self[key]
# def __call__(self, key):
# "Returns the currently active replacement for the given key."
# next = self.get(key, None)
# while next:
# key = next
# next = self.get(next, None)
# return key
# class InstanceFinder(Listener, Tool, dict):
......@@ -158,28 +302,6 @@ class NodeFinder(dict):
class PrintListener(object):
def __init__(self, env, active = True):
self.env = env
self.active = active
if active:
print "-- initializing"
def on_import(self, node):
if self.active:
print "-- importing: %s" % node
def on_prune(self, node):
if self.active:
print "-- pruning: %s" % node
def on_rewire(self, clients, r, new_r):
if self.active:
if r.owner is not None: r = r.owner
if new_r.owner is not None: new_r = new_r.owner
print "-- moving from %s to %s" % (r, new_r)
### UNUSED AND UNTESTED ###
......
......@@ -26,6 +26,8 @@ class object2(object):
if hasattr(self, '__eq__') or hasattr(self, '__cmp__'):
raise TypeError("unhashable object: %s" % self)
return id(self)
def __ne__(self, other):
return not self == other
class scratchpad:
def clear(self):
......
......@@ -71,7 +71,7 @@ class Tensor(Type):
def __init__(self, dtype, broadcastable):
self.dtype = str(dtype)
self.broadcastable = broadcastable
self.broadcastable = tuple(broadcastable)
self.dtype_specs() # error checking is done there
def filter(self, data, strict = False):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论