提交 65e08101 authored 作者: Olivier Breuleux's avatar Olivier Breuleux

env redone, toolbox redone

上级 646f4c01
...@@ -2,13 +2,15 @@ ...@@ -2,13 +2,15 @@
import unittest import unittest
from type import Type from type import Type
import graph
from graph import Result, as_result, Apply from graph import Result, as_result, Apply
from op import Op from op import Op
from opt import PatternOptimizer, OpSubOptimizer from opt import PatternOptimizer, OpSubOptimizer
from ext import * from ext import *
from env import Env, InconsistencyError from env import Env, InconsistencyError
from toolbox import EquivTool #from toolbox import EquivTool
from toolbox import ReplaceValidate
from copy import copy from copy import copy
...@@ -65,8 +67,11 @@ def inputs(): ...@@ -65,8 +67,11 @@ def inputs():
_Env = Env _Env = Env
def Env(inputs, outputs, validate = True): def Env(inputs, outputs, validate = True):
e = _Env(inputs, outputs) e = _Env(inputs, outputs)
e.extend(EquivTool(e)) ##e.extend(EquivTool(e))
e.extend(DestroyHandler(e), validate = validate) e.extend(DestroyHandler())
e.extend(ReplaceValidate())
if validate:
e.validate()
return e return e
...@@ -108,19 +113,19 @@ class _test_all(unittest.TestCase): ...@@ -108,19 +113,19 @@ class _test_all(unittest.TestCase):
g = Env([x,y,z], [e1, e2]) g = Env([x,y,z], [e1, e2])
chk = g.checkpoint() chk = g.checkpoint()
assert g.consistent() assert g.consistent()
g.replace(e1, add_in_place(x, y)) g.replace_validate(e1, add_in_place(x, y))
assert g.consistent() assert g.consistent()
try: try:
g.replace(e2, add_in_place(y, x)) g.replace_validate(e2, add_in_place(y, x))
self.fail() self.fail()
except InconsistencyError: except InconsistencyError:
pass pass
assert g.consistent() assert g.consistent()
g.revert(chk) g.revert(chk)
g.replace(e2, add_in_place(y, x)) g.replace_validate(e2, add_in_place(y, x))
assert g.consistent() assert g.consistent()
try: try:
g.replace(e1, add_in_place(x, y)) g.replace_validate(e1, add_in_place(x, y))
self.fail() self.fail()
except InconsistencyError: except InconsistencyError:
pass pass
...@@ -136,7 +141,7 @@ class _test_all(unittest.TestCase): ...@@ -136,7 +141,7 @@ class _test_all(unittest.TestCase):
assert str(g) != "[Dot(Dot(AddInPlace(x, y), AddInPlace(y, z)), AddInPlace(z, x))]" # we don't want to see that! assert str(g) != "[Dot(Dot(AddInPlace(x, y), AddInPlace(y, z)), AddInPlace(z, x))]" # we don't want to see that!
e2 = dot(dot(add_in_place(x,y), add_in_place(y,z)), add_in_place(z,x)) e2 = dot(dot(add_in_place(x,y), add_in_place(y,z)), add_in_place(z,x))
try: try:
g2 = Env([x,y,z], [e2]) g2 = Env(*graph.clone([x,y,z], [e2]))
self.fail() self.fail()
except InconsistencyError: except InconsistencyError:
pass pass
...@@ -154,16 +159,18 @@ class _test_all(unittest.TestCase): ...@@ -154,16 +159,18 @@ class _test_all(unittest.TestCase):
e = dot(aip, transpose_view(x)) e = dot(aip, transpose_view(x))
g = Env([x,y,z], [e], False) g = Env([x,y,z], [e], False)
assert not g.consistent() assert not g.consistent()
g.replace(aip, add(x, z)) g.replace_validate(aip, add(x, z))
assert g.consistent() assert g.consistent()
def test_usage_loop_through_views_2(self): def test_usage_loop_through_views_2(self):
x, y, z = inputs() x, y, z = inputs()
e0 = transpose_view(transpose_view(transpose_view(sigmoid(x)))) e0 = transpose_view(transpose_view(sigmoid(x)))
e = dot(add_in_place(x,y), transpose_view(e0)) e = dot(add_in_place(x,y), transpose_view(e0))
g = Env([x,y,z], [e]) g = Env([x,y,z], [e])
assert g.consistent() # because sigmoid can do the copy assert g.consistent() # because sigmoid can do the copy
g.replace(e0, x, False) # print g
# print g.destroy_handler.children
g.replace(e0, x)
assert not g.consistent() # we cut off the path to the sigmoid assert not g.consistent() # we cut off the path to the sigmoid
def test_usage_loop_insert_views(self): def test_usage_loop_insert_views(self):
...@@ -184,10 +191,10 @@ class _test_all(unittest.TestCase): ...@@ -184,10 +191,10 @@ class _test_all(unittest.TestCase):
chk = g.checkpoint() chk = g.checkpoint()
PatternOptimizer((transpose_view, (transpose_view, 'x')), 'x').optimize(g) PatternOptimizer((transpose_view, (transpose_view, 'x')), 'x').optimize(g)
assert str(g) == "[x]" assert str(g) == "[x]"
g.replace(g.equiv(e), add(x,y)) new_e = add(x,y)
print g g.replace_validate(x, new_e)
assert str(g) == "[Add(x, y)]" assert str(g) == "[Add(x, y)]"
g.replace(g.equiv(e), dot(add_in_place(x,y), transpose_view(x)), False) g.replace(new_e, dot(add_in_place(x,y), transpose_view(x)))
assert str(g) == "[Dot(AddInPlace(x, y), TransposeView(x))]" assert str(g) == "[Dot(AddInPlace(x, y), TransposeView(x))]"
assert not g.consistent() assert not g.consistent()
g.revert(chk) g.revert(chk)
...@@ -202,7 +209,7 @@ class _test_all(unittest.TestCase): ...@@ -202,7 +209,7 @@ class _test_all(unittest.TestCase):
e = add_in_place(x, y) e = add_in_place(x, y)
g = Env([x,y,z], [e], False) g = Env([x,y,z], [e], False)
assert not g.consistent() assert not g.consistent()
g.replace(e, add(x, y)) g.replace_validate(e, add(x, y))
assert g.consistent() assert g.consistent()
def test_indestructible_through_views(self): def test_indestructible_through_views(self):
...@@ -212,7 +219,7 @@ class _test_all(unittest.TestCase): ...@@ -212,7 +219,7 @@ class _test_all(unittest.TestCase):
e = add_in_place(tv, y) e = add_in_place(tv, y)
g = Env([x,y,z], [e], False) g = Env([x,y,z], [e], False)
assert not g.consistent() assert not g.consistent()
g.replace(tv, sigmoid(x)) g.replace_validate(tv, sigmoid(x))
assert g.consistent() assert g.consistent()
def test_repair_destroy_path(self): def test_repair_destroy_path(self):
...@@ -223,7 +230,7 @@ class _test_all(unittest.TestCase): ...@@ -223,7 +230,7 @@ class _test_all(unittest.TestCase):
e4 = add_in_place(e1, z) e4 = add_in_place(e1, z)
g = Env([x,y,z], [e3, e4], False) g = Env([x,y,z], [e3, e4], False)
assert not g.consistent() assert not g.consistent()
g.replace(e2, transpose_view(x), False) g.replace(e2, transpose_view(x))
assert not g.consistent() assert not g.consistent()
def test_indirect(self): def test_indirect(self):
...@@ -233,9 +240,9 @@ class _test_all(unittest.TestCase): ...@@ -233,9 +240,9 @@ class _test_all(unittest.TestCase):
g = Env([x,y,z], [e], False) g = Env([x,y,z], [e], False)
assert not g.consistent() assert not g.consistent()
new_e0 = add(x, y) new_e0 = add(x, y)
g.replace(e0, new_e0, False) g.replace(e0, new_e0)
assert g.consistent() assert g.consistent()
g.replace(new_e0, add_in_place(x, y), False) g.replace(new_e0, add_in_place(x, y))
assert not g.consistent() assert not g.consistent()
def test_indirect_2(self): def test_indirect_2(self):
...@@ -245,12 +252,12 @@ class _test_all(unittest.TestCase): ...@@ -245,12 +252,12 @@ class _test_all(unittest.TestCase):
g = Env([x,y,z], [e], False) g = Env([x,y,z], [e], False)
assert not g.consistent() assert not g.consistent()
new_e0 = add(e0, y) new_e0 = add(e0, y)
g.replace(e0, new_e0, False) g.replace(e0, new_e0)
assert g.consistent() assert g.consistent()
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() #unittest.main()
_test_all('test_usage_loop_through_views').debug()
...@@ -59,19 +59,19 @@ def inputs(): ...@@ -59,19 +59,19 @@ def inputs():
return x, y, z return x, y, z
class _test_EquivTool(unittest.TestCase): # class _test_EquivTool(unittest.TestCase):
def test_straightforward(self): # def test_straightforward(self):
x, y, z = inputs() # x, y, z = inputs()
sx = sigmoid(x) # sx = sigmoid(x)
e = add(sx, sigmoid(y)) # e = add(sx, sigmoid(y))
g = Env([x, y, z], [e]) # g = Env([x, y, z], [e])
g.extend(EquivTool(g)) # g.extend(EquivTool(g))
assert hasattr(g, 'equiv') # assert hasattr(g, 'equiv')
assert g.equiv(sx) is sx # assert g.equiv(sx) is sx
g.replace(sx, dot(x, z)) # g.replace(sx, dot(x, z))
assert g.equiv(sx) is not sx # assert g.equiv(sx) is not sx
assert g.equiv(sx).owner.op is dot # assert g.equiv(sx).owner.op is dot
class _test_NodeFinder(unittest.TestCase): class _test_NodeFinder(unittest.TestCase):
...@@ -81,7 +81,7 @@ class _test_NodeFinder(unittest.TestCase): ...@@ -81,7 +81,7 @@ class _test_NodeFinder(unittest.TestCase):
e0 = dot(y, z) e0 = dot(y, z)
e = add(add(sigmoid(x), sigmoid(sigmoid(z))), dot(add(x, y), e0)) e = add(add(sigmoid(x), sigmoid(sigmoid(z))), dot(add(x, y), e0))
g = Env([x, y, z], [e]) g = Env([x, y, z], [e])
g.extend(NodeFinder(g)) g.extend(NodeFinder())
assert hasattr(g, 'get_nodes') assert hasattr(g, 'get_nodes')
for type, num in ((add, 3), (sigmoid, 3), (dot, 2)): for type, num in ((add, 3), (sigmoid, 3), (dot, 2)):
if not len([x for x in g.get_nodes(type)]) == num: if not len([x for x in g.get_nodes(type)]) == num:
...@@ -100,7 +100,7 @@ class _test_NodeFinder(unittest.TestCase): ...@@ -100,7 +100,7 @@ class _test_NodeFinder(unittest.TestCase):
x, y, z = inputs() x, y, z = inputs()
e = add(add(sigmoid(x), sigmoid(sigmoid(z))), dot(add(x, y), dot(y, z))) e = add(add(sigmoid(x), sigmoid(sigmoid(z))), dot(add(x, y), dot(y, z)))
g = Env([x, y, z], [e]) g = Env([x, y, z], [e])
g.extend(NodeFinder(g)) g.extend(NodeFinder())
gen = g.get_nodes(sigmoid) # I want to get Sigmoid instances gen = g.get_nodes(sigmoid) # I want to get Sigmoid instances
g.replace(e, add(x, y)) # but here I prune them all g.replace(e, add(x, y)) # but here I prune them all
assert len([x for x in gen]) == 0 # the generator should not yield them assert len([x for x in gen]) == 0 # the generator should not yield them
......
差异被折叠。
from features import Listener, Constraint, Orderings, Tool #from features import Listener, Constraint, Orderings, Tool
import utils
from utils import AbstractFunctionError from utils import AbstractFunctionError
from copy import copy from copy import copy
from env import InconsistencyError from env import InconsistencyError
from toolbox import Bookkeeper
class DestroyHandler(Listener, Constraint, Orderings, Tool): from collections import defaultdict
class DestroyHandler(Bookkeeper): #(Listener, Constraint, Orderings, Tool):
""" """
This feature ensures that an env represents a consistent data flow This feature ensures that an env represents a consistent data flow
when some Ops overwrite their inputs and/or provide "views" over when some Ops overwrite their inputs and/or provide "views" over
...@@ -28,13 +36,31 @@ class DestroyHandler(Listener, Constraint, Orderings, Tool): ...@@ -28,13 +36,31 @@ class DestroyHandler(Listener, Constraint, Orderings, Tool):
safely. safely.
""" """
def __init__(self, env): def __init__(self):
self.env = None
def on_attach(self, env):
if self.env is not None:
raise Exception("A DestroyHandler instance can only serve one Env.")
for attr in ('destroyers', 'destroy_handler'):
if hasattr(env, attr):
raise Exception("DestroyHandler feature is already present or in conflict with another plugin.")
def __destroyers(r):
ret = self.destroyers.get(r, {})
ret = ret.keys()
return ret
env.destroyers = __destroyers
env.destroy_handler = self
self.env = env
# For an Op that has a view_map, {output : input it is a view of} # For an Op that has a view_map, {output : input it is a view of}
self.parent = {} self.parent = {}
# Reverse mapping of parent: {input : outputs that are a view of it} # Reverse mapping of parent: {input : outputs that are a view of it}
self.children = {} self.children = defaultdict(set)
# {foundation : {op that destroys it : path }} # {foundation : {op that destroys it : path }}
# where foundation is a result such that (not self.parent[result]) # where foundation is a result such that (not self.parent[result])
...@@ -57,25 +83,37 @@ class DestroyHandler(Listener, Constraint, Orderings, Tool): ...@@ -57,25 +83,37 @@ class DestroyHandler(Listener, Constraint, Orderings, Tool):
# indestructible by the user. # indestructible by the user.
self.illegal = set() self.illegal = set()
self.env = env
self.seen = set() self.seen = set()
# Initialize the children if the inputs and orphans. Bookkeeper.on_attach(self, env)
for input in env.orphans.union(env.inputs):
self.children[input] = set() # # Initialize the children if the inputs and orphans.
# for input in env.inputs: # env.orphans.union(env.inputs):
def publish(self): # self.children[input] = set()
"""
Publishes the following on the env: def on_detach(self, env):
- destroyers(r) -> returns all L{Op}s that destroy the result r del self.parent
- destroy_handler -> self del self.children
""" del self.destroyers
def __destroyers(r): del self.paths
ret = self.destroyers.get(r, {}) del self.dups
ret = ret.keys() del self.cycles
return ret del self.illegal
self.env.destroyers = __destroyers del self.seen
self.env.destroy_handler = self self.env = None
# def publish(self):
# """
# Publishes the following on the env:
# - destroyers(r) -> returns all L{Op}s that destroy the result r
# - destroy_handler -> self
# """
# def __destroyers(r):
# ret = self.destroyers.get(r, {})
# ret = ret.keys()
# return ret
# self.env.destroyers = __destroyers
# self.env.destroy_handler = self
def __path__(self, r): def __path__(self, r):
""" """
...@@ -105,12 +143,12 @@ class DestroyHandler(Listener, Constraint, Orderings, Tool): ...@@ -105,12 +143,12 @@ class DestroyHandler(Listener, Constraint, Orderings, Tool):
""" """
children = self.children[r] children = self.children[r]
if not children: if not children:
return set([r]) return [r]
else: else:
rval = set([r]) rval = [r]
for child in children: for child in children:
rval.update(self.__views__(child)) rval += self.__views__(child)
return rval return utils.uniq(rval)
def __users__(self, r): def __users__(self, r):
""" """
...@@ -120,12 +158,12 @@ class DestroyHandler(Listener, Constraint, Orderings, Tool): ...@@ -120,12 +158,12 @@ class DestroyHandler(Listener, Constraint, Orderings, Tool):
is returned. is returned.
""" """
views = self.__views__(r) views = self.__views__(r)
rval = set() rval = [] # set()
for view in views: for view in views:
for op, i in self.env.clients(view): for node, i in view.clients: #self.env.clients(view):
if op in self.seen: if node != 'output':
rval.update(op.outputs) rval += node.outputs
return rval return utils.uniq(rval)
def __pre__(self, op): def __pre__(self, op):
""" """
...@@ -178,7 +216,7 @@ class DestroyHandler(Listener, Constraint, Orderings, Tool): ...@@ -178,7 +216,7 @@ class DestroyHandler(Listener, Constraint, Orderings, Tool):
just_remove is True, we return immediately after removing the just_remove is True, we return immediately after removing the
cycles. cycles.
""" """
users = self.__users__(start) users = set(self.__users__(start))
users.add(start) users.add(start)
for user in users: for user in users:
for cycle in copy(self.cycles): for cycle in copy(self.cycles):
...@@ -208,13 +246,14 @@ class DestroyHandler(Listener, Constraint, Orderings, Tool): ...@@ -208,13 +246,14 @@ class DestroyHandler(Listener, Constraint, Orderings, Tool):
dmap[node.outputs[oidx]] = [node.inputs[iidx] for iidx in iidxs] dmap[node.outputs[oidx]] = [node.inputs[iidx] for iidx in iidxs]
return vmap, dmap return vmap, dmap
def on_import(self, op): def on_import(self, env, op):
""" """
Recomputes the dependencies and search for inconsistencies given Recomputes the dependencies and search for inconsistencies given
that we just added an op to the env. that we just added an op to the env.
""" """
self.seen.add(op) self.seen.add(op)
op.deps['destroy'] = []
view_map, destroy_map = self.get_maps(op) view_map, destroy_map = self.get_maps(op)
for input in op.inputs: for input in op.inputs:
...@@ -251,7 +290,7 @@ class DestroyHandler(Listener, Constraint, Orderings, Tool): ...@@ -251,7 +290,7 @@ class DestroyHandler(Listener, Constraint, Orderings, Tool):
self.__detect_cycles_helper__(output, []) self.__detect_cycles_helper__(output, [])
def on_prune(self, op): def on_prune(self, env, op):
""" """
Recomputes the dependencies and searches for inconsistencies to remove Recomputes the dependencies and searches for inconsistencies to remove
given that we just removed an op to the env. given that we just removed an op to the env.
...@@ -295,6 +334,7 @@ class DestroyHandler(Listener, Constraint, Orderings, Tool): ...@@ -295,6 +334,7 @@ class DestroyHandler(Listener, Constraint, Orderings, Tool):
del self.children[output] del self.children[output]
self.seen.remove(op) self.seen.remove(op)
del op.deps['destroy']
def __add_destroyer__(self, path): def __add_destroyer__(self, path):
...@@ -305,10 +345,17 @@ class DestroyHandler(Listener, Constraint, Orderings, Tool): ...@@ -305,10 +345,17 @@ class DestroyHandler(Listener, Constraint, Orderings, Tool):
foundation = path[0] foundation = path[0]
target = path[-1] target = path[-1]
op = target.owner node = target.owner
destroyers = self.destroyers.setdefault(foundation, {}) destroyers = self.destroyers.setdefault(foundation, {})
path = destroyers.setdefault(op, path) path = destroyers.setdefault(node, path)
print "add", path
node.deps['destroy'] += [user.owner for user in self.__users__(foundation) if user not in node.outputs]
# for foundation, destroyers in self.destroyers.items():
# for op in destroyers.keys():
# ords.setdefault(op, set()).update([user.owner for user in self.__users__(foundation) if user not in op.outputs])
if len(destroyers) > 1: if len(destroyers) > 1:
self.dups.add(foundation) self.dups.add(foundation)
...@@ -325,10 +372,17 @@ class DestroyHandler(Listener, Constraint, Orderings, Tool): ...@@ -325,10 +372,17 @@ class DestroyHandler(Listener, Constraint, Orderings, Tool):
foundation = path[0] foundation = path[0]
target = path[-1] target = path[-1]
op = target.owner node = target.owner
print "rm", path
print node.deps['destroy']
for user in self.__users__(foundation):
print " -- ", user
if user not in node.outputs:
node.deps['destroy'].remove(user.owner)
destroyers = self.destroyers[foundation] destroyers = self.destroyers[foundation]
del destroyers[op] del destroyers[node]
if not destroyers: if not destroyers:
if foundation in self.illegal: if foundation in self.illegal:
...@@ -338,14 +392,18 @@ class DestroyHandler(Listener, Constraint, Orderings, Tool): ...@@ -338,14 +392,18 @@ class DestroyHandler(Listener, Constraint, Orderings, Tool):
self.dups.remove(foundation) self.dups.remove(foundation)
def on_rewire(self, clients, r_1, r_2): def on_change_input(self, env, node, i, r, new_r):
if node != 'output':
self.on_rewire(env, [(node, i)], r, new_r)
def on_rewire(self, env, clients, r_1, r_2):
""" """
Recomputes the dependencies and searches for inconsistencies to remove Recomputes the dependencies and searches for inconsistencies to remove
given that all the clients are moved from r_1 to r_2, clients being given that all the clients are moved from r_1 to r_2, clients being
a list of (op, i) pairs such that op.inputs[i] used to be r_1 and is a list of (op, i) pairs such that op.inputs[i] used to be r_1 and is
now r_2. now r_2.
""" """
path_1 = self.__path__(r_1) path_1 = self.__path__(r_1)
path_2 = self.__path__(r_2) path_2 = self.__path__(r_2)
...@@ -396,7 +454,7 @@ class DestroyHandler(Listener, Constraint, Orderings, Tool): ...@@ -396,7 +454,7 @@ class DestroyHandler(Listener, Constraint, Orderings, Tool):
self.children.setdefault(r_2, set()) self.children.setdefault(r_2, set())
self.__detect_cycles__(r_2) self.__detect_cycles__(r_2)
def validate(self): def validate(self, env):
""" """
Raises an L{InconsistencyError} on any of the following conditions: Raises an L{InconsistencyError} on any of the following conditions:
- Some results are destroyed by more than one L{Op} - Some results are destroyed by more than one L{Op}
...@@ -412,9 +470,9 @@ class DestroyHandler(Listener, Constraint, Orderings, Tool): ...@@ -412,9 +470,9 @@ class DestroyHandler(Listener, Constraint, Orderings, Tool):
else: else:
return True return True
def orderings(self): def orderings(self, env):
""" """
Returns a dict of {op : set(ops that must be computed before it)} according Returns a dict of {node : set(nodes that must be computed before it)} according
to L{DestroyHandler}. to L{DestroyHandler}.
In particular, all the users of a destroyed result have priority over the In particular, all the users of a destroyed result have priority over the
L{Op} that destroys the result. L{Op} that destroys the result.
...@@ -426,6 +484,8 @@ class DestroyHandler(Listener, Constraint, Orderings, Tool): ...@@ -426,6 +484,8 @@ class DestroyHandler(Listener, Constraint, Orderings, Tool):
return ords return ords
class Destroyer: class Destroyer:
""" """
Base class for Ops that destroy one or more of their inputs in an Base class for Ops that destroy one or more of their inputs in an
...@@ -493,3 +553,4 @@ def view_roots(r): ...@@ -493,3 +553,4 @@ def view_roots(r):
return [r] return [r]
else: else:
return [r] return [r]
...@@ -202,7 +202,7 @@ def results_and_orphans(i, o, except_unreachable_input=False): ...@@ -202,7 +202,7 @@ def results_and_orphans(i, o, except_unreachable_input=False):
""" """
results = set() results = set()
i = set(i) i = set(i)
results.update(i) # results.update(i)
incomplete_paths = [] incomplete_paths = []
reached = set() reached = set()
...@@ -287,7 +287,7 @@ def orphans(i, o): ...@@ -287,7 +287,7 @@ def orphans(i, o):
return results_and_orphans(i, o)[1] return results_and_orphans(i, o)[1]
def clone(i, o, copy_inputs = False): def clone(i, o, copy_inputs = True):
""" """
@type i: list @type i: list
@param i: input L{Result}s @param i: input L{Result}s
...@@ -299,8 +299,8 @@ def clone(i, o, copy_inputs = False): ...@@ -299,8 +299,8 @@ def clone(i, o, copy_inputs = False):
Copies the subgraph contained between i and o and returns the Copies the subgraph contained between i and o and returns the
outputs of that copy (corresponding to o). outputs of that copy (corresponding to o).
""" """
equiv = clone_get_equiv(i, o) equiv = clone_get_equiv(i, o, copy_inputs)
return [equiv[output] for output in o] return [equiv[input] for input in i], [equiv[output] for output in o]
def clone_get_equiv(i, o, copy_inputs_and_orphans = False): def clone_get_equiv(i, o, copy_inputs_and_orphans = False):
...@@ -324,7 +324,7 @@ def clone_get_equiv(i, o, copy_inputs_and_orphans = False): ...@@ -324,7 +324,7 @@ def clone_get_equiv(i, o, copy_inputs_and_orphans = False):
for input in i: for input in i:
if copy_inputs_and_orphans: if copy_inputs_and_orphans:
cpy = copy(input) cpy = input.clone()
cpy.owner = None cpy.owner = None
cpy.index = None cpy.index = None
d[input] = cpy d[input] = cpy
...@@ -337,7 +337,7 @@ def clone_get_equiv(i, o, copy_inputs_and_orphans = False): ...@@ -337,7 +337,7 @@ def clone_get_equiv(i, o, copy_inputs_and_orphans = False):
node = result.owner node = result.owner
if node is None: # result is an orphan if node is None: # result is an orphan
if copy_inputs_and_orphans: if copy_inputs_and_orphans:
cpy = copy(result) cpy = result.clone()
d[result] = cpy d[result] = cpy
else: else:
d[result] = result d[result] = result
......
...@@ -115,7 +115,10 @@ class OpSpecificOptimizer(LocalOptimizer): ...@@ -115,7 +115,10 @@ class OpSpecificOptimizer(LocalOptimizer):
""" """
def add_requirements(self, env): def add_requirements(self, env):
env.extend(toolbox.NodeFinder(env)) try:
env.extend(toolbox.NodeFinder())
env.extend(toolbox.ReplaceValidate())
except: pass
def candidates(self, env): def candidates(self, env):
""" """
...@@ -135,7 +138,10 @@ class OpSubOptimizer(Optimizer): ...@@ -135,7 +138,10 @@ class OpSubOptimizer(Optimizer):
""" """
def add_requirements(self, env): def add_requirements(self, env):
env.extend(toolbox.NodeFinder(env)) try:
env.extend(toolbox.NodeFinder())
env.extend(toolbox.ReplaceValidate())
except: pass
def __init__(self, op1, op2, failure_callback = None): def __init__(self, op1, op2, failure_callback = None):
""" """
...@@ -163,7 +169,7 @@ class OpSubOptimizer(Optimizer): ...@@ -163,7 +169,7 @@ class OpSubOptimizer(Optimizer):
repl = self.op2.make_node(*node.inputs) repl = self.op2.make_node(*node.inputs)
assert len(node.outputs) == len(repl.outputs) assert len(node.outputs) == len(repl.outputs)
for old, new in zip(node.outputs, repl.outputs): for old, new in zip(node.outputs, repl.outputs):
env.replace(old, new) env.replace_validate(old, new)
except Exception, e: except Exception, e:
if self.failure_callback is not None: if self.failure_callback is not None:
self.failure_callback(node, repl, e) self.failure_callback(node, repl, e)
...@@ -182,7 +188,10 @@ class OpRemover(Optimizer): ...@@ -182,7 +188,10 @@ class OpRemover(Optimizer):
""" """
def add_requirements(self, env): def add_requirements(self, env):
env.extend(toolbox.NodeFinder(env)) try:
env.extend(toolbox.NodeFinder())
env.extend(toolbox.ReplaceValidate())
except: pass
def __init__(self, op, failure_callback = None): def __init__(self, op, failure_callback = None):
""" """
......
from random import shuffle from random import shuffle
import utils import utils
from functools import partial
import graph
class EquivTool(dict): class Bookkeeper:
def on_attach(self, env):
for node in graph.io_toposort(env.inputs, env.outputs):
self.on_import(env, node)
def on_deattach(self, env):
for node in graph.io_toposort(env.inputs, env.outputs):
self.on_prune(env, node)
class History:
def __init__(self):
self.history = {}
def on_attach(self, env):
if hasattr(env, 'checkpoint') or hasattr(env, 'revert'):
raise Exception("History feature is already present or in conflict with another plugin.")
self.history[env] = []
env.checkpoint = lambda: len(self.history[env])
env.revert = partial(self.revert, env)
def on_deattach(self, env):
del env.checkpoint
del env.revert
del self.history[env]
def on_change_input(self, env, node, i, r, new_r):
if self.history[env] is None:
return
h = self.history[env]
h.append(lambda: env.change_input(node, i, r))
def revert(self, env, checkpoint):
"""
Reverts the graph to whatever it was at the provided
checkpoint (undoes all replacements). A checkpoint at any
given time can be obtained using self.checkpoint().
"""
h = self.history[env]
self.history[env] = None
while len(h) > checkpoint:
f = h.pop()
f()
self.history[env] = h
class Validator:
def on_attach(self, env):
if hasattr(env, 'validate'):
raise Exception("Validator feature is already present or in conflict with another plugin.")
env.validate = lambda: env.execute_callbacks('validate')
def consistent():
try:
env.validate()
return True
except:
return False
env.consistent = consistent
def on_deattach(self, env):
del env.validate
del env.consistent
class ReplaceValidate(History, Validator):
def on_attach(self, env):
History.on_attach(self, env)
Validator.on_attach(self, env)
for attr in ('replace_validate', 'replace_all_validate'):
if hasattr(env, attr):
raise Exception("ReplaceValidate feature is already present or in conflict with another plugin.")
env.replace_validate = partial(self.replace_validate, env)
env.replace_all_validate = partial(self.replace_all_validate, env)
def on_deattach(self, env):
History.on_deattach(self, env)
Validator.on_deattach(self, env)
del env.replace_validate
del env.replace_all_validate
def replace_validate(self, env, r, new_r):
self.replace_all_validate(env, [(r, new_r)])
def replace_all_validate(self, env, replacements):
chk = env.checkpoint()
for r, new_r in replacements:
env.replace(r, new_r)
try:
env.validate()
except:
env.revert(chk)
raise
def __init__(self, env):
self.env = env
def on_rewire(self, clients, r, new_r): class NodeFinder(dict, Bookkeeper):
repl = self(new_r)
if repl is r: def __init__(self):
self.ungroup(r, new_r) self.env = None
elif repl is not new_r:
raise Exception("Improper use of EquivTool!") def on_attach(self, env):
else: if self.env is not None:
self.group(new_r, r) raise Exception("A NodeFinder instance can only serve one Env.")
if hasattr(env, 'get_nodes'):
def publish(self): raise Exception("NodeFinder is already present or in conflict with another plugin.")
self.env.equiv = self
self.env.set_equiv = self.set_equiv
def unpublish(self):
del self.env.equiv
del self.env.set_equiv
def set_equiv(self, d):
self.update(d)
def group(self, main, *keys):
"Marks all the keys as having been replaced by the Result main."
keys = [key for key in keys if key is not main]
if self.has_key(main):
raise Exception("Only group results that have not been grouped before.")
for key in keys:
if self.has_key(key):
raise Exception("Only group results that have not been grouped before.")
if key is main:
continue
self.setdefault(key, main)
def ungroup(self, main, *keys):
"Undoes group(main, *keys)"
keys = [key for key in keys if key is not main]
for key in keys:
if self[key] is main:
del self[key]
def __call__(self, key):
"Returns the currently active replacement for the given key."
next = self.get(key, None)
while next:
key = next
next = self.get(next, None)
return key
class NodeFinder(dict):
def __init__(self, env):
self.env = env self.env = env
env.get_nodes = partial(self.query, env)
Bookkeeper.on_attach(self, env)
def on_import(self, node): def on_deattach(self, env):
if self.env is not env:
raise Exception("This NodeFinder instance was not attached to the provided env.")
self.env = None
del env.get_nodes
Bookkeeper.on_deattach(self, env)
def on_import(self, env, node):
try: try:
self.setdefault(node.op, set()).add(node) self.setdefault(node.op, []).append(node)
except TypeError: except TypeError: #node.op is unhashable
pass return
def on_prune(self, node): def on_prune(self, env, node):
try: try:
self[node.op].remove(node) nodes = self[node.op]
except TypeError: except TypeError: #node.op is unhashable
return return
if not self[node.op]: nodes.remove(node)
if not nodes:
del self[node.op] del self[node.op]
def query(self, op): def query(self, env, op):
try: try:
all = self.get(op, []) all = self.get(op, [])
except TypeError: except TypeError:
raise TypeError("%s in unhashable and cannot be queried by the optimizer" % op) raise TypeError("%s in unhashable and cannot be queried by the optimizer" % op)
all = [x for x in all] all = list(all)
shuffle(all) # this helps a lot for debugging because the order of the replacements will vary
while all: while all:
next = all.pop() next = all.pop()
if self.env.has_node(next): if next in env.nodes:
yield next yield next
def publish(self):
self.env.get_nodes = self.query
def __eq__(self, other): class PrintListener(object):
return isinstance(other, NodeFinder) and self.env is other.env
def __init__(self, active = True):
self.active = active
def on_attach(self, env):
if self.active:
print "-- attaching to: ", env
def on_deattach(self, env):
if self.active:
print "-- deattaching from: ", env
def on_import(self, env, node):
if self.active:
print "-- importing: %s" % node
def on_prune(self, env, node):
if self.active:
print "-- pruning: %s" % node
def on_change_input(self, env, node, i, r, new_r):
if self.active:
print "-- changing (%s.inputs[%s]) from %s to %s" % (node, i, r, new_r)
# class EquivTool(dict):
# def __init__(self, env):
# self.env = env
# def on_rewire(self, clients, r, new_r):
# repl = self(new_r)
# if repl is r:
# self.ungroup(r, new_r)
# elif repl is not new_r:
# raise Exception("Improper use of EquivTool!")
# else:
# self.group(new_r, r)
# def publish(self):
# self.env.equiv = self
# self.env.set_equiv = self.set_equiv
# def unpublish(self):
# del self.env.equiv
# del self.env.set_equiv
# def set_equiv(self, d):
# self.update(d)
# def group(self, main, *keys):
# "Marks all the keys as having been replaced by the Result main."
# keys = [key for key in keys if key is not main]
# if self.has_key(main):
# raise Exception("Only group results that have not been grouped before.")
# for key in keys:
# if self.has_key(key):
# raise Exception("Only group results that have not been grouped before.")
# if key is main:
# continue
# self.setdefault(key, main)
# def ungroup(self, main, *keys):
# "Undoes group(main, *keys)"
# keys = [key for key in keys if key is not main]
# for key in keys:
# if self[key] is main:
# del self[key]
# def __call__(self, key):
# "Returns the currently active replacement for the given key."
# next = self.get(key, None)
# while next:
# key = next
# next = self.get(next, None)
# return key
# class InstanceFinder(Listener, Tool, dict): # class InstanceFinder(Listener, Tool, dict):
...@@ -158,28 +302,6 @@ class NodeFinder(dict): ...@@ -158,28 +302,6 @@ class NodeFinder(dict):
class PrintListener(object):
def __init__(self, env, active = True):
self.env = env
self.active = active
if active:
print "-- initializing"
def on_import(self, node):
if self.active:
print "-- importing: %s" % node
def on_prune(self, node):
if self.active:
print "-- pruning: %s" % node
def on_rewire(self, clients, r, new_r):
if self.active:
if r.owner is not None: r = r.owner
if new_r.owner is not None: new_r = new_r.owner
print "-- moving from %s to %s" % (r, new_r)
### UNUSED AND UNTESTED ### ### UNUSED AND UNTESTED ###
......
...@@ -26,6 +26,8 @@ class object2(object): ...@@ -26,6 +26,8 @@ class object2(object):
if hasattr(self, '__eq__') or hasattr(self, '__cmp__'): if hasattr(self, '__eq__') or hasattr(self, '__cmp__'):
raise TypeError("unhashable object: %s" % self) raise TypeError("unhashable object: %s" % self)
return id(self) return id(self)
def __ne__(self, other):
return not self == other
class scratchpad: class scratchpad:
def clear(self): def clear(self):
......
...@@ -71,7 +71,7 @@ class Tensor(Type): ...@@ -71,7 +71,7 @@ class Tensor(Type):
def __init__(self, dtype, broadcastable): def __init__(self, dtype, broadcastable):
self.dtype = str(dtype) self.dtype = str(dtype)
self.broadcastable = broadcastable self.broadcastable = tuple(broadcastable)
self.dtype_specs() # error checking is done there self.dtype_specs() # error checking is done there
def filter(self, data, strict = False): def filter(self, data, strict = False):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论