提交 59bbe565 authored 作者: Frederic's avatar Frederic

Add LocalSequenceDB and LocalSeqOptimizer

上级 4829418f
......@@ -152,7 +152,7 @@ def inplace_optimizer(f):
class SeqOptimizer(Optimizer, list):
#inherit from Optimizer first to get Optimizer.__hash__
# inherit from Optimizer first to get Optimizer.__hash__
"""WRITEME
Takes a list of L{Optimizer} instances and applies them
sequentially.
......@@ -823,6 +823,68 @@ class LocalOptimizer(object):
(' ' * level), self.__class__.__name__, id(self))
class LocalSeqOptimizer(LocalOptimizer, list):
"""
This allow to try a group of local optimizer in sequence.
When one do something, we return without trying the following one.
"""
# inherit from Optimizer first to get Optimizer.__hash__
def __init__(self, *opts, **kw):
"""WRITEME"""
if len(opts) == 1 and isinstance(opts[0], (list, tuple)):
opts = opts[0]
self[:] = opts
self.failure_callback = kw.pop('failure_callback', None)
def tracks(self):
t = []
for l in self:
tt = l.tracks()
if tt:
t.extend(tt)
return t
def transform(self, node):
"""Transform a subgraph whose output is `node`.
Subclasses should implement this function so that it returns one of two
kinds of things:
- False to indicate that no optimization can be applied to this `node`;
or
- <list of variables> to use in place of `node`'s outputs in the
greater graph.
- dict(old variables -> new variables). A dictionary that map
from old variables to new variables to replace.
:type node: an Apply instance
"""
for l in self:
ret = l.transform(node)
if ret:
return ret
def add_requirements(self, fgraph):
"""
If this local optimization wants to add some requirements to the
fgraph,
This is the place to do it.
"""
for l in self:
l.add_requirements(fgraph)
def print_summary(self, stream=sys.stdout, level=0, depth=-1):
name = getattr(self, 'name', None)
print >> stream, "%s%s %s id=%i" % (
(' ' * level), self.__class__.__name__, name, id(self))
# This way, -1 will do all depth
if depth != 0:
depth -= 1
for opt in self:
opt.print_summary(stream, level=(level + 2), depth=depth)
class FromFunctionLocalOptimizer(LocalOptimizer):
"""WRITEME"""
def __init__(self, fn, tracks=None, requirements=()):
......
......@@ -223,6 +223,7 @@ class SequenceDB(DB):
other tags) fast_run and fast_compile optimizers are drawn is a SequenceDB.
"""
seq_opt = opt.SeqOptimizer
def __init__(self, failure_callback=opt.SeqOptimizer.warn):
super(SequenceDB, self).__init__()
......@@ -256,13 +257,13 @@ class SequenceDB(DB):
# the order we want.
opts.sort(key=lambda obj: obj.name)
opts.sort(key=lambda obj: self.__position__[obj.name])
ret = opt.SeqOptimizer(opts, failure_callback=self.failure_callback)
ret = self.seq_opt(opts, failure_callback=self.failure_callback)
if hasattr(tags[0], 'name'):
ret.name = tags[0].name
return ret
def print_summary(self, stream=sys.stdout):
print >> stream, "SequenceDB (id %i)" % id(self)
print >> stream, self.__class__.__name__ + " (id %i)" % id(self)
positions = self.__position__.items()
def c(a, b):
......@@ -279,6 +280,13 @@ class SequenceDB(DB):
return sio.getvalue()
class LocalSequenceDB(SequenceDB):
"""
This generate a local optimizer instead of a global optimizer.
"""
seq_opt = opt.LocalSeqOptimizer
class ProxyDB(DB):
"""
Wrap an existing proxy.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论