提交 9bc72d95 authored 作者: Olivier Breuleux's avatar Olivier Breuleux

documented opt.py and added a callback to some Optimizers so the user can know…

documented opt.py and added a callback to some Optimizers so the user can know about failed attempts to optimize
上级 6008b0ec
...@@ -9,25 +9,51 @@ import ext ...@@ -9,25 +9,51 @@ import ext
class Optimizer: class Optimizer:
"""
An Optimizer can be applied to an env to transform it.
It can represent an optimization or in general any kind
of transformation you could apply to an env.
"""
def apply(self, env): def apply(self, env):
"""
Applies the optimization to the provided env. It may
use all the methods defined by the env. If the optimizer
needs to use a certain tool, such as an InstanceFinder,
it should set the __env_require__ field to a list of
what needs to be registered with the Env.
"""
pass pass
def optimize(self, env): def optimize(self, env):
"""
This is meant as a shortcut to:
env.satisfy(opt)
opt.apply(env)
"""
env.satisfy(self) env.satisfy(self)
self.apply(env) self.apply(env)
def __call__(self, env):
self.optimize(env)
DummyOpt = Optimizer() DummyOpt = Optimizer()
DummyOpt.__doc__ = "Does nothing."
class SeqOptimizer(Optimizer, list): class SeqOptimizer(Optimizer, list):
"""
Takes a list of Optimizer instances and applies them
sequentially.
"""
def __init__(self, *opts):
if len(opts) == 1 and isinstance(opts[0], (list, tuple)):
opts = opts[0]
list.__init__(opts)
def apply(self, env): def apply(self, env):
"""
Applies each optimizer in self in turn.
"""
for optimizer in self: for optimizer in self:
optimizer.optimize(env) optimizer.optimize(env)
...@@ -40,14 +66,34 @@ class SeqOptimizer(Optimizer, list): ...@@ -40,14 +66,34 @@ class SeqOptimizer(Optimizer, list):
class LocalOptimizer(Optimizer): class LocalOptimizer(Optimizer):
"""
Generic Optimizer class that considers local parts of
the env. It must be subclassed and should override the
following two methods:
* candidates(env) -> returns a set of ops that can be
optimized
* apply_on_op(env, op) -> for each op in candidates,
this function will be called to perform the actual
optimization.
"""
def candidates(self, env): def candidates(self, env):
return env.ops() """
Must return a set of ops that can be optimized.
"""
raise utils.AbstractFunctionError()
def apply_on_op(self, env, op): def apply_on_op(self, env, op):
raise Exception("Please override this function.") """
For each op in candidates, this function will be called to
perform the actual optimization.
"""
raise utils.AbstractFunctionError()
def apply(self, env): def apply(self, env):
"""
Calls self.apply_on_op(env, op) for each op in self.candidates(env).
"""
for op in self.candidates(env): for op in self.candidates(env):
if env.has_op(op): if env.has_op(op):
self.apply_on_op(env, op) self.apply_on_op(env, op)
...@@ -55,50 +101,95 @@ class LocalOptimizer(Optimizer): ...@@ -55,50 +101,95 @@ class LocalOptimizer(Optimizer):
class OpSpecificOptimizer(LocalOptimizer): class OpSpecificOptimizer(LocalOptimizer):
"""
Generic optimizer that applies only to ops of a certain
type. The type in question is accessed through self.opclass.
opclass can also be a class variable of the subclass.
"""
__env_require__ = toolbox.InstanceFinder __env_require__ = toolbox.InstanceFinder
opclass = Op
def candidates(self, env): def candidates(self, env):
"""
Returns all instances of self.opclass.
"""
return env.get_instances_of(self.opclass) return env.get_instances_of(self.opclass)
class OpSubOptimizer(Optimizer): class OpSubOptimizer(Optimizer):
"""
Replaces all ops of a certain type by ops of another type that
take the same inputs as what they are replacing.
e.g. OpSubOptimizer(add, sub) ==> add(div(x, y), add(y, x)) -> sub(div(x, y), sub(y, x))
"""
__env_require__ = toolbox.InstanceFinder __env_require__ = toolbox.InstanceFinder
def __init__(self, op1, op2): def __init__(self, op1, op2, failure_callback = None):
if not op1._default_output_idx >= 0: """
raise TypeError("OpSubOptimizer must be used with Op instances that have a default output.") op1 and op2 must both be Op subclasses, they must both take
# note: op2 must have the same input signature as op1 the same number of inputs and they must both have the same
number of outputs.
"""
self.op1 = op1 self.op1 = op1
self.op2 = op2 self.op2 = op2
self.failure_callback = failure_callback
def apply(self, env): def apply(self, env):
"""
Replaces all occurrences of self.op1 by instances of self.op2
with the same inputs.
If failure_callback is not None, it will be called whenever
the Optimizer fails to do a replacement in the graph. The
arguments to the callback are: (op1_instance, replacement, exception)
"""
candidates = env.get_instances_of(self.op1) candidates = env.get_instances_of(self.op1)
for op in candidates: for op in candidates:
try: try:
# note: only replaces the default 'out' port if it exists repl = self.op2(*op.inputs)
r = self.op2(*op.inputs).out assert len(op.outputs) == len(repl.outputs)
env.replace(op.out, r) for old, new in zip(op.outputs, repl.outputs):
except InconsistencyError, e: env.replace(old, new)
# print "Warning: OpSubOpt failed to transform %s into %s: %s" % (op, self.op2, str(e)) # warning is for debug except Exception, e:
if self.failure_callback is not None:
self.failure_callback(op, repl, e)
pass pass
def str(self):
return "%s -> %s" % (self.op1.__name__, self.op2.__name__)
class OpRemover(Optimizer): class OpRemover(Optimizer):
"""
Removes all ops of a certain type by transferring each of its
outputs to the corresponding input.
"""
__env_require__ = toolbox.InstanceFinder __env_require__ = toolbox.InstanceFinder
def __init__(self, opclass): def __init__(self, opclass, failure_callback = None):
"""
opclass is the class of the ops to remove. It must take as
many inputs as outputs.
"""
self.opclass = opclass self.opclass = opclass
self.failure_callback = failure_callback
def apply(self, env): def apply(self, env):
"""
Removes all occurrences of self.opclass.
If self.failure_callback is not None, it will be called whenever
the Optimizer fails to remove an operation in the graph. The
arguments to the callback are: (opclass_instance, exception)
"""
candidates = env.get_instances_of(self.opclass) candidates = env.get_instances_of(self.opclass)
for op in candidates: for op in candidates:
...@@ -106,10 +197,14 @@ class OpRemover(Optimizer): ...@@ -106,10 +197,14 @@ class OpRemover(Optimizer):
assert len(op.inputs) == len(op.outputs) assert len(op.inputs) == len(op.outputs)
for input, output in zip(op.inputs, op.outputs): for input, output in zip(op.inputs, op.outputs):
env.replace(output, input) env.replace(output, input)
except InconsistencyError, e: except Exception, e:
# print "Warning: OpRemover failed to remove %s: %s" % (op, str(e)) # warning is for debug if self.failure_callback is not None:
self.failure_callback(op, e)
pass pass
def str(self):
return "f(%s(x)) -> f(x)" % self.opclass
class PatternOptimizer(OpSpecificOptimizer): class PatternOptimizer(OpSpecificOptimizer):
...@@ -117,13 +212,26 @@ class PatternOptimizer(OpSpecificOptimizer): ...@@ -117,13 +212,26 @@ class PatternOptimizer(OpSpecificOptimizer):
Replaces all occurrences of the first pattern by the second pattern. Replaces all occurrences of the first pattern by the second pattern.
""" """
def __init__(self, in_pattern, out_pattern): def __init__(self, in_pattern, out_pattern, failure_callback = None):
"""
Sets in_pattern for replacement by out_pattern.
self.opclass is set to in_pattern[0] to accelerate the search.
"""
self.in_pattern = in_pattern self.in_pattern = in_pattern
self.out_pattern = out_pattern self.out_pattern = out_pattern
self.opclass = self.in_pattern[0] self.opclass = self.in_pattern[0]
self.__doc__ = self.__class__.__doc__ + "\n\nThis instance does: " + str(self) + "\n" self.__doc__ = self.__class__.__doc__ + "\n\nThis instance does: " + str(self) + "\n"
self.failure_callback = failure_callback
def apply_on_op(self, env, op): def apply_on_op(self, env, op):
"""
Checks if the graph from op corresponds to in_pattern. If it does,
constructs out_pattern and performs the replacement.
If self.failure_callback is not None, if there is a match but a
replacement fails to occur, the callback will be called with
arguments (results_to_replace, replacement, exception).
"""
def match(pattern, expr, u, first = False): def match(pattern, expr, u, first = False):
if isinstance(pattern, (list, tuple)): if isinstance(pattern, (list, tuple)):
...@@ -168,8 +276,9 @@ class PatternOptimizer(OpSpecificOptimizer): ...@@ -168,8 +276,9 @@ class PatternOptimizer(OpSpecificOptimizer):
if not isinstance(p, str): if not isinstance(p, str):
new = new.out new = new.out
env.replace(op.out, new) env.replace(op.out, new)
except InconsistencyError, e: except Exception, e:
# print "Warning: '%s' failed to apply on %s: %s" % (self, op, str(e)) # warning is for debug if self.failure_callback is not None:
self.failure_callback(op.out, new, e)
pass pass
...@@ -183,6 +292,11 @@ class PatternOptimizer(OpSpecificOptimizer): ...@@ -183,6 +292,11 @@ class PatternOptimizer(OpSpecificOptimizer):
class ConstantFinder(Optimizer): class ConstantFinder(Optimizer):
"""
Sets as constant every orphan that is not destroyed
and sets as indestructible every input that is not
destroyed.
"""
def apply(self, env): def apply(self, env):
if env.has_feature(ext.DestroyHandler): if env.has_feature(ext.DestroyHandler):
...@@ -202,6 +316,12 @@ class ConstantFinder(Optimizer): ...@@ -202,6 +316,12 @@ class ConstantFinder(Optimizer):
class MergeOptimizer(Optimizer): class MergeOptimizer(Optimizer):
"""
Merges parts of the graph that are identical, i.e. parts that
take the same inputs and carry out the asme computations so we
can avoid doing them more than once. Also merges results that
are constant.
"""
def apply(self, env): def apply(self, env):
cid = {} cid = {}
...@@ -220,6 +340,9 @@ class MergeOptimizer(Optimizer): ...@@ -220,6 +340,9 @@ class MergeOptimizer(Optimizer):
inv_cid[i] = r inv_cid[i] = r
for op in env.io_toposort(): for op in env.io_toposort():
# this could be made more robust by having an op.hash() that
# doesn't depend on the inputs but can depend on additional properties
# of the op.
op_cid = (op.__class__, tuple([cid[input] for input in op.inputs])) op_cid = (op.__class__, tuple([cid[input] for input in op.inputs]))
dup = inv_cid.get(op_cid, None) dup = inv_cid.get(op_cid, None)
if dup is None: if dup is None:
...@@ -237,124 +360,131 @@ class MergeOptimizer(Optimizer): ...@@ -237,124 +360,131 @@ class MergeOptimizer(Optimizer):
def MergeOptMerge(opt): def MergeOptMerge(opt):
"""
Returns an Optimizer that merges the graph then applies the
optimizer in opt and then merges the graph again in case the
opt introduced additional similarities.
"""
merger = MergeOptimizer() merger = MergeOptimizer()
return SeqOptimizer([merger, opt, merger]) return SeqOptimizer([merger, opt, merger])
class MultiOptimizer(Optimizer): ### THE FOLLOWING OPTIMIZERS ARE NEITHER USED NOR TESTED BUT PROBABLY WORK AND COULD BE USEFUL ###
def __init__(self, **opts): # class MultiOptimizer(Optimizer):
self._opts = []
self.ord = {}
self.name_to_opt = {}
self.up_to_date = True
for name, opt in opts:
self.register(name, opt, after = [], before = [])
def register(self, name, opt, **relative): # def __init__(self, **opts):
self.name_to_opt[name] = opt # self._opts = []
# self.ord = {}
# self.name_to_opt = {}
# self.up_to_date = True
# for name, opt in opts:
# self.register(name, opt, after = [], before = [])
after = relative.get('after', []) # def register(self, name, opt, **relative):
if not isinstance(after, (list, tuple)): # self.name_to_opt[name] = opt
after = [after]
before = relative.get('before', []) # after = relative.get('after', [])
if not isinstance(before, (list, tuple)): # if not isinstance(after, (list, tuple)):
before = [before] # after = [after]
self.up_to_date = False # before = relative.get('before', [])
# if not isinstance(before, (list, tuple)):
# before = [before]
if name in self.ord: # self.up_to_date = False
raise Exception("Cannot redefine optimization: '%s'" % name)
self.ord[name] = set(after) # if name in self.ord:
# raise Exception("Cannot redefine optimization: '%s'" % name)
for postreq in before: # self.ord[name] = set(after)
self.ord.setdefault(postreq, set()).add(name)
def get_opts(self): # for postreq in before:
if not self.up_to_date: # self.ord.setdefault(postreq, set()).add(name)
self.refresh()
return self._opts
def refresh(self): # def get_opts(self):
self._opts = [self.name_to_opt[name] for name in utils.toposort(self.ord)] # if not self.up_to_date:
self.up_to_date = True # self.refresh()
# return self._opts
def apply(self, env): # def refresh(self):
for opt in self.opts: # self._opts = [self.name_to_opt[name] for name in utils.toposort(self.ord)]
opt.apply(env) # self.up_to_date = True
opts = property(get_opts) # def apply(self, env):
# for opt in self.opts:
# opt.apply(env)
# opts = property(get_opts)
class TaggedMultiOptimizer(MultiOptimizer):
def __init__(self, **opts): # class TaggedMultiOptimizer(MultiOptimizer):
self.tags = {}
MultiOptimizer.__init__(self, **opts)
def register(self, name, opt, tags = [], **relative): # def __init__(self, **opts):
tags = set(tags) # self.tags = {}
tags.add(name) # MultiOptimizer.__init__(self, **opts)
self.tags[opt] = tags
MultiOptimizer.register(self, name, opt, **relative)
def filter(self, whitelist, blacklist): # def register(self, name, opt, tags = [], **relative):
return [opt for opt in self.opts # tags = set(tags)
if self.tags[opt].intersection(whitelist) # tags.add(name)
and not self.tags[opt].intersection(blacklist)] # self.tags[opt] = tags
# MultiOptimizer.register(self, name, opt, **relative)
def whitelist(self, *tags): # def filter(self, whitelist, blacklist):
return [opt for opt in self.opts if self.tags[opt].intersection(tags)] # return [opt for opt in self.opts
# if self.tags[opt].intersection(whitelist)
# and not self.tags[opt].intersection(blacklist)]
def blacklist(self, *tags): # def whitelist(self, *tags):
return [opt for opt in self.opts if not self.tags[opt].intersection(tags)] # return [opt for opt in self.opts if self.tags[opt].intersection(tags)]
# def blacklist(self, *tags):
# return [opt for opt in self.opts if not self.tags[opt].intersection(tags)]
class TagFilterMultiOptimizer(Optimizer):
def __init__(self, all, whitelist = None, blacklist = None): # class TagFilterMultiOptimizer(Optimizer):
self.all = all
if whitelist is not None: # def __init__(self, all, whitelist = None, blacklist = None):
self.whitelist = set(whitelist) # self.all = all
else:
self.whitelist = None
if blacklist is not None: # if whitelist is not None:
self.blacklist = set(blacklist) # self.whitelist = set(whitelist)
else: # else:
self.blacklist = set() # self.whitelist = None
def use_whitelist(self, use = True):
if self.whitelist is None and use:
self.whitelist = set()
def allow(self, *tags):
if self.whitelist is not None:
self.whitelist.update(tags)
self.blacklist.difference_update(tags)
def deny(self, *tags):
if self.whitelist is not None:
self.whitelist.difference_update(tags)
self.blacklist.update(tags)
def dont_care(self, *tags):
if self.whitelist is not None:
self.whitelist.difference_update(tags)
self.blacklist.difference_update(tags)
def opts(self):
if self.whitelist is not None:
return self.all.filter(self.whitelist, self.blacklist)
else:
return self.all.blacklist(*[tag for tag in self.blacklist])
def apply(self, env): # if blacklist is not None:
for opt in self.opts(): # self.blacklist = set(blacklist)
opt.apply(env) # else:
# self.blacklist = set()
# def use_whitelist(self, use = True):
# if self.whitelist is None and use:
# self.whitelist = set()
# def allow(self, *tags):
# if self.whitelist is not None:
# self.whitelist.update(tags)
# self.blacklist.difference_update(tags)
# def deny(self, *tags):
# if self.whitelist is not None:
# self.whitelist.difference_update(tags)
# self.blacklist.update(tags)
# def dont_care(self, *tags):
# if self.whitelist is not None:
# self.whitelist.difference_update(tags)
# self.blacklist.difference_update(tags)
# def opts(self):
# if self.whitelist is not None:
# return self.all.filter(self.whitelist, self.blacklist)
# else:
# return self.all.blacklist(*[tag for tag in self.blacklist])
# def apply(self, env):
# for opt in self.opts():
# opt.apply(env)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论