提交 9bc72d95 authored 作者: Olivier Breuleux's avatar Olivier Breuleux

documented opt.py and added a callback to some Optimizers so the user can know…

documented opt.py and added a callback to some Optimizers so the user can know about failed attempts to optimize
上级 6008b0ec
......@@ -9,25 +9,51 @@ import ext
class Optimizer:
"""
An Optimizer can be applied to an env to transform it.
It can represent an optimization or in general any kind
of transformation you could apply to an env.
"""
def apply(self, env):
"""
Applies the optimization to the provided env. It may
use all the methods defined by the env. If the optimizer
needs to use a certain tool, such as an InstanceFinder,
it should set the __env_require__ field to a list of
what needs to be registered with the Env.
"""
pass
def optimize(self, env):
"""
This is meant as a shortcut to:
env.satisfy(opt)
opt.apply(env)
"""
env.satisfy(self)
self.apply(env)
def __call__(self, env):
self.optimize(env)
DummyOpt = Optimizer()
DummyOpt.__doc__ = "Does nothing."
class SeqOptimizer(Optimizer, list):
"""
Takes a list of Optimizer instances and applies them
sequentially.
"""
def __init__(self, *opts):
if len(opts) == 1 and isinstance(opts[0], (list, tuple)):
opts = opts[0]
list.__init__(opts)
def apply(self, env):
"""
Applies each optimizer in self in turn.
"""
for optimizer in self:
optimizer.optimize(env)
......@@ -40,14 +66,34 @@ class SeqOptimizer(Optimizer, list):
class LocalOptimizer(Optimizer):
"""
Generic Optimizer class that considers local parts of
the env. It must be subclassed and should override the
following two methods:
* candidates(env) -> returns a set of ops that can be
optimized
* apply_on_op(env, op) -> for each op in candidates,
this function will be called to perform the actual
optimization.
"""
def candidates(self, env):
return env.ops()
"""
Must return a set of ops that can be optimized.
"""
raise utils.AbstractFunctionError()
def apply_on_op(self, env, op):
raise Exception("Please override this function.")
"""
For each op in candidates, this function will be called to
perform the actual optimization.
"""
raise utils.AbstractFunctionError()
def apply(self, env):
"""
Calls self.apply_on_op(env, op) for each op in self.candidates(env).
"""
for op in self.candidates(env):
if env.has_op(op):
self.apply_on_op(env, op)
......@@ -55,50 +101,95 @@ class LocalOptimizer(Optimizer):
class OpSpecificOptimizer(LocalOptimizer):
"""
Generic optimizer that applies only to ops of a certain
type. The type in question is accessed through self.opclass.
opclass can also be a class variable of the subclass.
"""
__env_require__ = toolbox.InstanceFinder
opclass = Op
def candidates(self, env):
"""
Returns all instances of self.opclass.
"""
return env.get_instances_of(self.opclass)
class OpSubOptimizer(Optimizer):
"""
Replaces all ops of a certain type by ops of another type that
take the same inputs as what they are replacing.
e.g. OpSubOptimizer(add, sub) ==> add(div(x, y), add(y, x)) -> sub(div(x, y), sub(y, x))
"""
__env_require__ = toolbox.InstanceFinder
def __init__(self, op1, op2):
if not op1._default_output_idx >= 0:
raise TypeError("OpSubOptimizer must be used with Op instances that have a default output.")
# note: op2 must have the same input signature as op1
def __init__(self, op1, op2, failure_callback = None):
"""
op1 and op2 must both be Op subclasses, they must both take
the same number of inputs and they must both have the same
number of outputs.
"""
self.op1 = op1
self.op2 = op2
self.failure_callback = failure_callback
def apply(self, env):
"""
Replaces all occurrences of self.op1 by instances of self.op2
with the same inputs.
If failure_callback is not None, it will be called whenever
the Optimizer fails to do a replacement in the graph. The
arguments to the callback are: (op1_instance, replacement, exception)
"""
candidates = env.get_instances_of(self.op1)
for op in candidates:
try:
# note: only replaces the default 'out' port if it exists
r = self.op2(*op.inputs).out
env.replace(op.out, r)
except InconsistencyError, e:
# print "Warning: OpSubOpt failed to transform %s into %s: %s" % (op, self.op2, str(e)) # warning is for debug
repl = self.op2(*op.inputs)
assert len(op.outputs) == len(repl.outputs)
for old, new in zip(op.outputs, repl.outputs):
env.replace(old, new)
except Exception, e:
if self.failure_callback is not None:
self.failure_callback(op, repl, e)
pass
def str(self):
return "%s -> %s" % (self.op1.__name__, self.op2.__name__)
class OpRemover(Optimizer):
"""
Removes all ops of a certain type by transferring each of its
outputs to the corresponding input.
"""
__env_require__ = toolbox.InstanceFinder
def __init__(self, opclass):
def __init__(self, opclass, failure_callback = None):
"""
opclass is the class of the ops to remove. It must take as
many inputs as outputs.
"""
self.opclass = opclass
self.failure_callback = failure_callback
def apply(self, env):
"""
Removes all occurrences of self.opclass.
If self.failure_callback is not None, it will be called whenever
the Optimizer fails to remove an operation in the graph. The
arguments to the callback are: (opclass_instance, exception)
"""
candidates = env.get_instances_of(self.opclass)
for op in candidates:
......@@ -106,10 +197,14 @@ class OpRemover(Optimizer):
assert len(op.inputs) == len(op.outputs)
for input, output in zip(op.inputs, op.outputs):
env.replace(output, input)
except InconsistencyError, e:
# print "Warning: OpRemover failed to remove %s: %s" % (op, str(e)) # warning is for debug
except Exception, e:
if self.failure_callback is not None:
self.failure_callback(op, e)
pass
def str(self):
return "f(%s(x)) -> f(x)" % self.opclass
class PatternOptimizer(OpSpecificOptimizer):
......@@ -117,13 +212,26 @@ class PatternOptimizer(OpSpecificOptimizer):
Replaces all occurrences of the first pattern by the second pattern.
"""
def __init__(self, in_pattern, out_pattern):
def __init__(self, in_pattern, out_pattern, failure_callback = None):
"""
Sets in_pattern for replacement by out_pattern.
self.opclass is set to in_pattern[0] to accelerate the search.
"""
self.in_pattern = in_pattern
self.out_pattern = out_pattern
self.opclass = self.in_pattern[0]
self.__doc__ = self.__class__.__doc__ + "\n\nThis instance does: " + str(self) + "\n"
self.failure_callback = failure_callback
def apply_on_op(self, env, op):
"""
Checks if the graph from op corresponds to in_pattern. If it does,
constructs out_pattern and performs the replacement.
If self.failure_callback is not None, if there is a match but a
replacement fails to occur, the callback will be called with
arguments (results_to_replace, replacement, exception).
"""
def match(pattern, expr, u, first = False):
if isinstance(pattern, (list, tuple)):
......@@ -168,8 +276,9 @@ class PatternOptimizer(OpSpecificOptimizer):
if not isinstance(p, str):
new = new.out
env.replace(op.out, new)
except InconsistencyError, e:
# print "Warning: '%s' failed to apply on %s: %s" % (self, op, str(e)) # warning is for debug
except Exception, e:
if self.failure_callback is not None:
self.failure_callback(op.out, new, e)
pass
......@@ -183,6 +292,11 @@ class PatternOptimizer(OpSpecificOptimizer):
class ConstantFinder(Optimizer):
"""
Sets as constant every orphan that is not destroyed
and sets as indestructible every input that is not
destroyed.
"""
def apply(self, env):
if env.has_feature(ext.DestroyHandler):
......@@ -202,6 +316,12 @@ class ConstantFinder(Optimizer):
class MergeOptimizer(Optimizer):
"""
Merges parts of the graph that are identical, i.e. parts that
take the same inputs and carry out the asme computations so we
can avoid doing them more than once. Also merges results that
are constant.
"""
def apply(self, env):
cid = {}
......@@ -220,6 +340,9 @@ class MergeOptimizer(Optimizer):
inv_cid[i] = r
for op in env.io_toposort():
# this could be made more robust by having an op.hash() that
# doesn't depend on the inputs but can depend on additional properties
# of the op.
op_cid = (op.__class__, tuple([cid[input] for input in op.inputs]))
dup = inv_cid.get(op_cid, None)
if dup is None:
......@@ -237,124 +360,131 @@ class MergeOptimizer(Optimizer):
def MergeOptMerge(opt):
"""
Returns an Optimizer that merges the graph then applies the
optimizer in opt and then merges the graph again in case the
opt introduced additional similarities.
"""
merger = MergeOptimizer()
return SeqOptimizer([merger, opt, merger])
class MultiOptimizer(Optimizer):
### THE FOLLOWING OPTIMIZERS ARE NEITHER USED NOR TESTED BUT PROBABLY WORK AND COULD BE USEFUL ###
# class MultiOptimizer(Optimizer):
def __init__(self, **opts):
self._opts = []
self.ord = {}
self.name_to_opt = {}
self.up_to_date = True
for name, opt in opts:
self.register(name, opt, after = [], before = [])
# def __init__(self, **opts):
# self._opts = []
# self.ord = {}
# self.name_to_opt = {}
# self.up_to_date = True
# for name, opt in opts:
# self.register(name, opt, after = [], before = [])
def register(self, name, opt, **relative):
self.name_to_opt[name] = opt
# def register(self, name, opt, **relative):
# self.name_to_opt[name] = opt
after = relative.get('after', [])
if not isinstance(after, (list, tuple)):
after = [after]
# after = relative.get('after', [])
# if not isinstance(after, (list, tuple)):
# after = [after]
before = relative.get('before', [])
if not isinstance(before, (list, tuple)):
before = [before]
# before = relative.get('before', [])
# if not isinstance(before, (list, tuple)):
# before = [before]
self.up_to_date = False
# self.up_to_date = False
if name in self.ord:
raise Exception("Cannot redefine optimization: '%s'" % name)
# if name in self.ord:
# raise Exception("Cannot redefine optimization: '%s'" % name)
self.ord[name] = set(after)
# self.ord[name] = set(after)
for postreq in before:
self.ord.setdefault(postreq, set()).add(name)
# for postreq in before:
# self.ord.setdefault(postreq, set()).add(name)
def get_opts(self):
if not self.up_to_date:
self.refresh()
return self._opts
# def get_opts(self):
# if not self.up_to_date:
# self.refresh()
# return self._opts
def refresh(self):
self._opts = [self.name_to_opt[name] for name in utils.toposort(self.ord)]
self.up_to_date = True
# def refresh(self):
# self._opts = [self.name_to_opt[name] for name in utils.toposort(self.ord)]
# self.up_to_date = True
def apply(self, env):
for opt in self.opts:
opt.apply(env)
# def apply(self, env):
# for opt in self.opts:
# opt.apply(env)
opts = property(get_opts)
# opts = property(get_opts)
class TaggedMultiOptimizer(MultiOptimizer):
# class TaggedMultiOptimizer(MultiOptimizer):
def __init__(self, **opts):
self.tags = {}
MultiOptimizer.__init__(self, **opts)
# def __init__(self, **opts):
# self.tags = {}
# MultiOptimizer.__init__(self, **opts)
def register(self, name, opt, tags = [], **relative):
tags = set(tags)
tags.add(name)
self.tags[opt] = tags
MultiOptimizer.register(self, name, opt, **relative)
# def register(self, name, opt, tags = [], **relative):
# tags = set(tags)
# tags.add(name)
# self.tags[opt] = tags
# MultiOptimizer.register(self, name, opt, **relative)
def filter(self, whitelist, blacklist):
return [opt for opt in self.opts
if self.tags[opt].intersection(whitelist)
and not self.tags[opt].intersection(blacklist)]
# def filter(self, whitelist, blacklist):
# return [opt for opt in self.opts
# if self.tags[opt].intersection(whitelist)
# and not self.tags[opt].intersection(blacklist)]
def whitelist(self, *tags):
return [opt for opt in self.opts if self.tags[opt].intersection(tags)]
# def whitelist(self, *tags):
# return [opt for opt in self.opts if self.tags[opt].intersection(tags)]
def blacklist(self, *tags):
return [opt for opt in self.opts if not self.tags[opt].intersection(tags)]
# def blacklist(self, *tags):
# return [opt for opt in self.opts if not self.tags[opt].intersection(tags)]
class TagFilterMultiOptimizer(Optimizer):
# class TagFilterMultiOptimizer(Optimizer):
def __init__(self, all, whitelist = None, blacklist = None):
self.all = all
# def __init__(self, all, whitelist = None, blacklist = None):
# self.all = all
if whitelist is not None:
self.whitelist = set(whitelist)
else:
self.whitelist = None
if blacklist is not None:
self.blacklist = set(blacklist)
else:
self.blacklist = set()
def use_whitelist(self, use = True):
if self.whitelist is None and use:
self.whitelist = set()
def allow(self, *tags):
if self.whitelist is not None:
self.whitelist.update(tags)
self.blacklist.difference_update(tags)
def deny(self, *tags):
if self.whitelist is not None:
self.whitelist.difference_update(tags)
self.blacklist.update(tags)
def dont_care(self, *tags):
if self.whitelist is not None:
self.whitelist.difference_update(tags)
self.blacklist.difference_update(tags)
def opts(self):
if self.whitelist is not None:
return self.all.filter(self.whitelist, self.blacklist)
else:
return self.all.blacklist(*[tag for tag in self.blacklist])
# if whitelist is not None:
# self.whitelist = set(whitelist)
# else:
# self.whitelist = None
# if blacklist is not None:
# self.blacklist = set(blacklist)
# else:
# self.blacklist = set()
# def use_whitelist(self, use = True):
# if self.whitelist is None and use:
# self.whitelist = set()
# def allow(self, *tags):
# if self.whitelist is not None:
# self.whitelist.update(tags)
# self.blacklist.difference_update(tags)
# def deny(self, *tags):
# if self.whitelist is not None:
# self.whitelist.difference_update(tags)
# self.blacklist.update(tags)
# def dont_care(self, *tags):
# if self.whitelist is not None:
# self.whitelist.difference_update(tags)
# self.blacklist.difference_update(tags)
# def opts(self):
# if self.whitelist is not None:
# return self.all.filter(self.whitelist, self.blacklist)
# else:
# return self.all.blacklist(*[tag for tag in self.blacklist])
def apply(self, env):
for opt in self.opts():
opt.apply(env)
# def apply(self, env):
# for opt in self.opts():
# opt.apply(env)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论