提交 e94955a7 authored 作者: Frederic's avatar Frederic

pep8 fix.

上级 43cb38b8
import sys, StringIO import StringIO
import sys
if sys.version_info[:2] >= (2,5): if sys.version_info[:2] >= (2, 5):
from collections import defaultdict from collections import defaultdict
else: else:
from python25 import defaultdict from python25 import defaultdict
import numpy import numpy
import opt import opt
from theano.configparser import TheanoConfigParser, AddConfigVar, FloatParam from theano.configparser import AddConfigVar, FloatParam
from theano import config from theano import config
AddConfigVar('optdb.position_cutoff', AddConfigVar('optdb.position_cutoff',
'Where to stop eariler during optimization. It represent the position of the optimizer where to stop.', 'Where to stop eariler during optimization. It represent the'
' position of the optimizer where to stop.',
FloatParam(numpy.inf), FloatParam(numpy.inf),
in_c_key=False) in_c_key=False)
#upgraded to 20 to avoid EquibriumOptimizer error #upgraded to 20 to avoid EquibriumOptimizer error
...@@ -22,6 +24,7 @@ AddConfigVar('optdb.max_use_ratio', ...@@ -22,6 +24,7 @@ AddConfigVar('optdb.max_use_ratio',
FloatParam(20), FloatParam(20),
in_c_key=False) in_c_key=False)
class DB(object): class DB(object):
def __hash__(self): def __hash__(self):
if not hasattr(self, '_optimizer_idx'): if not hasattr(self, '_optimizer_idx'):
...@@ -32,7 +35,7 @@ class DB(object): ...@@ -32,7 +35,7 @@ class DB(object):
def __init__(self): def __init__(self):
self.__db__ = defaultdict(set) self.__db__ = defaultdict(set)
self._names = set() self._names = set()
self.name = None #will be reset by register self.name = None # will be reset by register
#(via obj.name by the thing doing the registering) #(via obj.name by the thing doing the registering)
def register(self, name, obj, *tags): def register(self, name, obj, *tags):
...@@ -42,13 +45,15 @@ class DB(object): ...@@ -42,13 +45,15 @@ class DB(object):
if not isinstance(obj, (DB, opt.Optimizer, opt.LocalOptimizer)): if not isinstance(obj, (DB, opt.Optimizer, opt.LocalOptimizer)):
raise TypeError('Object cannot be registered in OptDB', obj) raise TypeError('Object cannot be registered in OptDB', obj)
if name in self.__db__: if name in self.__db__:
raise ValueError('The name of the object cannot be an existing tag or the name of another existing object.', obj, name) raise ValueError('The name of the object cannot be an existing'
' tag or the name of another existing object.',
obj, name)
# This restriction is there because in many place we suppose that # This restriction is there because in many place we suppose that
# something in the DB is there only once. # something in the DB is there only once.
if getattr(obj, 'name', "") in self.__db__: if getattr(obj, 'name', "") in self.__db__:
raise ValueError('''You can\'t register the same optimization raise ValueError('''You can\'t register the same optimization
multiple time in a DB. Tryed to register "%s" again under the new name "%s". multiple time in a DB. Tryed to register "%s" again under the new name "%s".
Use theano.gof.ProxyDB to work around that'''%(obj.name, name)) Use theano.gof.ProxyDB to work around that''' % (obj.name, name))
if self.name is not None: if self.name is not None:
tags = tags + (self.name,) tags = tags + (self.name,)
...@@ -60,11 +65,12 @@ multiple time in a DB. Tryed to register "%s" again under the new name "%s". ...@@ -60,11 +65,12 @@ multiple time in a DB. Tryed to register "%s" again under the new name "%s".
def add_tags(self, name, *tags): def add_tags(self, name, *tags):
obj = self.__db__[name] obj = self.__db__[name]
assert len(obj)==1 assert len(obj) == 1
obj = obj.copy().pop() obj = obj.copy().pop()
for tag in tags: for tag in tags:
if tag in self._names: if tag in self._names:
raise ValueError('The tag of the object collides with a name.', obj, tag) raise ValueError('The tag of the object collides with a name.',
obj, tag)
self.__db__[tag].add(obj) self.__db__[tag].add(obj)
def __query__(self, q): def __query__(self, q):
...@@ -94,36 +100,41 @@ multiple time in a DB. Tryed to register "%s" again under the new name "%s". ...@@ -94,36 +100,41 @@ multiple time in a DB. Tryed to register "%s" again under the new name "%s".
def query(self, *tags, **kwtags): def query(self, *tags, **kwtags):
if len(tags) >= 1 and isinstance(tags[0], Query): if len(tags) >= 1 and isinstance(tags[0], Query):
if len(tags) > 1 or kwtags: if len(tags) > 1 or kwtags:
raise TypeError('If the first argument to query is a Query, there should be no other arguments.', tags, kwtags) raise TypeError('If the first argument to query is a Query,'
' there should be no other arguments.',
tags, kwtags)
return self.__query__(tags[0]) return self.__query__(tags[0])
include = [tag[1:] for tag in tags if tag.startswith('+')] include = [tag[1:] for tag in tags if tag.startswith('+')]
require = [tag[1:] for tag in tags if tag.startswith('&')] require = [tag[1:] for tag in tags if tag.startswith('&')]
exclude = [tag[1:] for tag in tags if tag.startswith('-')] exclude = [tag[1:] for tag in tags if tag.startswith('-')]
if len(include) + len(require) + len(exclude) < len(tags): if len(include) + len(require) + len(exclude) < len(tags):
raise ValueError("All tags must start with one of the following characters: '+', '&' or '-'", tags) raise ValueError("All tags must start with one of the following"
return self.__query__(Query(include = include, " characters: '+', '&' or '-'", tags)
require = require, return self.__query__(Query(include=include,
exclude = exclude, require=require,
subquery = kwtags)) exclude=exclude,
subquery=kwtags))
def __getitem__(self, name): def __getitem__(self, name):
variables = self.__db__[name] variables = self.__db__[name]
if not variables: if not variables:
raise KeyError("Nothing registered for '%s'" % name) raise KeyError("Nothing registered for '%s'" % name)
elif len(variables) > 1: elif len(variables) > 1:
raise ValueError('More than one match for %s (please use query)' % name) raise ValueError('More than one match for %s (please use query)' %
name)
for variable in variables: for variable in variables:
return variable return variable
def print_summary(self, stream=sys.stdout): def print_summary(self, stream=sys.stdout):
print >> stream, "%s (id %i)"%(self.__class__.__name__, id(self)) print >> stream, "%s (id %i)" % (self.__class__.__name__, id(self))
print >> stream, " names", self._names print >> stream, " names", self._names
print >> stream, " db", self.__db__ print >> stream, " db", self.__db__
class Query(object): class Query(object):
def __init__(self, include, require = None, exclude = None, subquery = None, position_cutoff = None): def __init__(self, include, require=None, exclude=None,
subquery=None, position_cutoff=None):
""" """
:type position_cutoff: float :type position_cutoff: float
:param position_cutoff: Used by SequenceDB to keep only optimizer that :param position_cutoff: Used by SequenceDB to keep only optimizer that
...@@ -142,6 +153,7 @@ class Query(object): ...@@ -142,6 +153,7 @@ class Query(object):
self.exclude, self.exclude,
self.subquery, self.subquery,
self.position_cutoff) self.position_cutoff)
#remove all opt with this tag #remove all opt with this tag
def excluding(self, *tags): def excluding(self, *tags):
return Query(self.include, return Query(self.include,
...@@ -149,6 +161,7 @@ class Query(object): ...@@ -149,6 +161,7 @@ class Query(object):
self.exclude.union(tags), self.exclude.union(tags),
self.subquery, self.subquery,
self.position_cutoff) self.position_cutoff)
#keep only opt with this tag. #keep only opt with this tag.
def requiring(self, *tags): def requiring(self, *tags):
return Query(self.include, return Query(self.include,
...@@ -158,17 +171,16 @@ class Query(object): ...@@ -158,17 +171,16 @@ class Query(object):
self.position_cutoff) self.position_cutoff)
class EquilibriumDB(DB): class EquilibriumDB(DB):
"""A set of potential optimizations which should be applied in an arbitrary order until """ A set of potential optimizations which should be applied in an
equilibrium is reached. arbitrary order until equilibrium is reached.
Canonicalize, Stabilize, and Specialize are all equilibrium optimizations. Canonicalize, Stabilize, and Specialize are all equilibrium optimizations.
.. note:: .. note::
We can put LocalOptimizer and Optimizer as EquilibriumOptimizer suppor both. We can put LocalOptimizer and Optimizer as EquilibriumOptimizer
suppor both.
""" """
...@@ -186,15 +198,15 @@ class SequenceDB(DB): ...@@ -186,15 +198,15 @@ class SequenceDB(DB):
Retrieve a sequence of optimizations (a SeqOptimizer) by calling query(). Retrieve a sequence of optimizations (a SeqOptimizer) by calling query().
Each potential optimization is registered with a floating-point position. Each potential optimization is registered with a floating-point position.
No matter which optimizations are selected by a query, they are carried out in order of No matter which optimizations are selected by a query, they are carried
increasing position. out in order of increasing position.
The optdb itself (`theano.compile.mode.optdb`), from which (among many other tags) fast_run The optdb itself (`theano.compile.mode.optdb`), from which (among many
and fast_compile optimizers are drawn is a SequenceDB. other tags) fast_run and fast_compile optimizers are drawn is a SequenceDB.
""" """
def __init__(self, failure_callback = opt.SeqOptimizer.warn): def __init__(self, failure_callback=opt.SeqOptimizer.warn):
super(SequenceDB, self).__init__() super(SequenceDB, self).__init__()
self.__position__ = {} self.__position__ = {}
self.failure_callback = failure_callback self.failure_callback = failure_callback
...@@ -206,26 +218,29 @@ class SequenceDB(DB): ...@@ -206,26 +218,29 @@ class SequenceDB(DB):
def query(self, *tags, **kwtags): def query(self, *tags, **kwtags):
""" """
:type position_cutoff: float or int :type position_cutoff: float or int
:param position_cutoff: only optimizations with position less than the cutoff are returned. :param position_cutoff: only optimizations with position less than
the cutoff are returned.
""" """
opts = super(SequenceDB, self).query(*tags, **kwtags) opts = super(SequenceDB, self).query(*tags, **kwtags)
position_cutoff = kwtags.pop('position_cutoff', config.optdb.position_cutoff) position_cutoff = kwtags.pop('position_cutoff',
if len(tags)>=1 and isinstance(tags[0],Query): config.optdb.position_cutoff)
if len(tags) >= 1 and isinstance(tags[0], Query):
#the call to super should have raise an error with a good message #the call to super should have raise an error with a good message
assert len(tags)==1 assert len(tags) == 1
if getattr(tags[0],'position_cutoff', None): if getattr(tags[0], 'position_cutoff', None):
position_cutoff = tags[0].position_cutoff position_cutoff = tags[0].position_cutoff
opts = [o for o in opts if self.__position__[o.name] < position_cutoff] opts = [o for o in opts if self.__position__[o.name] < position_cutoff]
opts.sort(key = lambda obj: self.__position__[obj.name]) opts.sort(key=lambda obj: self.__position__[obj.name])
return opt.SeqOptimizer(opts, failure_callback = self.failure_callback) return opt.SeqOptimizer(opts, failure_callback=self.failure_callback)
def print_summary(self, stream=sys.stdout): def print_summary(self, stream=sys.stdout):
print >> stream, "SequenceDB (id %i)"%id(self) print >> stream, "SequenceDB (id %i)" % id(self)
positions = self.__position__.items() positions = self.__position__.items()
def c(a,b):
return cmp(a[1],b[1]) def c(a, b):
return cmp(a[1], b[1])
positions.sort(c) positions.sort(c)
print >> stream, " position", positions print >> stream, " position", positions
...@@ -240,8 +255,10 @@ class SequenceDB(DB): ...@@ -240,8 +255,10 @@ class SequenceDB(DB):
class ProxyDB(DB): class ProxyDB(DB):
""" """
This is needed as we can't register the same DB mutiple time in different position Wrap an existing proxy.
in a SequentialDB
This is needed as we can't register the same DB mutiple time in
different position in a SequentialDB
""" """
def __init__(self, db): def __init__(self, db):
assert isinstance(db, DB), "" assert isinstance(db, DB), ""
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论