提交 2ced2dcc authored 作者: Olivier Breuleux's avatar Olivier Breuleux

EquilibriumDB

上级 b4ec4265
...@@ -22,8 +22,13 @@ from opt import \ ...@@ -22,8 +22,13 @@ from opt import \
MergeOptimizer, MergeOptMerge, \ MergeOptimizer, MergeOptMerge, \
LocalOptimizer, local_optimizer, LocalOptGroup, LocalOpKeyOptGroup, \ LocalOptimizer, local_optimizer, LocalOptGroup, LocalOpKeyOptGroup, \
OpSub, OpRemove, PatternSub, \ OpSub, OpRemove, PatternSub, \
NavigatorOptimizer, TopoOptimizer, OpKeyOptimizer, \ NavigatorOptimizer, TopoOptimizer, OpKeyOptimizer, EquilibriumOptimizer, \
PureThenInplaceOptimizer keep_going, \
InplaceOptimizer, PureThenInplaceOptimizer
from optdb import \
DB, Query, \
EquilibriumDB, SequenceDB
from toolbox import \ from toolbox import \
Bookkeeper, History, Validator, ReplaceValidate, NodeFinder, PrintListener Bookkeeper, History, Validator, ReplaceValidate, NodeFinder, PrintListener
......
...@@ -11,11 +11,12 @@ import unify ...@@ -11,11 +11,12 @@ import unify
import toolbox import toolbox
import op import op
from copy import copy from copy import copy
from collections import deque from collections import deque, defaultdict
import destroyhandler as dh import destroyhandler as dh
import sys
class Optimizer: class Optimizer(object):
"""WRITEME """WRITEME
An L{Optimizer} can be applied to an L{Env} to transform it. An L{Optimizer} can be applied to an L{Env} to transform it.
It can represent an optimization or in general any kind It can represent an optimization or in general any kind
...@@ -76,18 +77,35 @@ class SeqOptimizer(Optimizer, list): ...@@ -76,18 +77,35 @@ class SeqOptimizer(Optimizer, list):
sequentially. sequentially.
""" """
def __init__(self, *opts): def __init__(self, *opts, **kw):
"""WRITEME""" """WRITEME"""
if len(opts) == 1 and isinstance(opts[0], (list, tuple)): if len(opts) == 1 and isinstance(opts[0], (list, tuple)):
opts = opts[0] opts = opts[0]
self[:] = opts self[:] = opts
self.failure_callback = kw.pop('failure_callback', None)
def apply(self, env): def apply(self, env):
"""WRITEME """WRITEME
Applies each L{Optimizer} in self in turn. Applies each L{Optimizer} in self in turn.
""" """
for optimizer in self: for optimizer in self:
optimizer.optimize(env) try:
optimizer.optimize(env)
except Exception, e:
if self.failure_callback:
self.failure_callback(e, self, optimizer)
continue
else:
raise
def __eq__(self, other):
return id(self) == id(other)
def __neq__(self, other):
return id(self) != id(other)
def __hash__(self):
return hash(id(self))
def __str__(self): def __str__(self):
return "SeqOpt(%s)" % list.__str__(self) return "SeqOpt(%s)" % list.__str__(self)
...@@ -212,14 +230,21 @@ class LocalOptimizer(utils.object2): ...@@ -212,14 +230,21 @@ class LocalOptimizer(utils.object2):
class FromFunctionLocalOptimizer(LocalOptimizer): class FromFunctionLocalOptimizer(LocalOptimizer):
"""WRITEME""" """WRITEME"""
def __init__(self, fn): def __init__(self, fn, tracks = []):
self.transform = fn self.transform = fn
self._tracks = tracks
def tracks(self):
return self._tracks
def add_requirements(self, env): def add_requirements(self, env):
env.extend(toolbox.ReplaceValidate()) env.extend(toolbox.ReplaceValidate())
def local_optimizer(f): def local_optimizer(*tracks):
"""WRITEME""" def decorator(f):
return FromFunctionLocalOptimizer(f) """WRITEME"""
rval = FromFunctionLocalOptimizer(f, tracks)
rval.__name__ = f.__name__
return rval
return decorator
class LocalOptGroup(LocalOptimizer): class LocalOptGroup(LocalOptimizer):
...@@ -272,6 +297,9 @@ class OpSub(LocalOptimizer): ...@@ -272,6 +297,9 @@ class OpSub(LocalOptimizer):
def op_key(self): def op_key(self):
return self.op1 return self.op1
def tracks(self):
return [[self.op1]]
def transform(self, node): def transform(self, node):
if node.op != self.op1: if node.op != self.op1:
return False return False
...@@ -304,6 +332,9 @@ class OpRemove(LocalOptimizer): ...@@ -304,6 +332,9 @@ class OpRemove(LocalOptimizer):
def op_key(self): def op_key(self):
return self.op return self.op
def tracks(self):
return [[self.op]]
def transform(self, node): def transform(self, node):
if node.op != self.op: if node.op != self.op:
return False return False
...@@ -380,6 +411,19 @@ class PatternSub(LocalOptimizer): ...@@ -380,6 +411,19 @@ class PatternSub(LocalOptimizer):
def op_key(self): def op_key(self):
return self.op return self.op
def tracks(self):
def helper(pattern, sofar):
if isinstance(pattern, (list, tuple)):
sofar = sofar + (pattern[0],)
return reduce(tuple.__add__,
tuple(helper(p, sofar) for p in pattern[1:]),
())
elif isinstance(pattern, dict):
return helper(pattern['pattern'], sofar)
else:
return (sofar,)
return set(helper(self.in_pattern, ()))
def transform(self, node): def transform(self, node):
""" """
Checks if the graph from node corresponds to in_pattern. If it does, Checks if the graph from node corresponds to in_pattern. If it does,
...@@ -490,23 +534,26 @@ class NavigatorOptimizer(Optimizer): ...@@ -490,23 +534,26 @@ class NavigatorOptimizer(Optimizer):
if u is not None: if u is not None:
env.remove_feature(u) env.remove_feature(u)
def process_node(self, env, node): def process_node(self, env, node, lopt = None):
lopt = lopt or self.local_opt
try: try:
replacements = self.local_opt.transform(node) replacements = lopt.transform(node)
except Exception, e: except Exception, e:
if self.failure_callback is not None: if self.failure_callback is not None:
self.failure_callback(e, self, [(x, None) for x in node.outputs]) self.failure_callback(e, self, [(x, None) for x in node.outputs])
return return False
else: else:
raise raise
if replacements is False or replacements is None: if replacements is False or replacements is None:
return return False
repl_pairs = zip(node.outputs, replacements) repl_pairs = zip(node.outputs, replacements)
try: try:
env.replace_all_validate(repl_pairs) env.replace_all_validate(repl_pairs)
return True
except Exception, e: except Exception, e:
if self.failure_callback is not None: if self.failure_callback is not None:
self.failure_callback(e, self, repl_pairs) self.failure_callback(e, self, repl_pairs)
return False
else: else:
raise raise
...@@ -589,6 +636,174 @@ class OpKeyOptimizer(NavigatorOptimizer): ...@@ -589,6 +636,174 @@ class OpKeyOptimizer(NavigatorOptimizer):
env.extend(toolbox.NodeFinder()) env.extend(toolbox.NodeFinder())
# class EquilibriumOptimizer(NavigatorOptimizer):
# """WRITEME"""
# def __init__(self, local_optimizers, failure_callback = None):
# NavigatorOptimizer.__init__(self, local_opt, ignore_newtrees, failure_callback)
# def apply(self, env):
# op = self.local_opt.op_key()
# if isinstance(op, (list, tuple)):
# q = reduce(list.__iadd__, map(env.get_nodes, op))
# else:
# q = list(env.get_nodes(op))
# def importer(node):
# if node.op == op: q.append(node)
# def pruner(node):
# if node is not current_node and node.op == op:
# try: q.remove(node)
# except ValueError: pass
# u = self.attach_updater(env, importer, pruner)
# try:
# while q:
# node = q.pop()
# current_node = node
# self.process_node(env, node)
# except:
# self.detach_updater(env, u)
# raise
from utils import D
class EquilibriumOptimizer(NavigatorOptimizer):
def __init__(self,
local_optimizers,
failure_callback = None,
max_depth = None,
max_use_ratio = None):
super(EquilibriumOptimizer, self).__init__(
None,
ignore_newtrees = False,
failure_callback = failure_callback)
self.local_optimizers = local_optimizers
self.max_depth = max_depth
self.max_use_ratio = max_use_ratio
self.tracks = defaultdict(list)
self.tracks0 = defaultdict(list)
max_depth = 0
for lopt in local_optimizers:
tracks = lopt.tracks()
for track in tracks:
max_depth = max(max_depth, len(track))
if self.max_depth is not None and max_depth > self.max_depth:
raise ValueError('One of the local optimizers exceeds the maximal depth.')
for i, op in enumerate(track):
if i == 0:
self.tracks0[op].append((track, i, lopt))
self.tracks[op].append((track, i, lopt))
def fetch_tracks(self, op):
return self.tracks[op] + self.tracks[None]
def fetch_tracks0(self, op):
return self.tracks0[op] + self.tracks0[None]
def backtrack(self, node, tasks):
candidates = self.fetch_tracks(node.op)
tracks = []
def filter(node, depth):
new_candidates = []
for candidate in candidates:
track, i, lopt = candidate
if i < depth:
pass
elif track[i-depth] in (None, node.op):
if i == depth:
tasks[node].append(lopt)
else:
tracks.append(candidate)
else:
new_candidates.append(candidate)
return new_candidates
depth = 0
nodes = [node]
while candidates:
for node in nodes:
candidates = filter(node, depth)
depth += 1
nodes = reduce(list.__iadd__,
[reduce(list.__iadd__,
[[n for n, i in out.clients] for out in node.outputs],
[]) for node in nodes],
[])
candidates = tracks
tracks = []
def apply(self, env):
tasks = defaultdict(list)
if self.max_use_ratio is not None:
max_uses = self.max_use_ratio * len(env.nodes)
runs = defaultdict(int)
else:
runs = None
def importer(node):
self.backtrack(node, tasks)
def pruner(node):
try:
del tasks[node]
except KeyError:
pass
# # == NOT IDEAL == #
# for node in env.nodes:
# importer(node)
for node in env.nodes:
tasks[node].extend(lopt for track, i, lopt in self.fetch_tracks0(node.op))
u = self.attach_updater(env, importer, pruner)
while tasks:
for node in tasks.iterkeys():
todo = tasks.pop(node)
break
for lopt in todo:
if runs is not None and runs[lopt] >= max_uses:
print >>sys.stderr, 'Warning: optimization exceeded its maximal use ratio: %s, %s' % (lopt, max_uses)
continue
success = self.process_node(env, node, lopt)
if success:
if runs is not None: runs[lopt] += 1
break
self.detach_updater(env, u)
# def match(self, node, candidates):
# candidates[:] = [candidate
# for candidate in candidates
# if candidate.current.op is None or candidate.current.op == node.op]
# for candidate in candidates:
# if candidate.current.inputs is not None:
# for in1, in2 in zip(candidate.current.inputs, node.inputs):
# if isinstance(in1, str):
# candidate.match[in1] = in2
# for client in node.clients:
# op = node.op
# patterns = self.pattern_base[(depth, op)].union(self.pattern_base[(depth, WILDCARD)])
# if not patterns:
# return patterns
# return self.match(node, depth + 1).intersection(patterns)
# def backtrack(self, node, q):
# for node2, i in node.clients:
# op2 = node2.op
def keep_going(exc, nav, repl_pairs): def keep_going(exc, nav, repl_pairs):
"""WRITEME""" """WRITEME"""
pass pass
...@@ -635,6 +850,18 @@ def check_chain(r, *chain): ...@@ -635,6 +850,18 @@ def check_chain(r, *chain):
### Misc ### ### Misc ###
############ ############
class InplaceOptimizer(Optimizer):
def __init__(self, inplace):
self.inplace = inplace
def apply(self, env):
self.inplace(env)
def add_requirements(self, env):
env.extend(dh.DestroyHandler())
class PureThenInplaceOptimizer(Optimizer): class PureThenInplaceOptimizer(Optimizer):
def __init__(self, pure, inplace): def __init__(self, pure, inplace):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论