提交 32cce9b4 authored 作者: Olivier Breuleux's avatar Olivier Breuleux

optdb

上级 2ced2dcc
...@@ -250,7 +250,7 @@ class Env(utils.object2): ...@@ -250,7 +250,7 @@ class Env(utils.object2):
### replace ### ### replace ###
def replace(self, r, new_r): def replace(self, r, new_r):
""" WRITEME """ WRITEME
This is the main interface to manipulate the subgraph in Env. This is the main interface to manipulate the subgraph in Env.
...@@ -259,7 +259,7 @@ class Env(utils.object2): ...@@ -259,7 +259,7 @@ class Env(utils.object2):
if r.env is not self: if r.env is not self:
raise Exception("Cannot replace %s because it does not belong to this Env" % r) raise Exception("Cannot replace %s because it does not belong to this Env" % r)
if not r.type == new_r.type: if not r.type == new_r.type:
raise TypeError("The type of the replacement must be the same as the type of the original Result.", r, new_r) raise TypeError("The type of the replacement must be the same as the type of the original Result.", r, new_r, r.type, new_r.type)
if r not in self.results: if r not in self.results:
# this result isn't in the graph... don't raise an exception here, just return silently # this result isn't in the graph... don't raise an exception here, just return silently
# because it makes it easier to implement some optimizations for multiple-output ops # because it makes it easier to implement some optimizations for multiple-output ops
......
from collections import defaultdict
import opt
class DB(object):
def __init__(self):
self.__db__ = defaultdict(set)
def register(self, name, obj, *tags):
obj.name = name
if name in self.__db__:
raise ValueError('The name of the object cannot be an existing tag or the name of another existing object.', obj, name)
self.__db__[name] = set([obj])
for tag in tags:
self.__db__[tag].add(obj)
def __query__(self, q):
if not isinstance(q, Query):
raise TypeError('Expected a Query.', q)
results = set()
for tag in q.include:
results.update(self.__db__[tag])
for tag in q.require:
results.intersection_update(self.__db__[tag])
for tag in q.exclude:
results.difference_update(self.__db__[tag])
remove = set()
add = set()
for obj in results:
if isinstance(obj, DB):
sq = q.subquery.get(obj.name, q)
if sq:
replacement = obj.query(sq)
replacement.name = obj.name
remove.add(obj)
add.add(replacement)
results.difference_update(remove)
results.update(add)
return results
def query(self, *tags, **kwtags):
if len(tags) >= 1 and isinstance(tags[0], Query):
if len(tags) > 1 or kwtags:
raise TypeError('If the first argument to query is a Query, there should be no other arguments.', tags, kwtags)
return self.__query__(tags[0])
include = [tag[1:] for tag in tags if tag.startswith('+')]
require = [tag[1:] for tag in tags if tag.startswith('&')]
exclude = [tag[1:] for tag in tags if tag.startswith('-')]
if len(include) + len(require) + len(exclude) < len(tags):
raise ValueError("All tags must start with one of the following characters: '+', '&' or '-'", tags)
return self.__query__(Query(include = include,
require = require,
exclude = exclude,
subquery = kwtags))
def __getitem__(self, name):
results = self.__db__[name]
if not results:
raise KeyError("Nothing registered for '%s'" % name)
elif len(results) > 1:
raise ValueError('More than one match for %s (please use query)' % name)
for result in results:
return result
class Query(object):
def __init__(self, include, require = None, exclude = None, subquery = None):
self.include = include
self.require = require or set()
self.exclude = exclude or set()
self.subquery = subquery or {}
def including(self, *tags):
return Query(self.include.union(tags),
self.require,
self.exclude,
self.subquery)
def excluding(self, *tags):
return Query(self.include,
self.require,
self.exclude.union(tags),
self.subquery)
def requiring(self, *tags):
return Query(self.include,
self.require.union(tags),
self.exclude,
self.subquery)
class EquilibriumDB(DB):
def query(self, *tags, **kwtags):
opts = super(EquilibriumDB, self).query(*tags, **kwtags)
return opt.EquilibriumOptimizer(opts, max_depth = 5, max_use_ratio = 10, failure_callback = opt.keep_going)
class SequenceDB(DB):
def __init__(self):
super(SequenceDB, self).__init__()
self.__priority__ = {}
def register(self, name, obj, priority, *tags):
super(SequenceDB, self).register(name, obj, *tags)
self.__priority__[name] = priority
def query(self, *tags, **kwtags):
opts = super(SequenceDB, self).query(*tags, **kwtags)
opts = list(opts)
opts.sort(key = lambda obj: self.__priority__[obj.name])
return opt.SeqOptimizer(opts, failure_callback = opt.keep_going)
...@@ -331,3 +331,46 @@ class TestMergeOptimizer: ...@@ -331,3 +331,46 @@ class TestMergeOptimizer:
assert strg == '[Op1(y, y)]' or strg == '[Op1(z, z)]' assert strg == '[Op1(y, y)]' or strg == '[Op1(z, z)]'
class TestEquilibrium(object):
def test_1(self):
x, y, z = map(MyResult, 'xyz')
e = op3(op4(x, y))
g = Env([x, y, z], [e])
print g
opt = EquilibriumOptimizer(
[PatternSub((op1, 'x', 'y'), (op2, 'x', 'y')),
PatternSub((op4, 'x', 'y'), (op1, 'x', 'y')),
PatternSub((op3, (op2, 'x', 'y')), (op4, 'x', 'y'))
],
max_use_ratio = 10)
opt.optimize(g)
print g
assert str(g) == '[Op2(x, y)]'
def test_low_use_ratio(self):
x, y, z = map(MyResult, 'xyz')
e = op3(op4(x, y))
g = Env([x, y, z], [e])
print g
sys.stderr = sys.stdout # display pesky warnings along with stdout
opt = EquilibriumOptimizer(
[PatternSub((op1, 'x', 'y'), (op2, 'x', 'y')),
PatternSub((op4, 'x', 'y'), (op1, 'x', 'y')),
PatternSub((op3, (op2, 'x', 'y')), (op4, 'x', 'y'))
],
max_use_ratio = 1. / len(g.nodes)) # each opt can only be applied once
opt.optimize(g)
print g
assert str(g) == '[Op4(x, y)]'
...@@ -101,8 +101,9 @@ class ReplaceValidate(History, Validator): ...@@ -101,8 +101,9 @@ class ReplaceValidate(History, Validator):
try: try:
env.replace(r, new_r) env.replace(r, new_r)
except Exception, e: except Exception, e:
print >>sys.stderr, "<<!! BUG IN ENV.REPLACE OR A LISTENER !!>>", type(e), e if not 'The type of the replacement must be the same' in str(e) or not 'does not belong to this Env' in str(e):
env.revert(chk) # this might fail; env.replace should never raise an exception (it kinda needs better internal error handling) print >>sys.stderr, "<<!! BUG IN ENV.REPLACE OR A LISTENER !!>>", type(e), e
env.revert(chk) # this might fail if the error is in a listener: (env.replace kinda needs better internal error handling)
raise raise
try: try:
env.validate() env.validate()
......
...@@ -38,6 +38,10 @@ class scratchpad: ...@@ -38,6 +38,10 @@ class scratchpad:
def __str__(self): def __str__(self):
return "scratch" + str(self.__dict__) return "scratch" + str(self.__dict__)
class D:
def __init__(self, **d):
self.__dict__.update(d)
def memoize(f): def memoize(f):
"""Cache the return value for each tuple of arguments """Cache the return value for each tuple of arguments
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论