提交 ff91a554 authored 作者: abergeron's avatar abergeron

Merge pull request #2074 from nouiz/opt_order

Make optimization more deterministic
import sys import sys
from theano.gof.python25 import DefaultOrderedDict
import numpy import numpy
from theano.gof.python25 import DefaultOrderedDict
from theano.misc.ordered_set import OrderedSet
from theano.compat.six import StringIO from theano.compat.six import StringIO
from theano.gof import opt from theano.gof import opt
from theano.configparser import AddConfigVar, FloatParam from theano.configparser import AddConfigVar, FloatParam
...@@ -26,7 +27,7 @@ class DB(object): ...@@ -26,7 +27,7 @@ class DB(object):
return self._optimizer_idx return self._optimizer_idx
def __init__(self): def __init__(self):
self.__db__ = DefaultOrderedDict(set) self.__db__ = DefaultOrderedDict(OrderedSet)
self._names = set() self._names = set()
self.name = None # will be reset by register self.name = None # will be reset by register
#(via obj.name by the thing doing the registering) #(via obj.name by the thing doing the registering)
...@@ -51,7 +52,7 @@ class DB(object): ...@@ -51,7 +52,7 @@ class DB(object):
raise ValueError('''You can\'t register the same optimization raise ValueError('''You can\'t register the same optimization
multiple time in a DB. Tryed to register "%s" again under the new name "%s". multiple time in a DB. Tryed to register "%s" again under the new name "%s".
Use theano.gof.ProxyDB to work around that''' % (obj.name, name)) Use theano.gof.ProxyDB to work around that''' % (obj.name, name))
self.__db__[name] = set([obj]) self.__db__[name] = OrderedSet([obj])
self._names.add(name) self._names.add(name)
self.__db__[obj.__class__.__name__].add(obj) self.__db__[obj.__class__.__name__].add(obj)
self.add_tags(name, *tags) self.add_tags(name, *tags)
...@@ -79,15 +80,16 @@ multiple time in a DB. Tryed to register "%s" again under the new name "%s". ...@@ -79,15 +80,16 @@ multiple time in a DB. Tryed to register "%s" again under the new name "%s".
def __query__(self, q): def __query__(self, q):
if not isinstance(q, Query): if not isinstance(q, Query):
raise TypeError('Expected a Query.', q) raise TypeError('Expected a Query.', q)
variables = set() # The ordered set is needed for deterministic optimization.
variables = OrderedSet()
for tag in q.include: for tag in q.include:
variables.update(self.__db__[tag]) variables.update(self.__db__[tag])
for tag in q.require: for tag in q.require:
variables.intersection_update(self.__db__[tag]) variables.intersection_update(self.__db__[tag])
for tag in q.exclude: for tag in q.exclude:
variables.difference_update(self.__db__[tag]) variables.difference_update(self.__db__[tag])
remove = set() remove = OrderedSet()
add = set() add = OrderedSet()
for obj in variables: for obj in variables:
if isinstance(obj, DB): if isinstance(obj, DB):
sq = q.subquery.get(obj.name, q) sq = q.subquery.get(obj.name, q)
...@@ -143,15 +145,15 @@ class Query(object): ...@@ -143,15 +145,15 @@ class Query(object):
:param position_cutoff: Used by SequenceDB to keep only optimizer that :param position_cutoff: Used by SequenceDB to keep only optimizer that
are positioned before the cut_off point. are positioned before the cut_off point.
""" """
self.include = set(include) self.include = OrderedSet(include)
self.require = require or set() self.require = require or OrderedSet()
self.exclude = exclude or set() self.exclude = exclude or OrderedSet()
self.subquery = subquery or {} self.subquery = subquery or {}
self.position_cutoff = position_cutoff self.position_cutoff = position_cutoff
if isinstance(self.require, (list, tuple)): if isinstance(self.require, (list, tuple)):
self.require = set(self.require) self.require = OrderedSet(self.require)
if isinstance(self.exclude, (list, tuple)): if isinstance(self.exclude, (list, tuple)):
self.exclude = set(self.exclude) self.exclude = OrderedSet(self.exclude)
#add all opt with this tag #add all opt with this tag
def including(self, *tags): def including(self, *tags):
......
...@@ -7,6 +7,7 @@ except ImportError: ...@@ -7,6 +7,7 @@ except ImportError:
from theano.gof.python25 import OrderedDict from theano.gof.python25 import OrderedDict
import types import types
def check_deterministic(iterable): def check_deterministic(iterable):
# Most places where OrderedSet is used, theano interprets any exception # Most places where OrderedSet is used, theano interprets any exception
# whatsoever as a problem that an optimization introduced into the graph. # whatsoever as a problem that an optimization introduced into the graph.
...@@ -40,11 +41,28 @@ if MutableSet is not None: ...@@ -40,11 +41,28 @@ if MutableSet is not None:
# SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. # SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
## {{{ http://code.activestate.com/recipes/576696/ (r5) ## {{{ http://code.activestate.com/recipes/576696/ (r5)
import collections import collections
from weakref import proxy import weakref
class Link(object): class Link(object):
__slots__ = 'prev', 'next', 'key', '__weakref__' __slots__ = 'prev', 'next', 'key', '__weakref__'
def __getstate__(self):
# weakref.proxy don't pickle well, so we use weakref.ref
# manually and don't pickle the weakref.
# We restore the weakref when we unpickle.
ret = [self.prev(), self.next()]
try:
ret.append(self.key)
except AttributeError:
pass
return ret
def __setstate__(self, state):
self.prev = weakref.ref(state[0])
self.next = weakref.ref(state[1])
if len(state) == 3:
self.key = state[2]
class OrderedSet(collections.MutableSet): class OrderedSet(collections.MutableSet):
'Set the remembers the order elements were added' 'Set the remembers the order elements were added'
# Big-O running times for all methods are the same as for regular sets. # Big-O running times for all methods are the same as for regular sets.
...@@ -65,7 +83,7 @@ if MutableSet is not None: ...@@ -65,7 +83,7 @@ if MutableSet is not None:
# Checks added by IG # Checks added by IG
check_deterministic(iterable) check_deterministic(iterable)
self.__root = root = Link() # sentinel node for doubly linked list self.__root = root = Link() # sentinel node for doubly linked list
root.prev = root.next = root root.prev = root.next = weakref.ref(root)
self.__map = {} # key --> link self.__map = {} # key --> link
if iterable is not None: if iterable is not None:
self |= iterable self |= iterable
...@@ -82,32 +100,61 @@ if MutableSet is not None: ...@@ -82,32 +100,61 @@ if MutableSet is not None:
self.__map[key] = link = Link() self.__map[key] = link = Link()
root = self.__root root = self.__root
last = root.prev last = root.prev
link.prev, link.next, link.key = last, root, key link.prev, link.next, link.key = last, weakref.ref(root), key
last.next = root.prev = proxy(link) last().next = root.prev = weakref.ref(link)
def union(self, s):
check_deterministic(s)
n = self.copy()
for elem in s:
if elem not in n:
n.add(elem)
return n
def intersection_update(self, s):
l = []
for elem in self:
if elem not in s:
l.append(elem)
for elem in l:
self.remove(elem)
return self
def difference_update(self, s):
check_deterministic(s)
for elem in s:
if elem in self:
self.remove(elem)
return self
def copy(self):
n = OrderedSet()
n.update(self)
return n
def discard(self, key): def discard(self, key):
# Remove an existing item using self.__map to find the link which is # Remove an existing item using self.__map to find the link which is
# then removed by updating the links in the predecessor and successors. # then removed by updating the links in the predecessor and successors.
if key in self.__map: if key in self.__map:
link = self.__map.pop(key) link = self.__map.pop(key)
link.prev.next = link.next link.prev().next = link.next
link.next.prev = link.prev link.next().prev = link.prev
def __iter__(self): def __iter__(self):
# Traverse the linked list in order. # Traverse the linked list in order.
root = self.__root root = self.__root
curr = root.next curr = root.next()
while curr is not root: while curr is not root:
yield curr.key yield curr.key
curr = curr.next curr = curr.next()
def __reversed__(self): def __reversed__(self):
# Traverse the linked list in reverse order. # Traverse the linked list in reverse order.
root = self.__root root = self.__root
curr = root.prev curr = root.prev()
while curr is not root: while curr is not root:
yield curr.key yield curr.key
curr = curr.prev curr = curr.prev()
def pop(self, last=True): def pop(self, last=True):
if not self: if not self:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论