提交 96c4b88d authored 作者: Olivier Delalleau's avatar Olivier Delalleau

Fixed bug in deque implementation for Python 2.4

The 'remove' method was bugged (it would not remove an item if it was the last item in the queue). This commit actually replaces the full custom deque implementation by a subclass of the original Python 2.4 implementation of deque, as it seems safer than re-writing everything. The implementation is also moved to the python25.py file since it is the place for Python 2.4 compatibility code.
上级 e5ffef54
import sys import sys
if sys.version_info[:2] >= (2,5):
from collections import deque
else:
class deque(object):
def __init__(self, iterable=(), maxsize=-1):
if not hasattr(self, 'data'):
self.left = self.right = 0
self.data = {}
self.maxsize = maxsize
self.extend(iterable)
def append(self, x):
self.data[self.right] = x
self.right += 1
if self.maxsize != -1 and len(self) > self.maxsize:
self.popleft()
def remove(self, x):
if self.left == self.right:
raise ValueError('cannot remove from empty deque')
for i in xrange(self.left, self.right-1):
elem = self.data[i]
if elem==x:
self.__delitem__(i)
break
def appendleft(self, x):
self.left -= 1
self.data[self.left] = x
if self.maxsize != -1 and len(self) > self.maxsize:
self.pop()
def pop(self):
if self.left == self.right:
raise IndexError('cannot pop from empty deque')
self.right -= 1
elem = self.data[self.right]
del self.data[self.right]
return elem
def popleft(self):
if self.left == self.right:
raise IndexError('cannot pop from empty deque')
elem = self.data[self.left]
del self.data[self.left]
self.left += 1
return elem
def clear(self):
self.data.clear()
self.left = self.right = 0
def extend(self, iterable):
for elem in iterable:
self.append(elem)
def extendleft(self, iterable):
for elem in iterable:
self.appendleft(elem)
def rotate(self, n=1):
if self:
n %= len(self)
for i in xrange(n):
self.appendleft(self.pop())
def __getitem__(self, i):
if i < 0:
i += len(self)
try:
return self.data[i + self.left]
except KeyError:
raise IndexError
def __setitem__(self, i, value):
if i < 0:
i += len(self)
try:
self.data[i + self.left] = value
except KeyError:
raise IndexError
def __delitem__(self, i):
size = len(self)
data = self.data
if i < 0:
i += size
if not data.has_key(i):
raise IndexError
for j in xrange(self.left+i, self.right-1):
data[j] = data[j+1]
self.pop()
def __len__(self):
return self.right - self.left
def __cmp__(self, other):
if type(self) != type(other):
return cmp(type(self), type(other))
return cmp(list(self), list(other))
def __repr__(self, _track=[]):
if id(self) in _track:
return '...'
_track.append(id(self))
r = 'deque(%r)' % (list(self),)
_track.remove(id(self))
return r
def __getstate__(self):
return (tuple(self),)
def __setstate__(self, s):
self.__init__(s[0])
def __hash__(self):
raise TypeError
def __copy__(self):
return self.__class__(self)
def __deepcopy__(self, memo={}):
from copy import deepcopy
result = self.__class__()
memo[id(self)] = result
result.__init__(deepcopy(tuple(self), memo))
return result
from cc import \ from cc import \
CLinker, OpWiseCLinker, DualLinker CLinker, OpWiseCLinker, DualLinker
......
...@@ -8,7 +8,7 @@ if sys.version_info[:2] >= (2,5): ...@@ -8,7 +8,7 @@ if sys.version_info[:2] >= (2,5):
import theano import theano
import toolbox import toolbox
import graph import graph
from theano.gof import deque from theano.gof.python25 import deque
from env import InconsistencyError from env import InconsistencyError
......
...@@ -12,7 +12,8 @@ __docformat__ = "restructuredtext en" ...@@ -12,7 +12,8 @@ __docformat__ = "restructuredtext en"
from copy import copy from copy import copy
import theano import theano
from theano.gof import deque, utils from theano.gof import utils
from theano.gof.python25 import deque
# Lazy imports to avoid circular dependencies. # Lazy imports to avoid circular dependencies.
is_same_graph_with_merge = None is_same_graph_with_merge = None
......
...@@ -16,7 +16,7 @@ import toolbox ...@@ -16,7 +16,7 @@ import toolbox
import op import op
import theano import theano
from theano import config from theano import config
from theano.gof.python25 import any, all from theano.gof.python25 import any, all, deque
from theano.configparser import AddConfigVar, BoolParam, config from theano.configparser import AddConfigVar, BoolParam, config
#if sys.version_info[:2] >= (2,5): #if sys.version_info[:2] >= (2,5):
...@@ -29,7 +29,6 @@ AddConfigVar('time_seq_optimizer', ...@@ -29,7 +29,6 @@ AddConfigVar('time_seq_optimizer',
BoolParam(False), BoolParam(False),
in_c_key=False) in_c_key=False)
from theano.gof import deque
import destroyhandler as dh import destroyhandler as dh
import traceback import traceback
...@@ -989,8 +988,10 @@ class TopoOptimizer(NavigatorOptimizer): ...@@ -989,8 +988,10 @@ class TopoOptimizer(NavigatorOptimizer):
q.append(node) q.append(node)
def pruner(node): def pruner(node):
if node is not current_node: if node is not current_node:
try: q.remove(node) try:
except ValueError: pass q.remove(node)
except ValueError:
pass
u = self.attach_updater(env, importer, pruner) u = self.attach_updater(env, importer, pruner)
try: try:
...@@ -1027,8 +1028,10 @@ class OpKeyOptimizer(NavigatorOptimizer): ...@@ -1027,8 +1028,10 @@ class OpKeyOptimizer(NavigatorOptimizer):
if node.op == op: q.append(node) if node.op == op: q.append(node)
def pruner(node): def pruner(node):
if node is not current_node and node.op == op: if node is not current_node and node.op == op:
try: q.remove(node) try:
except ValueError: pass q.remove(node)
except ValueError:
pass
u = self.attach_updater(env, importer, pruner) u = self.attach_updater(env, importer, pruner)
try: try:
while q: while q:
...@@ -1133,8 +1136,10 @@ class EquilibriumOptimizer(NavigatorOptimizer): ...@@ -1133,8 +1136,10 @@ class EquilibriumOptimizer(NavigatorOptimizer):
q.append(node) q.append(node)
def pruner(node): def pruner(node):
if node is not current_node: if node is not current_node:
try: q.remove(node) try:
except ValueError: pass q.remove(node)
except ValueError:
pass
u = self.attach_updater(env, importer, pruner) u = self.attach_updater(env, importer, pruner)
try: try:
...@@ -1153,7 +1158,7 @@ class EquilibriumOptimizer(NavigatorOptimizer): ...@@ -1153,7 +1158,7 @@ class EquilibriumOptimizer(NavigatorOptimizer):
process_count[lopt] += 1 process_count[lopt] += 1
changed = True changed = True
if node not in env.nodes: if node not in env.nodes:
break# go to next node break # go to next node
finally: finally:
self.detach_updater(env, u) self.detach_updater(env, u)
self.detach_updater(env, u) #TODO: erase this line, it's redundant at best self.detach_updater(env, u) #TODO: erase this line, it's redundant at best
......
...@@ -2,18 +2,24 @@ ...@@ -2,18 +2,24 @@
Helper functions to make gof backwards compatible (tested on python 2.4 and 2.5) Helper functions to make gof backwards compatible (tested on python 2.4 and 2.5)
""" """
import collections
import sys import sys
if sys.version_info[:2] < (2,5): if sys.version_info[:2] < (2,5):
def all(iterable): def all(iterable):
for element in iterable: for element in iterable:
if not element: if not element:
return False return False
return True return True
def any(iterable): def any(iterable):
for element in iterable: for element in iterable:
if element: if element:
return True return True
return False return False
def partial(func, *args, **keywords): def partial(func, *args, **keywords):
def newfunc(*fargs, **fkeywords): def newfunc(*fargs, **fkeywords):
newkeywords = keywords.copy() newkeywords = keywords.copy()
...@@ -23,6 +29,25 @@ if sys.version_info[:2] < (2,5): ...@@ -23,6 +29,25 @@ if sys.version_info[:2] < (2,5):
newfunc.args = args newfunc.args = args
newfunc.keywords = keywords newfunc.keywords = keywords
return newfunc return newfunc
class deque(collections.deque):
"""
Custom deque class to implement the `remove` method.
"""
def remove(self, item):
found = None
for i, x in enumerate(self):
if x == item:
found = i
break
if found is None:
raise ValueError('item not found in deque')
# To remove an item, we rotate the queue until it is the first item
# in the queue, we pop it, and finally we rotate back the queue.
self.rotate(-found)
self.popleft()
self.rotate(found)
class defaultdict(dict): class defaultdict(dict):
def __init__(self, default_factory=None, *a, **kw): def __init__(self, default_factory=None, *a, **kw):
if (default_factory is not None and if (default_factory is not None and
...@@ -60,14 +85,16 @@ if sys.version_info[:2] < (2,5): ...@@ -60,14 +85,16 @@ if sys.version_info[:2] < (2,5):
dict.__repr__(self)) dict.__repr__(self))
else: else:
# Only bother with this else clause and the __all__ line if you are putting # Only bother with this else clause and the __all__ line if you are putting
# this in a separate file. # this in a separate file.
import __builtin__ import __builtin__
all = __builtin__.all all = __builtin__.all
any = __builtin__.any any = __builtin__.any
import functools, collections import functools, collections
partial = functools.partial partial = functools.partial
defaultdict = collections.defaultdict defaultdict = collections.defaultdict
deque = collections.deque
__all__ = ['all', 'any'] __all__ = ['all', 'any']
if sys.version_info[:2] < (2,6): if sys.version_info[:2] < (2,6):
......
import unittest import unittest
from collections import deque
from theano import tensor from theano import tensor
from theano.gof.graph import ( from theano.gof.graph import (
......
...@@ -6,14 +6,13 @@ if sys.version_info[:2] >= (2,5): ...@@ -6,14 +6,13 @@ if sys.version_info[:2] >= (2,5):
else: else:
from theano.gof.python25 import partial from theano.gof.python25 import partial
from collections import deque
import numpy import numpy
from copy import copy from copy import copy
from theano.compile import (SymbolicInputKit, SymbolicInput, from theano.compile import (SymbolicInputKit, SymbolicInput,
Module, module, Method, Member, In, Component) Module, module, Method, Member, In, Component)
from theano.gof import Container from theano.gof import Container
from theano.gof.python25 import deque
from theano.tensor import raw_random from theano.tensor import raw_random
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论