提交 cca46fc2 authored 作者: Olivier Breuleux's avatar Olivier Breuleux

wrote the NavigatorOptimizer interface to apply local optimizations to nodes…

wrote the NavigatorOptimizer interface to apply local optimizations to nodes following various traversal orders - fixes #79 with TopoOptimizer
上级 b801ba52
from cc import CLinker, OpWiseCLinker, DualLinker from cc import \
from env import InconsistencyError, Env CLinker, OpWiseCLinker, DualLinker
from ext import DestroyHandler, view_roots
from graph import Apply, Result, Constant, Value from env import \
from link import Linker, LocalLinker, PerformLinker, MetaLinker, Profiler InconsistencyError, Env
from op import Op, Macro
from opt import Optimizer, DummyOpt, SeqOptimizer, LocalOptimizer, OpSpecificOptimizer, OpSubOptimizer, OpRemover, PatternOptimizer, MergeOptimizer, MergeOptMerge from ext import \
from toolbox import Bookkeeper, History, Validator, ReplaceValidate, NodeFinder, PrintListener DestroyHandler, view_roots
from type import Type, Generic, generic
from utils import object2, AbstractFunctionError from graph import \
Apply, Result, Constant, Value
from link import \
Linker, LocalLinker, PerformLinker, MetaLinker, Profiler
from op import \
Op, Macro
from opt import \
Optimizer, SeqOptimizer, \
MergeOptimizer, MergeOptMerge, \
LocalOptimizer, LocalOptGroup, LocalOpKeyOptGroup, \
ExpandMacro, OpSub, OpRemove, PatternSub, \
NavigatorOptimizer, TopoOptimizer, OpKeyOptimizer, \
expand_macros
from toolbox import \
Bookkeeper, History, Validator, ReplaceValidate, NodeFinder, PrintListener
from type import \
Type, Generic, generic
from utils import \
object2, AbstractFunctionError
...@@ -10,7 +10,8 @@ from toolbox import * ...@@ -10,7 +10,8 @@ from toolbox import *
def as_result(x): def as_result(x):
assert isinstance(x, Result) if not isinstance(x, Result):
raise TypeError("not a Result", x)
return x return x
...@@ -69,6 +70,9 @@ def inputs(): ...@@ -69,6 +70,9 @@ def inputs():
return x, y, z return x, y, z
PatternOptimizer = lambda p1, p2, ign=False: OpKeyOptimizer(PatternSub(p1, p2), ignore_newtrees=ign)
TopoPatternOptimizer = lambda p1, p2, ign=True: TopoOptimizer(PatternSub(p1, p2), ignore_newtrees=ign)
class _test_PatternOptimizer(unittest.TestCase): class _test_PatternOptimizer(unittest.TestCase):
def test_replace_output(self): def test_replace_output(self):
...@@ -116,13 +120,14 @@ class _test_PatternOptimizer(unittest.TestCase): ...@@ -116,13 +120,14 @@ class _test_PatternOptimizer(unittest.TestCase):
assert str(g) == "[Op1(Op1(y, x), z)]" assert str(g) == "[Op1(Op1(y, x), z)]"
def test_no_recurse(self): def test_no_recurse(self):
# if the out pattern is an acceptable in pattern, # if the out pattern is an acceptable in pattern
# and that the ignore_newtrees flag is True,
# it should do the replacement and stop # it should do the replacement and stop
x, y, z = inputs() x, y, z = inputs()
e = op1(op2(x, y), z) e = op1(op2(x, y), z)
g = Env([x, y, z], [e]) g = Env([x, y, z], [e])
PatternOptimizer((op2, '1', '2'), PatternOptimizer((op2, '1', '2'),
(op2, '2', '1')).optimize(g) (op2, '2', '1'), ign=True).optimize(g)
assert str(g) == "[Op1(Op2(y, x), z)]" assert str(g) == "[Op1(Op2(y, x), z)]"
def test_multiple(self): def test_multiple(self):
...@@ -157,20 +162,19 @@ class _test_PatternOptimizer(unittest.TestCase): ...@@ -157,20 +162,19 @@ class _test_PatternOptimizer(unittest.TestCase):
e = op1(op1(op1(x))) e = op1(op1(op1(x)))
g = Env([x, y, z], [e]) g = Env([x, y, z], [e])
PatternOptimizer((op1, '1'), PatternOptimizer((op1, '1'),
(op2, (op1, '1'))).optimize(g) (op2, (op1, '1')), ign=True).optimize(g)
assert str(g) == "[Op2(Op1(Op2(Op1(Op2(Op1(x))))))]" assert str(g) == "[Op2(Op1(Op2(Op1(Op2(Op1(x))))))]"
# def test_ambiguous(self): def test_ambiguous(self):
# # this test is known to fail most of the time # this test should always work with TopoOptimizer and the
# # the reason is that PatternOptimizer doesn't go through # ignore_newtrees flag set to False. Behavior with ignore_newtrees
# # the ops in topological order. The order is random and # = True or with other NavigatorOptimizers may differ.
# # it does not visit ops that it creates. x, y, z = inputs()
# x, y, z = inputs() e = op1(op1(op1(op1(op1(x)))))
# e = op1(op1(op1(op1(op1(x))))) g = Env([x, y, z], [e])
# g = Env([x, y, z], [e]) TopoPatternOptimizer((op1, (op1, '1')),
# PatternOptimizer((op1, (op1, '1')), (op1, '1'), ign=False).optimize(g)
# (op1, '1')).optimize(g) assert str(g) == "[Op1(x)]"
# assert str(g) == "[Op1(x)]"
def test_constant_unification(self): def test_constant_unification(self):
x = Constant(MyType(), 2, name = 'x') x = Constant(MyType(), 2, name = 'x')
...@@ -186,7 +190,7 @@ class _test_PatternOptimizer(unittest.TestCase): ...@@ -186,7 +190,7 @@ class _test_PatternOptimizer(unittest.TestCase):
x, y, z = inputs() x, y, z = inputs()
e = op4(op1(op2(x, y)), op1(op1(x, y))) e = op4(op1(op2(x, y)), op1(op1(x, y)))
g = Env([x, y, z], [e]) g = Env([x, y, z], [e])
def constraint(env, r): def constraint(r):
# Only replacing if the input is an instance of Op2 # Only replacing if the input is an instance of Op2
return r.owner.op == op2 return r.owner.op == op2
PatternOptimizer((op1, {'pattern': '1', PatternOptimizer((op1, {'pattern': '1',
...@@ -206,7 +210,7 @@ class _test_PatternOptimizer(unittest.TestCase): ...@@ -206,7 +210,7 @@ class _test_PatternOptimizer(unittest.TestCase):
x, y, z = inputs() x, y, z = inputs()
e = op2(op1(x, x), op1(x, y)) e = op2(op1(x, x), op1(x, y))
g = Env([x, y, z], [e]) g = Env([x, y, z], [e])
def constraint(env, r): def constraint(r):
# Only replacing if the input is an instance of Op2 # Only replacing if the input is an instance of Op2
return r.owner.inputs[0] is not r.owner.inputs[1] return r.owner.inputs[0] is not r.owner.inputs[1]
PatternOptimizer({'pattern': (op1, 'x', 'y'), PatternOptimizer({'pattern': (op1, 'x', 'y'),
...@@ -263,6 +267,9 @@ class _test_PatternOptimizer(unittest.TestCase): ...@@ -263,6 +267,9 @@ class _test_PatternOptimizer(unittest.TestCase):
# assert str(g) == "[Op1(Op3(y, x), Op2(OpZ(y, z)))]" # assert str(g) == "[Op1(Op3(y, x), Op2(OpZ(y, z)))]"
OpSubOptimizer = lambda op1, op2: TopoOptimizer(OpSub(op1, op2))
OpSubOptimizer = lambda op1, op2: OpKeyOptimizer(OpSub(op1, op2))
class _test_OpSubOptimizer(unittest.TestCase): class _test_OpSubOptimizer(unittest.TestCase):
def test_straightforward(self): def test_straightforward(self):
...@@ -413,7 +420,20 @@ class _test_MergeOptimizer(unittest.TestCase): ...@@ -413,7 +420,20 @@ class _test_MergeOptimizer(unittest.TestCase):
# assert not getattr(x, 'constant', False) and z.constant # assert not getattr(x, 'constant', False) and z.constant
# MergeOptimizer().optimize(g) # MergeOptimizer().optimize(g)
reenter = Exception("Re-Entered")
class LoopyMacro(Macro):
def __init__(self):
self.counter = 0
def make_node(self, x, y):
return Apply(self, [x, y], [MyType()()])
def expand(self, node):
x, y = node.inputs
if self.counter > 0:
raise reenter
self.counter += 1
return [self(y, x)]
def __str__(self):
return "loopy_macro"
class _test_ExpandMacro(unittest.TestCase): class _test_ExpandMacro(unittest.TestCase):
...@@ -423,26 +443,31 @@ class _test_ExpandMacro(unittest.TestCase): ...@@ -423,26 +443,31 @@ class _test_ExpandMacro(unittest.TestCase):
return Apply(self, [x, y], [MyType()()]) return Apply(self, [x, y], [MyType()()])
def expand(self, node): def expand(self, node):
return [op1(y, x)] return [op1(y, x)]
def __str__(self):
return "macro"
x, y, z = inputs() x, y, z = inputs()
e = Macro1()(x, y) e = Macro1()(x, y)
g = Env([x, y], [e]) g = Env([x, y], [e])
print g
expand_macros.optimize(g) expand_macros.optimize(g)
print g assert str(g) == "[Op1(y, x)]"
def test_loopy(self): def test_loopy_1(self):
class Macro1(Macro):
def make_node(self, x, y):
return Apply(self, [x, y], [MyType()()])
def expand(self, node):
return [Macro1()(y, x)]
x, y, z = inputs() x, y, z = inputs()
e = Macro1()(x, y) e = LoopyMacro()(x, y)
g = Env([x, y], [e])
TopoOptimizer(ExpandMacro(), ignore_newtrees = True).optimize(g)
assert str(g) == "[loopy_macro(y, x)]"
def test_loopy_2(self):
x, y, z = inputs()
e = LoopyMacro()(x, y)
g = Env([x, y], [e]) g = Env([x, y], [e])
print g try:
#expand_macros.optimize(g) TopoOptimizer(ExpandMacro(), ignore_newtrees = False).optimize(g)
TopDownOptimizer(ExpandMacro(), ignore_newtrees = True).optimize(g) self.fail("should not arrive here")
print g except Exception, e:
if e is not reenter:
raise
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
from copy import copy from copy import copy
import graph import graph
import utils import utils
import toolbox
class InconsistencyError(Exception): class InconsistencyError(Exception):
...@@ -273,14 +274,13 @@ class Env(utils.object2): ...@@ -273,14 +274,13 @@ class Env(utils.object2):
""" """
if feature in self._features: if feature in self._features:
return # the feature is already present return # the feature is already present
self._features.append(feature)
attach = getattr(feature, 'on_attach', None) attach = getattr(feature, 'on_attach', None)
if attach is not None: if attach is not None:
try: try:
attach(self) attach(self)
except: except toolbox.AlreadyThere:
self._features.pop() return
raise self._features.append(feature)
def remove_feature(self, feature): def remove_feature(self, feature):
""" """
......
...@@ -72,7 +72,7 @@ class DestroyHandlerHelper(toolbox.Bookkeeper): ...@@ -72,7 +72,7 @@ class DestroyHandlerHelper(toolbox.Bookkeeper):
raise Exception("A DestroyHandler instance can only serve one Env.") raise Exception("A DestroyHandler instance can only serve one Env.")
for attr in ('destroyers', 'destroy_handler'): for attr in ('destroyers', 'destroy_handler'):
if hasattr(env, attr): if hasattr(env, attr):
raise Exception("DestroyHandler feature is already present or in conflict with another plugin.") raise toolbox.AlreadyThere("DestroyHandler feature is already present or in conflict with another plugin.")
def __destroyers(r): def __destroyers(r):
ret = self.destroyers.get(r, {}) ret = self.destroyers.get(r, {})
......
...@@ -10,6 +10,8 @@ import utils ...@@ -10,6 +10,8 @@ import utils
import unify import unify
import toolbox import toolbox
import op import op
from copy import copy
from collections import deque
class Optimizer: class Optimizer:
...@@ -54,9 +56,6 @@ class Optimizer: ...@@ -54,9 +56,6 @@ class Optimizer:
pass pass
DummyOpt = Optimizer()
DummyOpt.__doc__ = "Does nothing."
class SeqOptimizer(Optimizer, list): class SeqOptimizer(Optimizer, list):
""" """
...@@ -84,167 +83,212 @@ class SeqOptimizer(Optimizer, list): ...@@ -84,167 +83,212 @@ class SeqOptimizer(Optimizer, list):
class LocalOptimizer(Optimizer): class _metadict:
# dict that accepts unhashable keys
# uses an associative list
# for internal use only
def __init__(self):
self.d = {}
self.l = []
def __getitem__(self, item):
return self.get(item, None)
def __setitem__(self, item, value):
try:
self.d[item] = value
except:
self.l.append((item, value))
def get(self, item, default):
try:
return self.d[item]
except:
for item2, value in self.l:
try:
if item == item2:
return value
if item.equals(item2):
return value
except:
if item is item2:
return value
else:
return default
def clear(self):
self.d = {}
self.l = []
def __str__(self):
return "(%s, %s)" % (self.d, self.l)
class MergeOptimizer(Optimizer):
""" """
Generic L{Optimizer} class that considers local parts of Merges parts of the graph that are identical, i.e. parts that
the L{Env}. It must be subclassed and should override the take the same inputs and carry out the asme computations so we
following two methods: can avoid doing them more than once. Also merges results that
- candidates(env) -> returns a set of ops that can be are constant.
optimized
- apply_on_node(env, node) -> for each node in candidates,
this function will be called to perform the actual
optimization.
""" """
def candidates(self, env): def add_requirements(self, env):
""" try:
Must return a set of nodes that can be optimized. env.extend(toolbox.ReplaceValidate())
""" except: pass
raise utils.AbstractFunctionError()
def apply_on_node(self, env, node):
"""
For each node in candidates, this function will be called to
perform the actual optimization.
"""
raise utils.AbstractFunctionError()
def apply(self, env): def apply(self, env):
""" cid = _metadict() #result -> result.desc() (for constants)
Calls self.apply_on_op(env, op) for each op in self.candidates(env). inv_cid = _metadict() #desc -> result (for constants)
""" for i, r in enumerate([r for r in env.results if isinstance(r, graph.Constant)]):
for node in self.candidates(env): sig = r.signature()
if node in env.nodes: other_r = inv_cid.get(sig, None)
self.apply_on_node(env, node) if other_r is not None:
env.replace(r, other_r)
else:
cid[r] = sig
inv_cid[sig] = r
# we clear the dicts because the Constants signatures are not necessarily hashable
# and it's more efficient to give them an integer cid like the other Results
cid.clear()
inv_cid.clear()
for i, r in enumerate(r for r in env.results if r.owner is None):
cid[r] = i
inv_cid[i] = r
for node in graph.io_toposort(env.inputs, env.outputs):
node_cid = (node.op, tuple([cid[input] for input in node.inputs]))
dup = inv_cid.get(node_cid, None)
success = False
if dup is not None:
success = True
try:
env.replace_all_validate(zip(node.outputs, dup.outputs))
except InconsistencyError, e:
success = False
if not success:
cid[node] = node_cid
inv_cid[node_cid] = node
for i, output in enumerate(node.outputs):
ref = (i, node_cid)
cid[output] = ref
inv_cid[ref] = output
class OpSpecificOptimizer(LocalOptimizer): def MergeOptMerge(opt):
""" """
Generic L{Optimizer} that applies only to ops of a certain Returns an Optimizer that merges the graph then applies the
type. The type in question is accessed through L{self.op}. optimizer in opt and then merges the graph again in case the
op can also be a class variable of the subclass. opt introduced additional similarities.
""" """
merger = MergeOptimizer()
return SeqOptimizer([merger, opt, merger])
def add_requirements(self, env):
try:
env.extend(toolbox.NodeFinder())
env.extend(toolbox.ReplaceValidate())
except: pass
def candidates(self, env):
"""
Returns all nodes that have L{self.op} in their op field.
"""
return env.get_nodes(self.op)
########################
### Local Optimizers ###
########################
class LocalOptimizer(utils.object2):
def transform(self, node):
raise utils.AbstractFunctionError()
class LocalOptGroup(LocalOptimizer):
def __init__(self, optimizers):
self.opts = optimizers
self.reentrant = any(getattr(opt, 'reentrant', True), optimizers)
self.retains_inputs = all(getattr(opt, 'retains_inputs', False), optimizers)
def transform(self, node):
for opt in self.opts:
repl = opt.transform(node)
if repl is not False:
return repl
class LocalOpKeyOptGroup(LocalOptGroup):
def __init__(self, optimizers):
if any(not hasattr(opt, 'op_key'), optimizers):
raise TypeError("All LocalOptimizers passed here must have an op_key method.")
CompositeLocalOptimizer.__init__(self, optimizers)
def op_key(self):
return [opt.op_key() for opt in self.opts]
class ExpandMacro(LocalOptimizer):
def transform(self, node):
if not isinstance(node.op, op.Macro):
return False
return node.op.expand(node)
class OpSubOptimizer(Optimizer): class OpSub(LocalOptimizer):
""" """
Replaces all applications of a certain op by the application of Replaces the application of a certain op by the application of
another op that take the same inputs as what they are replacing. another op that take the same inputs as what they are replacing.
e.g. OpSubOptimizer(add, sub) ==> add(div(x, y), add(y, x)) -> sub(div(x, y), sub(y, x)) e.g. OpSub(add, sub) ==> add(div(x, y), add(y, x)) -> sub(div(x, y), sub(y, x))
OpSubOptimizer requires the following features:
- NodeFinder
- ReplaceValidate
""" """
def add_requirements(self, env): reentrant = False # an OpSub does not apply to the nodes it produces
""" retains_inputs = True # all the inputs of the original node are transferred to the outputs
Requires the following features:
- NodeFinder
- ReplaceValidate
"""
try:
env.extend(toolbox.NodeFinder())
env.extend(toolbox.ReplaceValidate())
except: pass
def __init__(self, op1, op2, failure_callback = None): def __init__(self, op1, op2, transfer_tags = True):
""" """
op1.make_node and op2.make_node must take the same number of op1.make_node and op2.make_node must take the same number of
inputs and have the same number of outputs. inputs and have the same number of outputs.
If failure_callback is not None, it will be called whenever
the Optimizer fails to do a replacement in the graph. The
arguments to the callback are: (node, replacement, exception)
""" """
self.op1 = op1 self.op1 = op1
self.op2 = op2 self.op2 = op2
self.failure_callback = failure_callback self.transfer_tags = transfer_tags
def apply(self, env): def op_key(self):
""" return self.op1
Replaces all applications of self.op1 by applications of self.op2
with the same inputs. def transform(self, node):
""" if node.op != self.op1:
candidates = env.get_nodes(self.op1) return False
repl = self.op2.make_node(*node.inputs)
for node in candidates: if self.transfer_tags:
try: repl.tag = copy(node.tag)
repl = self.op2.make_node(*node.inputs) for output, new_output in zip(node.outputs, repl.outputs):
assert len(node.outputs) == len(repl.outputs) new_output.tag = copy(output.tag)
for old, new in zip(node.outputs, repl.outputs): return repl.outputs
env.replace_validate(old, new)
except Exception, e:
if self.failure_callback is not None:
self.failure_callback(node, repl, e)
def str(self): def str(self):
return "%s -> %s" % (self.op1, self.op2) return "%s -> %s" % (self.op1, self.op2)
class OpRemove(LocalOptimizer):
class OpRemover(Optimizer):
""" """
@todo untested
Removes all applications of an op by transferring each of its Removes all applications of an op by transferring each of its
outputs to the corresponding input. outputs to the corresponding input.
""" """
def add_requirements(self, env): reentrant = False # no nodes are added at all
try:
env.extend(toolbox.NodeFinder())
env.extend(toolbox.ReplaceValidate())
except: pass
def __init__(self, op, failure_callback = None): def __init__(self, op):
""" """
Applications of the op must have as many inputs as outputs. op1.make_node and op2.make_node must take the same number of
inputs and have the same number of outputs.
If failure_callback is not None, it will be called whenever
the Optimizer fails to remove an operation in the graph. The
arguments to the callback are: (node, exception)
""" """
self.op = op self.op = op
self.failure_callback = failure_callback
def apply(self, env): def op_key(self):
""" return self.op
Removes all applications of self.op.
"""
candidates = env.get_nodes(self.op)
for node in candidates:
try:
assert len(node.inputs) == len(node.outputs)
for input, output in zip(node.inputs, node.outputs):
env.replace(output, input)
except Exception, e:
if self.failure_callback is not None:
self.failure_callback(node, e)
pass
def str(self): def transform(self, node):
return "f(%s(x)) -> f(x)" % self.op if node.op != self.op:
return False
return node.inputs
def str(self):
return "%s(x) -> x" % (self.op)
class PatternOptimizer(OpSpecificOptimizer): class PatternSub(LocalOptimizer):
""" """
@todo update @todo update
...@@ -289,14 +333,10 @@ class PatternOptimizer(OpSpecificOptimizer): ...@@ -289,14 +333,10 @@ class PatternOptimizer(OpSpecificOptimizer):
(scrabble, 'x')) (scrabble, 'x'))
""" """
def __init__(self, in_pattern, out_pattern, allow_multiple_clients = False, failure_callback = None): def __init__(self, in_pattern, out_pattern, allow_multiple_clients = False):
""" """
Creates a PatternOptimizer that replaces occurrences of Creates a PatternOptimizer that replaces occurrences of
in_pattern by occurrences of out_pattern. in_pattern by occurrences of out_pattern.
If failure_callback is not None, if there is a match but a
replacement fails to occur, the callback will be called with
arguments (result_to_replace, replacement, exception).
If allow_multiple_clients is False, he pattern matching will If allow_multiple_clients is False, he pattern matching will
fail if one of the subpatterns has more than one client. fail if one of the subpatterns has more than one client.
...@@ -310,19 +350,23 @@ class PatternOptimizer(OpSpecificOptimizer): ...@@ -310,19 +350,23 @@ class PatternOptimizer(OpSpecificOptimizer):
else: else:
raise TypeError("The pattern to search for must start with a specific Op instance.") raise TypeError("The pattern to search for must start with a specific Op instance.")
self.__doc__ = self.__class__.__doc__ + "\n\nThis instance does: " + str(self) + "\n" self.__doc__ = self.__class__.__doc__ + "\n\nThis instance does: " + str(self) + "\n"
self.failure_callback = failure_callback
self.allow_multiple_clients = allow_multiple_clients self.allow_multiple_clients = allow_multiple_clients
def apply_on_node(self, env, node): def op_key(self):
return self.op
def transform(self, node):
""" """
Checks if the graph from node corresponds to in_pattern. If it does, Checks if the graph from node corresponds to in_pattern. If it does,
constructs out_pattern and performs the replacement. constructs out_pattern and performs the replacement.
""" """
if node.op != self.op:
return False
def match(pattern, expr, u, first = False): def match(pattern, expr, u, first = False):
if isinstance(pattern, (list, tuple)): if isinstance(pattern, (list, tuple)):
if expr.owner is None: if expr.owner is None:
return False return False
if not (expr.owner.op == pattern[0]) or (not self.allow_multiple_clients and not first and env.nclients(expr) > 1): if not (expr.owner.op == pattern[0]) or (not self.allow_multiple_clients and not first and len(expr.clients) > 1):
return False return False
if len(pattern) - 1 != len(expr.owner.inputs): if len(pattern) - 1 != len(expr.owner.inputs):
return False return False
...@@ -336,8 +380,10 @@ class PatternOptimizer(OpSpecificOptimizer): ...@@ -336,8 +380,10 @@ class PatternOptimizer(OpSpecificOptimizer):
constraint = pattern['constraint'] constraint = pattern['constraint']
except KeyError: except KeyError:
raise KeyError("Malformed pattern: %s (expected keys pattern and constraint)" % pattern) raise KeyError("Malformed pattern: %s (expected keys pattern and constraint)" % pattern)
if constraint(env, expr): if constraint(expr):
return match(real_pattern, expr, u, False) return match(real_pattern, expr, u, False)
else:
return False
elif isinstance(pattern, str): elif isinstance(pattern, str):
v = unify.Var(pattern) v = unify.Var(pattern)
if u[v] is not v and u[v] is not expr: if u[v] is not v and u[v] is not expr:
...@@ -361,16 +407,11 @@ class PatternOptimizer(OpSpecificOptimizer): ...@@ -361,16 +407,11 @@ class PatternOptimizer(OpSpecificOptimizer):
u = match(self.in_pattern, node.out, unify.Unification(), True) u = match(self.in_pattern, node.out, unify.Unification(), True)
if u: if u:
try: p = self.out_pattern
# note: only replaces the default 'out' port if it exists new = build(p, u)
p = self.out_pattern return [new]
new = 'unassigned' # this is for the callback if build fails else:
new = build(p, u) return False
env.replace(node.out, new)
except Exception, e:
if self.failure_callback is not None:
self.failure_callback(node.out, new, e)
pass
def __str__(self): def __str__(self):
def pattern_to_str(pattern): def pattern_to_str(pattern):
...@@ -384,166 +425,561 @@ class PatternOptimizer(OpSpecificOptimizer): ...@@ -384,166 +425,561 @@ class PatternOptimizer(OpSpecificOptimizer):
##################
### Navigators ###
##################
class _metadict: # Use the following classes to apply LocalOptimizers
# dict that accepts unhashable keys
# uses an associative list
# for internal use only class NavigatorOptimizer(Optimizer):
def __init__(self):
self.d = {} def __init__(self, local_opt, ignore_newtrees = 'auto', failure_callback = None):
self.l = [] self.local_opt = local_opt
def __getitem__(self, item): if ignore_newtrees == 'auto':
return self.get(item, None) self.ignore_newtrees = not getattr(local_opt, 'reentrant', True)
def __setitem__(self, item, value): else:
try: self.ignore_newtrees = ignore_newtrees
self.d[item] = value self.failure_callback = failure_callback
except:
self.l.append((item, value)) def attach_updater(self, env, importer, pruner):
def get(self, item, default): if self.ignore_newtrees:
importer = None
if importer is None and pruner is None:
return None
class Updater:
if importer is not None:
def on_import(self, env, node):
importer(node)
if pruner is not None:
def on_prune(self, env, node):
pruner(node)
u = Updater()
env.extend(u)
return u
def detach_updater(self, env, u):
if u is not None:
env.remove_feature(u)
def process_node(self, env, node):
replacements = self.local_opt.transform(node)
if replacements is False:
return
repl_pairs = zip(node.outputs, replacements)
try: try:
return self.d[item] env.replace_all_validate(repl_pairs)
except: except Exception, e:
for item2, value in self.l: if self.failure_callback is not None:
try: self.failure_callback(e, self, repl_pairs)
if item == item2:
return value
if item.equals(item2):
return value
except:
if item is item2:
return value
else: else:
return default raise
def clear(self):
self.d = {}
self.l = []
def __str__(self):
return "(%s, %s)" % (self.d, self.l)
def add_requirements(self, env):
env.extend(toolbox.ReplaceValidate())
class MergeOptimizer(Optimizer):
"""
Merges parts of the graph that are identical, i.e. parts that
take the same inputs and carry out the asme computations so we
can avoid doing them more than once. Also merges results that
are constant.
"""
def add_requirements(self, env):
class TopoOptimizer(NavigatorOptimizer):
def __init__(self, local_opt, order = 'out_to_in', ignore_newtrees = False, failure_callback = None):
if order not in ['out_to_in', 'in_to_out']:
raise ValueError("order must be 'out_to_in' or 'in_to_out'")
self.order = order
NavigatorOptimizer.__init__(self, local_opt, ignore_newtrees, failure_callback)
def apply(self, env):
q = deque(graph.io_toposort(env.inputs, env.outputs))
def importer(node):
q.append(node)
def pruner(node):
if node is not current_node:
q.remove(node)
u = self.attach_updater(env, importer, pruner)
try: try:
env.extend(toolbox.ReplaceValidate()) while q:
except: pass if self.order == 'out_to_in':
node = q.pop()
else:
node = q.popleft()
current_node = node
self.process_node(env, node)
except:
self.detach_updater(env, u)
raise
class OpKeyOptimizer(NavigatorOptimizer):
def __init__(self, local_opt, ignore_newtrees = False, failure_callback = None):
if not hasattr(local_opt, 'op_key'):
raise TypeError("LocalOptimizer for OpKeyOptimizer must have an 'op_key' method.")
NavigatorOptimizer.__init__(self, local_opt, ignore_newtrees, failure_callback)
def apply(self, env): def apply(self, env):
cid = _metadict() #result -> result.desc() (for constants) op = self.local_opt.op_key()
inv_cid = _metadict() #desc -> result (for constants) if isinstance(op, (list, tuple)):
for i, r in enumerate([r for r in env.results if isinstance(r, graph.Constant)]): q = reduce(list.__iadd__, map(env.get_nodes, op))
sig = r.signature() else:
other_r = inv_cid.get(sig, None) q = list(env.get_nodes(op))
if other_r is not None: def importer(node):
env.replace(r, other_r) if node.op == op: q.append(node)
else: def pruner(node):
cid[r] = sig if node is not current_node and node.op == op:
inv_cid[sig] = r q.remove(node)
# we clear the dicts because the Constants signatures are not necessarily hashable u = self.attach_updater(env, importer, pruner)
# and it's more efficient to give them an integer cid like the other Results try:
cid.clear() while q:
inv_cid.clear() node = q.pop()
for i, r in enumerate(r for r in env.results if r.owner is None): current_node = node
cid[r] = i self.process_node(env, node)
inv_cid[i] = r except:
self.detach_updater(env, u)
raise
for node in graph.io_toposort(env.inputs, env.outputs): def add_requirements(self, env):
node_cid = (node.op, tuple([cid[input] for input in node.inputs])) """
dup = inv_cid.get(node_cid, None) Requires the following features:
success = False - NodeFinder
if dup is not None: - ReplaceValidate
success = True """
try: NavigatorOptimizer.add_requirements(self, env)
env.replace_all_validate(zip(node.outputs, dup.outputs)) env.extend(toolbox.NodeFinder())
except InconsistencyError, e:
success = False
if not success:
cid[node] = node_cid
inv_cid[node_cid] = node
for i, output in enumerate(node.outputs):
ref = (i, node_cid)
cid[output] = ref
inv_cid[ref] = output
def MergeOptMerge(opt): ##############################
""" ### Pre-defined optimizers ###
Returns an Optimizer that merges the graph then applies the ##############################
optimizer in opt and then merges the graph again in case the
opt introduced additional similarities.
"""
merger = MergeOptimizer()
return SeqOptimizer([merger, opt, merger])
expand_macros = TopoOptimizer(ExpandMacro(), ignore_newtrees = False)
class LocalOptimizer:
def applies(self, node):
raise utils.AbstractFunctionError()
def transform(self, node):
raise utils.AbstractFunctionError()
class ExpandMacro:
def applies(self, node):
return isinstance(node.op, op.Macro)
def transform(self, node):
return node.op.expand(node)
from collections import deque
class TopDownOptimizer(Optimizer):
def __init__(self, local_opt, ignore_newtrees = False):
self.local_opt = local_opt
self.ignore_newtrees = ignore_newtrees
def apply(self, env):
ignore_newtrees = self.ignore_newtrees
q = deque()
class Updater:
def on_attach(self, env):
for node in graph.io_toposort(env.inputs, env.outputs):
q.appendleft(node)
if not ignore_newtrees:
def on_import(self, env, node):
q.appendleft(node)
def on_prune(self, env, node):
if node is not current_node:
q.remove(node)
u = Updater()
env.extend(u)
while q:
node = q.popleft()
current_node = node
if not self.local_opt.applies(node):
continue
replacements = self.local_opt.transform(node)
for output, replacement in zip(node.outputs, replacements):
env.replace_validate(output, replacement)
env.remove_feature(u)
def add_requirements(self, env):
try:
env.extend(toolbox.ReplaceValidate())
except: pass
expand_macros = TopDownOptimizer(ExpandMacro())
# class TopoOptimizer(Optimizer):
# def __init__(self, local_opt, order = 'out_to_in', ignore_newtrees = False, failure_callback = None):
# self.local_opt = local_opt
# if order not in ['out_to_in', 'in_to_out']:
# raise ValueError("order must be 'out_to_in' or 'in_to_out'")
# self.order = order
# self.ignore_newtrees = ignore_newtrees
# self.failure_callback = failure_callback
# def apply(self, env):
# ignore_newtrees = self.ignore_newtrees
# q = deque()
# class Updater:
# def on_attach(self, env):
# for node in graph.io_toposort(env.inputs, env.outputs):
# q.append(node)
# if not ignore_newtrees:
# def on_import(self, env, node):
# q.append(node)
# def on_prune(self, env, node):
# if node is not current_node:
# q.remove(node)
# u = Updater()
# env.extend(u)
# while q:
# if self.order == 'out_to_in':
# node = q.pop()
# else:
# node = q.popleft()
# current_node = node
# if not self.local_opt.applies(node):
# continue
# replacements = self.local_opt.transform(node)
# repl_pairs = zip(node.outputs, replacements)
# try:
# env.replace_all_validate(repl_pairs)
# except Exception, e:
# if self.failure_callback is not None:
# self.failure_callback(e, self, repl_pairs)
# else:
# raise
# env.remove_feature(u)
# def add_requirements(self, env):
# try:
# env.extend(toolbox.ReplaceValidate())
# except: pass
# class OpKeyOptimizer(Optimizer):
# def __init__(self, local_opt, ignore_newtrees = False, failure_callback = None):
# self.local_opt = local_opt
# if not hasattr(local_opt, 'op_key'):
# raise TypeError("LocalOptimizer for OpKeyOptimizer must have an 'op_key' method.")
# self.ignore_newtrees = ignore_newtrees
# self.failure_callback = failure_callback
# def apply(self, env):
# ignore_newtrees = self.ignore_newtrees
# op = self.local_opt.op_key()
# q = []
# class Updater:
# def on_attach(self, env):
# for node in graph.io_toposort(env.inputs, env.outputs):
# if node.op == op: q.append(node)
# if not ignore_newtrees:
# def on_import(self, env, node):
# if node.op == op: q.append(node)
# def on_prune(self, env, node):
# if node is not current_node:
# q.remove(node)
# u = Updater()
# env.extend(u)
# q = list(env.get_nodes(op))
# while q:
# node = q.pop()
# current_node = node
# if not self.local_opt.applies(node):
# continue
# replacements = self.local_opt.transform(node)
# repl_pairs = zip(node.outputs, replacements)
# try:
# env.replace_all_validate(repl_pairs)
# except Exception, e:
# if self.failure_callback is not None:
# self.failure_callback(e, self, repl_pairs)
# else:
# raise
# env.remove_feature(u)
# def add_requirements(self):
# """
# Requires the following features:
# - NodeFinder
# - ReplaceValidate
# """
# try:
# env.extend(toolbox.NodeFinder())
# env.extend(toolbox.ReplaceValidate())
# except: pass
# class OpSpecificOptimizer(LocalOptimizer):
# """
# Generic L{Optimizer} that applies only to ops of a certain
# type. The type in question is accessed through L{self.op}.
# op can also be a class variable of the subclass.
# """
# def add_requirements(self, env):
# try:
# env.extend(toolbox.NodeFinder())
# env.extend(toolbox.ReplaceValidate())
# except: pass
# def candidates(self, env):
# """
# Returns all nodes that have L{self.op} in their op field.
# """
# return env.get_nodes(self.op)
# class OpSubOptimizer(Optimizer):
# """
# Replaces all applications of a certain op by the application of
# another op that take the same inputs as what they are replacing.
# e.g. OpSubOptimizer(add, sub) ==> add(div(x, y), add(y, x)) -> sub(div(x, y), sub(y, x))
# OpSubOptimizer requires the following features:
# - NodeFinder
# - ReplaceValidate
# """
# def add_requirements(self, env):
# """
# Requires the following features:
# - NodeFinder
# - ReplaceValidate
# """
# try:
# env.extend(toolbox.NodeFinder())
# env.extend(toolbox.ReplaceValidate())
# except: pass
# def __init__(self, op1, op2, failure_callback = None):
# """
# op1.make_node and op2.make_node must take the same number of
# inputs and have the same number of outputs.
# If failure_callback is not None, it will be called whenever
# the Optimizer fails to do a replacement in the graph. The
# arguments to the callback are: (node, replacement, exception)
# """
# self.op1 = op1
# self.op2 = op2
# self.failure_callback = failure_callback
# def apply(self, env):
# """
# Replaces all applications of self.op1 by applications of self.op2
# with the same inputs.
# """
# candidates = env.get_nodes(self.op1)
# for node in candidates:
# try:
# repl = self.op2.make_node(*node.inputs)
# assert len(node.outputs) == len(repl.outputs)
# for old, new in zip(node.outputs, repl.outputs):
# env.replace_validate(old, new)
# except Exception, e:
# if self.failure_callback is not None:
# self.failure_callback(node, repl, e)
# def str(self):
# return "%s -> %s" % (self.op1, self.op2)
# class OpRemover(Optimizer):
# """
# @todo untested
# Removes all applications of an op by transferring each of its
# outputs to the corresponding input.
# """
# def add_requirements(self, env):
# try:
# env.extend(toolbox.NodeFinder())
# env.extend(toolbox.ReplaceValidate())
# except: pass
# def __init__(self, op, failure_callback = None):
# """
# Applications of the op must have as many inputs as outputs.
# If failure_callback is not None, it will be called whenever
# the Optimizer fails to remove an operation in the graph. The
# arguments to the callback are: (node, exception)
# """
# self.op = op
# self.failure_callback = failure_callback
# def apply(self, env):
# """
# Removes all applications of self.op.
# """
# candidates = env.get_nodes(self.op)
# for node in candidates:
# try:
# assert len(node.inputs) == len(node.outputs)
# for input, output in zip(node.inputs, node.outputs):
# env.replace(output, input)
# except Exception, e:
# if self.failure_callback is not None:
# self.failure_callback(node, e)
# pass
# def str(self):
# return "f(%s(x)) -> f(x)" % self.op
# class PatternOptimizer(OpSpecificOptimizer):
# """
# @todo update
# Replaces all occurrences of the input pattern by the output pattern:
# input_pattern ::= (op, <sub_pattern1>, <sub_pattern2>, ...)
# input_pattern ::= dict(pattern = <input_pattern>,
# constraint = <constraint>)
# sub_pattern ::= input_pattern
# sub_pattern ::= string
# sub_pattern ::= a Constant instance
# constraint ::= lambda env, expr: additional matching condition
# output_pattern ::= (op, <output_pattern1>, <output_pattern2>, ...)
# output_pattern ::= string
# Each string in the input pattern is a variable that will be set to
# whatever expression is found in its place. If the same string is
# used more than once, the same expression must be found in those
# places. If a string used in the input pattern is used in the
# output pattern, the matching expression will be inserted in its
# place. The input pattern cannot just be a string but the output
# pattern can.
# If you put a constant result in the input pattern, there will be a
# match iff a constant result with the same value and the same type
# is found in its place.
# You can add a constraint to the match by using the dict(...) form
# described above with a 'constraint' key. The constraint must be a
# function that takes the env and the current Result that we are
# trying to match and returns True or False according to an
# arbitrary criterion.
# Examples:
# PatternOptimizer((add, 'x', 'y'), (add, 'y', 'x'))
# PatternOptimizer((multiply, 'x', 'x'), (square, 'x'))
# PatternOptimizer((subtract, (add, 'x', 'y'), 'y'), 'x')
# PatternOptimizer((power, 'x', Constant(double, 2.0)), (square, 'x'))
# PatternOptimizer((boggle, {'pattern': 'x',
# 'constraint': lambda env, expr: expr.type == scrabble}),
# (scrabble, 'x'))
# """
# def __init__(self, in_pattern, out_pattern, allow_multiple_clients = False, failure_callback = None):
# """
# Creates a PatternOptimizer that replaces occurrences of
# in_pattern by occurrences of out_pattern.
# If failure_callback is not None, if there is a match but a
# replacement fails to occur, the callback will be called with
# arguments (result_to_replace, replacement, exception).
# If allow_multiple_clients is False, he pattern matching will
# fail if one of the subpatterns has more than one client.
# """
# self.in_pattern = in_pattern
# self.out_pattern = out_pattern
# if isinstance(in_pattern, (list, tuple)):
# self.op = self.in_pattern[0]
# elif isinstance(in_pattern, dict):
# self.op = self.in_pattern['pattern'][0]
# else:
# raise TypeError("The pattern to search for must start with a specific Op instance.")
# self.__doc__ = self.__class__.__doc__ + "\n\nThis instance does: " + str(self) + "\n"
# self.failure_callback = failure_callback
# self.allow_multiple_clients = allow_multiple_clients
# def apply_on_node(self, env, node):
# """
# Checks if the graph from node corresponds to in_pattern. If it does,
# constructs out_pattern and performs the replacement.
# """
# def match(pattern, expr, u, first = False):
# if isinstance(pattern, (list, tuple)):
# if expr.owner is None:
# return False
# if not (expr.owner.op == pattern[0]) or (not self.allow_multiple_clients and not first and env.nclients(expr) > 1):
# return False
# if len(pattern) - 1 != len(expr.owner.inputs):
# return False
# for p, v in zip(pattern[1:], expr.owner.inputs):
# u = match(p, v, u)
# if not u:
# return False
# elif isinstance(pattern, dict):
# try:
# real_pattern = pattern['pattern']
# constraint = pattern['constraint']
# except KeyError:
# raise KeyError("Malformed pattern: %s (expected keys pattern and constraint)" % pattern)
# if constraint(env, expr):
# return match(real_pattern, expr, u, False)
# elif isinstance(pattern, str):
# v = unify.Var(pattern)
# if u[v] is not v and u[v] is not expr:
# return False
# else:
# u = u.merge(expr, v)
# elif isinstance(pattern, graph.Constant) and isinstance(expr, graph.Constant) and pattern.equals(expr):
# return u
# else:
# return False
# return u
# def build(pattern, u):
# if isinstance(pattern, (list, tuple)):
# args = [build(p, u) for p in pattern[1:]]
# return pattern[0](*args)
# elif isinstance(pattern, str):
# return u[unify.Var(pattern)]
# else:
# return pattern
# u = match(self.in_pattern, node.out, unify.Unification(), True)
# if u:
# try:
# # note: only replaces the default 'out' port if it exists
# p = self.out_pattern
# new = 'unassigned' # this is for the callback if build fails
# new = build(p, u)
# env.replace(node.out, new)
# except Exception, e:
# if self.failure_callback is not None:
# self.failure_callback(node.out, new, e)
# pass
# def __str__(self):
# def pattern_to_str(pattern):
# if isinstance(pattern, (list, tuple)):
# return "%s(%s)" % (str(pattern[0]), ", ".join([pattern_to_str(p) for p in pattern[1:]]))
# elif isinstance(pattern, dict):
# return "%s subject to %s" % (pattern_to_str(pattern['pattern']), str(pattern['constraint']))
# else:
# return str(pattern)
# return "%s -> %s" % (pattern_to_str(self.in_pattern), pattern_to_str(self.out_pattern))
# class LocalOptimizer(Optimizer):
# """
# Generic L{Optimizer} class that considers local parts of
# the L{Env}. It must be subclassed and should override the
# following two methods:
# - candidates(env) -> returns a set of ops that can be
# optimized
# - apply_on_node(env, node) -> for each node in candidates,
# this function will be called to perform the actual
# optimization.
# """
# def candidates(self, env):
# """
# Must return a set of nodes that can be optimized.
# """
# raise utils.AbstractFunctionError()
# def apply_on_node(self, env, node):
# """
# For each node in candidates, this function will be called to
# perform the actual optimization.
# """
# raise utils.AbstractFunctionError()
# def apply(self, env):
# """
# Calls self.apply_on_op(env, op) for each op in self.candidates(env).
# """
# for node in self.candidates(env):
# if node in env.nodes:
# self.apply_on_node(env, node)
...@@ -3,6 +3,10 @@ from functools import partial ...@@ -3,6 +3,10 @@ from functools import partial
import graph import graph
class AlreadyThere(Exception):
pass
class Bookkeeper: class Bookkeeper:
def on_attach(self, env): def on_attach(self, env):
...@@ -21,7 +25,7 @@ class History: ...@@ -21,7 +25,7 @@ class History:
def on_attach(self, env): def on_attach(self, env):
if hasattr(env, 'checkpoint') or hasattr(env, 'revert'): if hasattr(env, 'checkpoint') or hasattr(env, 'revert'):
raise Exception("History feature is already present or in conflict with another plugin.") raise AlreadyThere("History feature is already present or in conflict with another plugin.")
self.history[env] = [] self.history[env] = []
env.checkpoint = lambda: len(self.history[env]) env.checkpoint = lambda: len(self.history[env])
env.revert = partial(self.revert, env) env.revert = partial(self.revert, env)
...@@ -55,7 +59,7 @@ class Validator: ...@@ -55,7 +59,7 @@ class Validator:
def on_attach(self, env): def on_attach(self, env):
if hasattr(env, 'validate'): if hasattr(env, 'validate'):
raise Exception("Validator feature is already present or in conflict with another plugin.") raise AlreadyThere("Validator feature is already present or in conflict with another plugin.")
env.validate = lambda: env.execute_callbacks('validate') env.validate = lambda: env.execute_callbacks('validate')
def consistent(): def consistent():
try: try:
...@@ -77,7 +81,7 @@ class ReplaceValidate(History, Validator): ...@@ -77,7 +81,7 @@ class ReplaceValidate(History, Validator):
Validator.on_attach(self, env) Validator.on_attach(self, env)
for attr in ('replace_validate', 'replace_all_validate'): for attr in ('replace_validate', 'replace_all_validate'):
if hasattr(env, attr): if hasattr(env, attr):
raise Exception("ReplaceValidate feature is already present or in conflict with another plugin.") raise AlreadyThere("ReplaceValidate feature is already present or in conflict with another plugin.")
env.replace_validate = partial(self.replace_validate, env) env.replace_validate = partial(self.replace_validate, env)
env.replace_all_validate = partial(self.replace_all_validate, env) env.replace_all_validate = partial(self.replace_all_validate, env)
...@@ -110,7 +114,7 @@ class NodeFinder(dict, Bookkeeper): ...@@ -110,7 +114,7 @@ class NodeFinder(dict, Bookkeeper):
if self.env is not None: if self.env is not None:
raise Exception("A NodeFinder instance can only serve one Env.") raise Exception("A NodeFinder instance can only serve one Env.")
if hasattr(env, 'get_nodes'): if hasattr(env, 'get_nodes'):
raise Exception("NodeFinder is already present or in conflict with another plugin.") raise AlreadyThere("NodeFinder is already present or in conflict with another plugin.")
self.env = env self.env = env
env.get_nodes = partial(self.query, env) env.get_nodes = partial(self.query, env)
Bookkeeper.on_attach(self, env) Bookkeeper.on_attach(self, env)
...@@ -143,10 +147,7 @@ class NodeFinder(dict, Bookkeeper): ...@@ -143,10 +147,7 @@ class NodeFinder(dict, Bookkeeper):
except TypeError: except TypeError:
raise TypeError("%s in unhashable and cannot be queried by the optimizer" % op) raise TypeError("%s in unhashable and cannot be queried by the optimizer" % op)
all = list(all) all = list(all)
while all: return all
next = all.pop()
if next in env.nodes:
yield next
class PrintListener(object): class PrintListener(object):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论