提交 2ced2dcc authored 作者: Olivier Breuleux's avatar Olivier Breuleux

EquilibriumDB

上级 b4ec4265
......@@ -22,8 +22,13 @@ from opt import \
MergeOptimizer, MergeOptMerge, \
LocalOptimizer, local_optimizer, LocalOptGroup, LocalOpKeyOptGroup, \
OpSub, OpRemove, PatternSub, \
NavigatorOptimizer, TopoOptimizer, OpKeyOptimizer, \
PureThenInplaceOptimizer
NavigatorOptimizer, TopoOptimizer, OpKeyOptimizer, EquilibriumOptimizer, \
keep_going, \
InplaceOptimizer, PureThenInplaceOptimizer
from optdb import \
DB, Query, \
EquilibriumDB, SequenceDB
from toolbox import \
Bookkeeper, History, Validator, ReplaceValidate, NodeFinder, PrintListener
......
......@@ -11,11 +11,12 @@ import unify
import toolbox
import op
from copy import copy
from collections import deque
from collections import deque, defaultdict
import destroyhandler as dh
import sys
class Optimizer:
class Optimizer(object):
"""WRITEME
An L{Optimizer} can be applied to an L{Env} to transform it.
It can represent an optimization or in general any kind
......@@ -76,18 +77,35 @@ class SeqOptimizer(Optimizer, list):
sequentially.
"""
def __init__(self, *opts):
def __init__(self, *opts, **kw):
"""WRITEME"""
if len(opts) == 1 and isinstance(opts[0], (list, tuple)):
opts = opts[0]
self[:] = opts
self.failure_callback = kw.pop('failure_callback', None)
def apply(self, env):
"""WRITEME
Applies each L{Optimizer} in self in turn.
"""
for optimizer in self:
optimizer.optimize(env)
try:
optimizer.optimize(env)
except Exception, e:
if self.failure_callback:
self.failure_callback(e, self, optimizer)
continue
else:
raise
def __eq__(self, other):
return id(self) == id(other)
def __neq__(self, other):
return id(self) != id(other)
def __hash__(self):
return hash(id(self))
def __str__(self):
return "SeqOpt(%s)" % list.__str__(self)
......@@ -212,14 +230,21 @@ class LocalOptimizer(utils.object2):
class FromFunctionLocalOptimizer(LocalOptimizer):
"""WRITEME"""
def __init__(self, fn):
def __init__(self, fn, tracks = []):
self.transform = fn
self._tracks = tracks
def tracks(self):
return self._tracks
def add_requirements(self, env):
env.extend(toolbox.ReplaceValidate())
def local_optimizer(f):
"""WRITEME"""
return FromFunctionLocalOptimizer(f)
def local_optimizer(*tracks):
def decorator(f):
"""WRITEME"""
rval = FromFunctionLocalOptimizer(f, tracks)
rval.__name__ = f.__name__
return rval
return decorator
class LocalOptGroup(LocalOptimizer):
......@@ -272,6 +297,9 @@ class OpSub(LocalOptimizer):
def op_key(self):
return self.op1
def tracks(self):
return [[self.op1]]
def transform(self, node):
if node.op != self.op1:
return False
......@@ -304,6 +332,9 @@ class OpRemove(LocalOptimizer):
def op_key(self):
return self.op
def tracks(self):
return [[self.op]]
def transform(self, node):
if node.op != self.op:
return False
......@@ -380,6 +411,19 @@ class PatternSub(LocalOptimizer):
def op_key(self):
return self.op
def tracks(self):
def helper(pattern, sofar):
if isinstance(pattern, (list, tuple)):
sofar = sofar + (pattern[0],)
return reduce(tuple.__add__,
tuple(helper(p, sofar) for p in pattern[1:]),
())
elif isinstance(pattern, dict):
return helper(pattern['pattern'], sofar)
else:
return (sofar,)
return set(helper(self.in_pattern, ()))
def transform(self, node):
"""
Checks if the graph from node corresponds to in_pattern. If it does,
......@@ -490,23 +534,26 @@ class NavigatorOptimizer(Optimizer):
if u is not None:
env.remove_feature(u)
def process_node(self, env, node):
def process_node(self, env, node, lopt = None):
lopt = lopt or self.local_opt
try:
replacements = self.local_opt.transform(node)
replacements = lopt.transform(node)
except Exception, e:
if self.failure_callback is not None:
self.failure_callback(e, self, [(x, None) for x in node.outputs])
return
return False
else:
raise
if replacements is False or replacements is None:
return
return False
repl_pairs = zip(node.outputs, replacements)
try:
env.replace_all_validate(repl_pairs)
return True
except Exception, e:
if self.failure_callback is not None:
self.failure_callback(e, self, repl_pairs)
return False
else:
raise
......@@ -589,6 +636,174 @@ class OpKeyOptimizer(NavigatorOptimizer):
env.extend(toolbox.NodeFinder())
# class EquilibriumOptimizer(NavigatorOptimizer):
# """WRITEME"""
# def __init__(self, local_optimizers, failure_callback = None):
# NavigatorOptimizer.__init__(self, local_opt, ignore_newtrees, failure_callback)
# def apply(self, env):
# op = self.local_opt.op_key()
# if isinstance(op, (list, tuple)):
# q = reduce(list.__iadd__, map(env.get_nodes, op))
# else:
# q = list(env.get_nodes(op))
# def importer(node):
# if node.op == op: q.append(node)
# def pruner(node):
# if node is not current_node and node.op == op:
# try: q.remove(node)
# except ValueError: pass
# u = self.attach_updater(env, importer, pruner)
# try:
# while q:
# node = q.pop()
# current_node = node
# self.process_node(env, node)
# except:
# self.detach_updater(env, u)
# raise
from utils import D
class EquilibriumOptimizer(NavigatorOptimizer):
def __init__(self,
local_optimizers,
failure_callback = None,
max_depth = None,
max_use_ratio = None):
super(EquilibriumOptimizer, self).__init__(
None,
ignore_newtrees = False,
failure_callback = failure_callback)
self.local_optimizers = local_optimizers
self.max_depth = max_depth
self.max_use_ratio = max_use_ratio
self.tracks = defaultdict(list)
self.tracks0 = defaultdict(list)
max_depth = 0
for lopt in local_optimizers:
tracks = lopt.tracks()
for track in tracks:
max_depth = max(max_depth, len(track))
if self.max_depth is not None and max_depth > self.max_depth:
raise ValueError('One of the local optimizers exceeds the maximal depth.')
for i, op in enumerate(track):
if i == 0:
self.tracks0[op].append((track, i, lopt))
self.tracks[op].append((track, i, lopt))
def fetch_tracks(self, op):
return self.tracks[op] + self.tracks[None]
def fetch_tracks0(self, op):
return self.tracks0[op] + self.tracks0[None]
def backtrack(self, node, tasks):
candidates = self.fetch_tracks(node.op)
tracks = []
def filter(node, depth):
new_candidates = []
for candidate in candidates:
track, i, lopt = candidate
if i < depth:
pass
elif track[i-depth] in (None, node.op):
if i == depth:
tasks[node].append(lopt)
else:
tracks.append(candidate)
else:
new_candidates.append(candidate)
return new_candidates
depth = 0
nodes = [node]
while candidates:
for node in nodes:
candidates = filter(node, depth)
depth += 1
nodes = reduce(list.__iadd__,
[reduce(list.__iadd__,
[[n for n, i in out.clients] for out in node.outputs],
[]) for node in nodes],
[])
candidates = tracks
tracks = []
def apply(self, env):
tasks = defaultdict(list)
if self.max_use_ratio is not None:
max_uses = self.max_use_ratio * len(env.nodes)
runs = defaultdict(int)
else:
runs = None
def importer(node):
self.backtrack(node, tasks)
def pruner(node):
try:
del tasks[node]
except KeyError:
pass
# # == NOT IDEAL == #
# for node in env.nodes:
# importer(node)
for node in env.nodes:
tasks[node].extend(lopt for track, i, lopt in self.fetch_tracks0(node.op))
u = self.attach_updater(env, importer, pruner)
while tasks:
for node in tasks.iterkeys():
todo = tasks.pop(node)
break
for lopt in todo:
if runs is not None and runs[lopt] >= max_uses:
print >>sys.stderr, 'Warning: optimization exceeded its maximal use ratio: %s, %s' % (lopt, max_uses)
continue
success = self.process_node(env, node, lopt)
if success:
if runs is not None: runs[lopt] += 1
break
self.detach_updater(env, u)
# def match(self, node, candidates):
# candidates[:] = [candidate
# for candidate in candidates
# if candidate.current.op is None or candidate.current.op == node.op]
# for candidate in candidates:
# if candidate.current.inputs is not None:
# for in1, in2 in zip(candidate.current.inputs, node.inputs):
# if isinstance(in1, str):
# candidate.match[in1] = in2
# for client in node.clients:
# op = node.op
# patterns = self.pattern_base[(depth, op)].union(self.pattern_base[(depth, WILDCARD)])
# if not patterns:
# return patterns
# return self.match(node, depth + 1).intersection(patterns)
# def backtrack(self, node, q):
# for node2, i in node.clients:
# op2 = node2.op
def keep_going(exc, nav, repl_pairs):
"""WRITEME"""
pass
......@@ -635,6 +850,18 @@ def check_chain(r, *chain):
### Misc ###
############
class InplaceOptimizer(Optimizer):
def __init__(self, inplace):
self.inplace = inplace
def apply(self, env):
self.inplace(env)
def add_requirements(self, env):
env.extend(dh.DestroyHandler())
class PureThenInplaceOptimizer(Optimizer):
def __init__(self, pure, inplace):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论