提交 e94955a7 authored 作者: Frederic's avatar Frederic

pep8 fix.

上级 43cb38b8
import sys, StringIO
import StringIO
import sys
if sys.version_info[:2] >= (2,5):
if sys.version_info[:2] >= (2, 5):
from collections import defaultdict
else:
from python25 import defaultdict
import numpy
import opt
from theano.configparser import TheanoConfigParser, AddConfigVar, FloatParam
from theano.configparser import AddConfigVar, FloatParam
from theano import config
AddConfigVar('optdb.position_cutoff',
'Where to stop eariler during optimization. It represent the position of the optimizer where to stop.',
'Where to stop eariler during optimization. It represent the'
' position of the optimizer where to stop.',
FloatParam(numpy.inf),
in_c_key=False)
#upgraded to 20 to avoid EquibriumOptimizer error
......@@ -22,6 +24,7 @@ AddConfigVar('optdb.max_use_ratio',
FloatParam(20),
in_c_key=False)
class DB(object):
def __hash__(self):
if not hasattr(self, '_optimizer_idx'):
......@@ -32,7 +35,7 @@ class DB(object):
def __init__(self):
self.__db__ = defaultdict(set)
self._names = set()
self.name = None #will be reset by register
self.name = None # will be reset by register
#(via obj.name by the thing doing the registering)
def register(self, name, obj, *tags):
......@@ -42,13 +45,15 @@ class DB(object):
if not isinstance(obj, (DB, opt.Optimizer, opt.LocalOptimizer)):
raise TypeError('Object cannot be registered in OptDB', obj)
if name in self.__db__:
raise ValueError('The name of the object cannot be an existing tag or the name of another existing object.', obj, name)
raise ValueError('The name of the object cannot be an existing'
' tag or the name of another existing object.',
obj, name)
# This restriction is there because in many place we suppose that
# something in the DB is there only once.
if getattr(obj, 'name', "") in self.__db__:
raise ValueError('''You can\'t register the same optimization
multiple time in a DB. Tryed to register "%s" again under the new name "%s".
Use theano.gof.ProxyDB to work around that'''%(obj.name, name))
Use theano.gof.ProxyDB to work around that''' % (obj.name, name))
if self.name is not None:
tags = tags + (self.name,)
......@@ -60,11 +65,12 @@ multiple time in a DB. Tryed to register "%s" again under the new name "%s".
def add_tags(self, name, *tags):
obj = self.__db__[name]
assert len(obj)==1
assert len(obj) == 1
obj = obj.copy().pop()
for tag in tags:
if tag in self._names:
raise ValueError('The tag of the object collides with a name.', obj, tag)
raise ValueError('The tag of the object collides with a name.',
obj, tag)
self.__db__[tag].add(obj)
def __query__(self, q):
......@@ -94,36 +100,41 @@ multiple time in a DB. Tryed to register "%s" again under the new name "%s".
def query(self, *tags, **kwtags):
if len(tags) >= 1 and isinstance(tags[0], Query):
if len(tags) > 1 or kwtags:
raise TypeError('If the first argument to query is a Query, there should be no other arguments.', tags, kwtags)
raise TypeError('If the first argument to query is a Query,'
' there should be no other arguments.',
tags, kwtags)
return self.__query__(tags[0])
include = [tag[1:] for tag in tags if tag.startswith('+')]
require = [tag[1:] for tag in tags if tag.startswith('&')]
exclude = [tag[1:] for tag in tags if tag.startswith('-')]
if len(include) + len(require) + len(exclude) < len(tags):
raise ValueError("All tags must start with one of the following characters: '+', '&' or '-'", tags)
return self.__query__(Query(include = include,
require = require,
exclude = exclude,
subquery = kwtags))
raise ValueError("All tags must start with one of the following"
" characters: '+', '&' or '-'", tags)
return self.__query__(Query(include=include,
require=require,
exclude=exclude,
subquery=kwtags))
def __getitem__(self, name):
variables = self.__db__[name]
if not variables:
raise KeyError("Nothing registered for '%s'" % name)
elif len(variables) > 1:
raise ValueError('More than one match for %s (please use query)' % name)
raise ValueError('More than one match for %s (please use query)' %
name)
for variable in variables:
return variable
def print_summary(self, stream=sys.stdout):
print >> stream, "%s (id %i)"%(self.__class__.__name__, id(self))
print >> stream, "%s (id %i)" % (self.__class__.__name__, id(self))
print >> stream, " names", self._names
print >> stream, " db", self.__db__
class Query(object):
def __init__(self, include, require = None, exclude = None, subquery = None, position_cutoff = None):
def __init__(self, include, require=None, exclude=None,
subquery=None, position_cutoff=None):
"""
:type position_cutoff: float
:param position_cutoff: Used by SequenceDB to keep only optimizer that
......@@ -142,6 +153,7 @@ class Query(object):
self.exclude,
self.subquery,
self.position_cutoff)
#remove all opt with this tag
def excluding(self, *tags):
return Query(self.include,
......@@ -149,6 +161,7 @@ class Query(object):
self.exclude.union(tags),
self.subquery,
self.position_cutoff)
#keep only opt with this tag.
def requiring(self, *tags):
return Query(self.include,
......@@ -158,17 +171,16 @@ class Query(object):
self.position_cutoff)
class EquilibriumDB(DB):
"""A set of potential optimizations which should be applied in an arbitrary order until
equilibrium is reached.
""" A set of potential optimizations which should be applied in an
arbitrary order until equilibrium is reached.
Canonicalize, Stabilize, and Specialize are all equilibrium optimizations.
.. note::
We can put LocalOptimizer and Optimizer as EquilibriumOptimizer suppor both.
We can put LocalOptimizer and Optimizer as EquilibriumOptimizer
suppor both.
"""
......@@ -186,15 +198,15 @@ class SequenceDB(DB):
Retrieve a sequence of optimizations (a SeqOptimizer) by calling query().
Each potential optimization is registered with a floating-point position.
No matter which optimizations are selected by a query, they are carried out in order of
increasing position.
No matter which optimizations are selected by a query, they are carried
out in order of increasing position.
The optdb itself (`theano.compile.mode.optdb`), from which (among many other tags) fast_run
and fast_compile optimizers are drawn is a SequenceDB.
The optdb itself (`theano.compile.mode.optdb`), from which (among many
other tags) fast_run and fast_compile optimizers are drawn is a SequenceDB.
"""
def __init__(self, failure_callback = opt.SeqOptimizer.warn):
def __init__(self, failure_callback=opt.SeqOptimizer.warn):
super(SequenceDB, self).__init__()
self.__position__ = {}
self.failure_callback = failure_callback
......@@ -206,26 +218,29 @@ class SequenceDB(DB):
def query(self, *tags, **kwtags):
"""
:type position_cutoff: float or int
:param position_cutoff: only optimizations with position less than the cutoff are returned.
:param position_cutoff: only optimizations with position less than
the cutoff are returned.
"""
opts = super(SequenceDB, self).query(*tags, **kwtags)
position_cutoff = kwtags.pop('position_cutoff', config.optdb.position_cutoff)
if len(tags)>=1 and isinstance(tags[0],Query):
position_cutoff = kwtags.pop('position_cutoff',
config.optdb.position_cutoff)
if len(tags) >= 1 and isinstance(tags[0], Query):
#the call to super should have raise an error with a good message
assert len(tags)==1
if getattr(tags[0],'position_cutoff', None):
assert len(tags) == 1
if getattr(tags[0], 'position_cutoff', None):
position_cutoff = tags[0].position_cutoff
opts = [o for o in opts if self.__position__[o.name] < position_cutoff]
opts.sort(key = lambda obj: self.__position__[obj.name])
return opt.SeqOptimizer(opts, failure_callback = self.failure_callback)
opts.sort(key=lambda obj: self.__position__[obj.name])
return opt.SeqOptimizer(opts, failure_callback=self.failure_callback)
def print_summary(self, stream=sys.stdout):
print >> stream, "SequenceDB (id %i)"%id(self)
print >> stream, "SequenceDB (id %i)" % id(self)
positions = self.__position__.items()
def c(a,b):
return cmp(a[1],b[1])
def c(a, b):
return cmp(a[1], b[1])
positions.sort(c)
print >> stream, " position", positions
......@@ -240,8 +255,10 @@ class SequenceDB(DB):
class ProxyDB(DB):
"""
This is needed as we can't register the same DB mutiple time in different position
in a SequentialDB
Wrap an existing proxy.
This is needed as we can't register the same DB mutiple time in
different position in a SequentialDB
"""
def __init__(self, db):
assert isinstance(db, DB), ""
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论