提交 cb67df1c authored 作者: --global's avatar --global

Add register() method to compilation mode

上级 952b05e5
...@@ -328,6 +328,32 @@ class Mode(object): ...@@ -328,6 +328,32 @@ class Mode(object):
# string? Optimizer? OptDB? who knows??? # string? Optimizer? OptDB? who knows???
return self.__class__(linker=link, optimizer=opt.including(*tags)) return self.__class__(linker=link, optimizer=opt.including(*tags))
def register(self, *optimizations):
"""Adds new optimization instances to a mode.
This method adds new optimization instances to a compilation mode. It
works like the `including()` method but takes as inputs optimization
instances to add instead of tags.
Parameters
----------
optimizations :
Every element of `optimizations` is a tuple containing an
optimization instance and a floating point value indicating the
position at which to insert the optimization in the mode.
Returns
-------
Mode
Copy of the current Mode which includes the provided
optimizations.
"""
link, opt = self.get_linker_optimizer(self.provided_linker,
self.provided_optimizer)
return self.__class__(linker=link,
optimizer=opt.register(*optimizations))
def excluding(self, *tags): def excluding(self, *tags):
link, opt = self.get_linker_optimizer(self.provided_linker, link, opt = self.get_linker_optimizer(self.provided_linker,
self.provided_optimizer) self.provided_optimizer)
......
...@@ -173,12 +173,13 @@ class Query(object): ...@@ -173,12 +173,13 @@ class Query(object):
""" """
def __init__(self, include, require=None, exclude=None, def __init__(self, include, require=None, exclude=None,
subquery=None, position_cutoff=None): subquery=None, position_cutoff=None, extra_optimizations=[]):
self.include = OrderedSet(include) self.include = OrderedSet(include)
self.require = require or OrderedSet() self.require = require or OrderedSet()
self.exclude = exclude or OrderedSet() self.exclude = exclude or OrderedSet()
self.subquery = subquery or {} self.subquery = subquery or {}
self.position_cutoff = position_cutoff self.position_cutoff = position_cutoff
self.extra_optimizations = extra_optimizations
if isinstance(self.require, (list, tuple)): if isinstance(self.require, (list, tuple)):
self.require = OrderedSet(self.require) self.require = OrderedSet(self.require)
if isinstance(self.exclude, (list, tuple)): if isinstance(self.exclude, (list, tuple)):
...@@ -186,9 +187,9 @@ class Query(object): ...@@ -186,9 +187,9 @@ class Query(object):
def __str__(self): def __str__(self):
return ("Query{inc=%s,ex=%s,require=%s,subquery=%s," return ("Query{inc=%s,ex=%s,require=%s,subquery=%s,"
"position_cutoff=%d}" % "position_cutoff=%d,extra_opts=%d}" %
(self.include, self.exclude, self.require, self.subquery, (self.include, self.exclude, self.require, self.subquery,
self.position_cutoff)) self.position_cutoff, self.extra_optimizations))
# add all opt with this tag # add all opt with this tag
def including(self, *tags): def including(self, *tags):
...@@ -196,7 +197,8 @@ class Query(object): ...@@ -196,7 +197,8 @@ class Query(object):
self.require, self.require,
self.exclude, self.exclude,
self.subquery, self.subquery,
self.position_cutoff) self.position_cutoff,
self.extra_optimizations)
# remove all opt with this tag # remove all opt with this tag
def excluding(self, *tags): def excluding(self, *tags):
...@@ -204,7 +206,8 @@ class Query(object): ...@@ -204,7 +206,8 @@ class Query(object):
self.require, self.require,
self.exclude.union(tags), self.exclude.union(tags),
self.subquery, self.subquery,
self.position_cutoff) self.position_cutoff,
self.extra_optimizations)
# keep only opt with this tag. # keep only opt with this tag.
def requiring(self, *tags): def requiring(self, *tags):
...@@ -212,7 +215,16 @@ class Query(object): ...@@ -212,7 +215,16 @@ class Query(object):
self.require.union(tags), self.require.union(tags),
self.exclude, self.exclude,
self.subquery, self.subquery,
self.position_cutoff) self.position_cutoff,
self.extra_optimizations)
def register(self, *optimizations):
return Query(self.include,
self.require,
self.exclude,
self.subquery,
self.position_cutoff,
self.extra_optimizations + list(optimizations))
class EquilibriumDB(DB): class EquilibriumDB(DB):
...@@ -312,6 +324,17 @@ class SequenceDB(DB): ...@@ -312,6 +324,17 @@ class SequenceDB(DB):
if getattr(tags[0], 'position_cutoff', None): if getattr(tags[0], 'position_cutoff', None):
position_cutoff = tags[0].position_cutoff position_cutoff = tags[0].position_cutoff
# The Query instance might contain extra optimizations which need
# to be added the the sequence of optimizations
for extra_opt in tags[0].extra_optimizations:
# Give a name to the extra optimization
opt, position = extra_opt
opt.name = str(opt.__class__)
# Add the extra optimization to the optimization sequence
opts.add(opt)
self.__position__[opt.name] = position
opts = [o for o in opts if self.__position__[o.name] < position_cutoff] opts = [o for o in opts if self.__position__[o.name] < position_cutoff]
# We want to sort by position and then if collision by name # We want to sort by position and then if collision by name
# for deterministic optimization. Since Python 2.2, sort is # for deterministic optimization. Since Python 2.2, sort is
......
...@@ -814,12 +814,14 @@ class Scan(PureOp): ...@@ -814,12 +814,14 @@ class Scan(PureOp):
# Add an optimization to the compilation mode to prevent mitsot, # Add an optimization to the compilation mode to prevent mitsot,
# sitsot and nitsot outputs from being computed inplace (to allow # sitsot and nitsot outputs from being computed inplace (to allow
# their preallocation) # their preallocation). This optimization is added such that it
# will run just before the inplace optimizations
mitsot_start = self.n_mit_mot_outs - len(preallocated_mitmot_outs) mitsot_start = self.n_mit_mot_outs - len(preallocated_mitmot_outs)
nitsot_end = (mitsot_start + self.n_mit_sot + self.n_sit_sot + nitsot_end = (mitsot_start + self.n_mit_sot + self.n_sit_sot +
self.n_nit_sot) self.n_nit_sot)
no_inplace_opt = AddNoOutputFromInplace(0, 4) no_inplace_opt = AddNoOutputFromInplace(mitsot_start, nitsot_end)
compilation_mode = self.mode_instance.register(no_inplace_opt) compilation_mode = self.mode_instance.register((no_inplace_opt,
0.599))
else: else:
# Output preallocation is not activated. Mark every mitmot output # Output preallocation is not activated. Mark every mitmot output
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论