提交 32cce9b4 authored 作者: Olivier Breuleux's avatar Olivier Breuleux

optdb

上级 2ced2dcc
......@@ -250,7 +250,7 @@ class Env(utils.object2):
### replace ###
def replace(self, r, new_r):
""" WRITEME
This is the main interface to manipulate the subgraph in Env.
......@@ -259,7 +259,7 @@ class Env(utils.object2):
if r.env is not self:
raise Exception("Cannot replace %s because it does not belong to this Env" % r)
if not r.type == new_r.type:
raise TypeError("The type of the replacement must be the same as the type of the original Result.", r, new_r)
raise TypeError("The type of the replacement must be the same as the type of the original Result.", r, new_r, r.type, new_r.type)
if r not in self.results:
# this result isn't in the graph... don't raise an exception here, just return silently
# because it makes it easier to implement some optimizations for multiple-output ops
......
from collections import defaultdict
import opt
class DB(object):
def __init__(self):
self.__db__ = defaultdict(set)
def register(self, name, obj, *tags):
obj.name = name
if name in self.__db__:
raise ValueError('The name of the object cannot be an existing tag or the name of another existing object.', obj, name)
self.__db__[name] = set([obj])
for tag in tags:
self.__db__[tag].add(obj)
def __query__(self, q):
if not isinstance(q, Query):
raise TypeError('Expected a Query.', q)
results = set()
for tag in q.include:
results.update(self.__db__[tag])
for tag in q.require:
results.intersection_update(self.__db__[tag])
for tag in q.exclude:
results.difference_update(self.__db__[tag])
remove = set()
add = set()
for obj in results:
if isinstance(obj, DB):
sq = q.subquery.get(obj.name, q)
if sq:
replacement = obj.query(sq)
replacement.name = obj.name
remove.add(obj)
add.add(replacement)
results.difference_update(remove)
results.update(add)
return results
def query(self, *tags, **kwtags):
if len(tags) >= 1 and isinstance(tags[0], Query):
if len(tags) > 1 or kwtags:
raise TypeError('If the first argument to query is a Query, there should be no other arguments.', tags, kwtags)
return self.__query__(tags[0])
include = [tag[1:] for tag in tags if tag.startswith('+')]
require = [tag[1:] for tag in tags if tag.startswith('&')]
exclude = [tag[1:] for tag in tags if tag.startswith('-')]
if len(include) + len(require) + len(exclude) < len(tags):
raise ValueError("All tags must start with one of the following characters: '+', '&' or '-'", tags)
return self.__query__(Query(include = include,
require = require,
exclude = exclude,
subquery = kwtags))
def __getitem__(self, name):
results = self.__db__[name]
if not results:
raise KeyError("Nothing registered for '%s'" % name)
elif len(results) > 1:
raise ValueError('More than one match for %s (please use query)' % name)
for result in results:
return result
class Query(object):
def __init__(self, include, require = None, exclude = None, subquery = None):
self.include = include
self.require = require or set()
self.exclude = exclude or set()
self.subquery = subquery or {}
def including(self, *tags):
return Query(self.include.union(tags),
self.require,
self.exclude,
self.subquery)
def excluding(self, *tags):
return Query(self.include,
self.require,
self.exclude.union(tags),
self.subquery)
def requiring(self, *tags):
return Query(self.include,
self.require.union(tags),
self.exclude,
self.subquery)
class EquilibriumDB(DB):
def query(self, *tags, **kwtags):
opts = super(EquilibriumDB, self).query(*tags, **kwtags)
return opt.EquilibriumOptimizer(opts, max_depth = 5, max_use_ratio = 10, failure_callback = opt.keep_going)
class SequenceDB(DB):
def __init__(self):
super(SequenceDB, self).__init__()
self.__priority__ = {}
def register(self, name, obj, priority, *tags):
super(SequenceDB, self).register(name, obj, *tags)
self.__priority__[name] = priority
def query(self, *tags, **kwtags):
opts = super(SequenceDB, self).query(*tags, **kwtags)
opts = list(opts)
opts.sort(key = lambda obj: self.__priority__[obj.name])
return opt.SeqOptimizer(opts, failure_callback = opt.keep_going)
......@@ -331,3 +331,46 @@ class TestMergeOptimizer:
assert strg == '[Op1(y, y)]' or strg == '[Op1(z, z)]'
class TestEquilibrium(object):
def test_1(self):
x, y, z = map(MyResult, 'xyz')
e = op3(op4(x, y))
g = Env([x, y, z], [e])
print g
opt = EquilibriumOptimizer(
[PatternSub((op1, 'x', 'y'), (op2, 'x', 'y')),
PatternSub((op4, 'x', 'y'), (op1, 'x', 'y')),
PatternSub((op3, (op2, 'x', 'y')), (op4, 'x', 'y'))
],
max_use_ratio = 10)
opt.optimize(g)
print g
assert str(g) == '[Op2(x, y)]'
def test_low_use_ratio(self):
x, y, z = map(MyResult, 'xyz')
e = op3(op4(x, y))
g = Env([x, y, z], [e])
print g
sys.stderr = sys.stdout # display pesky warnings along with stdout
opt = EquilibriumOptimizer(
[PatternSub((op1, 'x', 'y'), (op2, 'x', 'y')),
PatternSub((op4, 'x', 'y'), (op1, 'x', 'y')),
PatternSub((op3, (op2, 'x', 'y')), (op4, 'x', 'y'))
],
max_use_ratio = 1. / len(g.nodes)) # each opt can only be applied once
opt.optimize(g)
print g
assert str(g) == '[Op4(x, y)]'
......@@ -101,8 +101,9 @@ class ReplaceValidate(History, Validator):
try:
env.replace(r, new_r)
except Exception, e:
print >>sys.stderr, "<<!! BUG IN ENV.REPLACE OR A LISTENER !!>>", type(e), e
env.revert(chk) # this might fail; env.replace should never raise an exception (it kinda needs better internal error handling)
if not 'The type of the replacement must be the same' in str(e) or not 'does not belong to this Env' in str(e):
print >>sys.stderr, "<<!! BUG IN ENV.REPLACE OR A LISTENER !!>>", type(e), e
env.revert(chk) # this might fail if the error is in a listener: (env.replace kinda needs better internal error handling)
raise
try:
env.validate()
......
......@@ -38,6 +38,10 @@ class scratchpad:
def __str__(self):
return "scratch" + str(self.__dict__)
class D:
def __init__(self, **d):
self.__dict__.update(d)
def memoize(f):
"""Cache the return value for each tuple of arguments
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论