提交 1ff33711 authored 作者: nouiz's avatar nouiz

Merge pull request #367 from delallea/win_py24

Fixes for Windows and Python 2.4
import sys
if sys.version_info[:2] >= (2,5):
from collections import deque
else:
class deque(object):
def __init__(self, iterable=(), maxsize=-1):
if not hasattr(self, 'data'):
self.left = self.right = 0
self.data = {}
self.maxsize = maxsize
self.extend(iterable)
def append(self, x):
self.data[self.right] = x
self.right += 1
if self.maxsize != -1 and len(self) > self.maxsize:
self.popleft()
def remove(self, x):
if self.left == self.right:
raise ValueError('cannot remove from empty deque')
for i in xrange(self.left, self.right-1):
elem = self.data[i]
if elem==x:
self.__delitem__(i)
break
def appendleft(self, x):
self.left -= 1
self.data[self.left] = x
if self.maxsize != -1 and len(self) > self.maxsize:
self.pop()
def pop(self):
if self.left == self.right:
raise IndexError('cannot pop from empty deque')
self.right -= 1
elem = self.data[self.right]
del self.data[self.right]
return elem
def popleft(self):
if self.left == self.right:
raise IndexError('cannot pop from empty deque')
elem = self.data[self.left]
del self.data[self.left]
self.left += 1
return elem
def clear(self):
self.data.clear()
self.left = self.right = 0
def extend(self, iterable):
for elem in iterable:
self.append(elem)
def extendleft(self, iterable):
for elem in iterable:
self.appendleft(elem)
def rotate(self, n=1):
if self:
n %= len(self)
for i in xrange(n):
self.appendleft(self.pop())
def __getitem__(self, i):
if i < 0:
i += len(self)
try:
return self.data[i + self.left]
except KeyError:
raise IndexError
def __setitem__(self, i, value):
if i < 0:
i += len(self)
try:
self.data[i + self.left] = value
except KeyError:
raise IndexError
def __delitem__(self, i):
size = len(self)
data = self.data
if i < 0:
i += size
if not data.has_key(i):
raise IndexError
for j in xrange(self.left+i, self.right-1):
data[j] = data[j+1]
self.pop()
def __len__(self):
return self.right - self.left
def __cmp__(self, other):
if type(self) != type(other):
return cmp(type(self), type(other))
return cmp(list(self), list(other))
def __repr__(self, _track=[]):
if id(self) in _track:
return '...'
_track.append(id(self))
r = 'deque(%r)' % (list(self),)
_track.remove(id(self))
return r
def __getstate__(self):
return (tuple(self),)
def __setstate__(self, s):
self.__init__(s[0])
def __hash__(self):
raise TypeError
def __copy__(self):
return self.__class__(self)
def __deepcopy__(self, memo={}):
from copy import deepcopy
result = self.__class__()
memo[id(self)] = result
result.__init__(deepcopy(tuple(self), memo))
return result
from cc import \
CLinker, OpWiseCLinker, DualLinker
......
......@@ -125,7 +125,7 @@ def print_compiledir_content():
file = None
try:
try:
file = open(os.path.join(compiledir, dir, "key.pkl"))
file = open(os.path.join(compiledir, dir, "key.pkl"), 'rb')
keydata = cPickle.load(file)
ops = list(set([x for x in flatten(keydata.keys)
if isinstance(x, theano.gof.Op)]))
......
......@@ -8,7 +8,7 @@ if sys.version_info[:2] >= (2,5):
import theano
import toolbox
import graph
from theano.gof import deque
from theano.gof.python25 import deque
from env import InconsistencyError
......
......@@ -12,7 +12,8 @@ __docformat__ = "restructuredtext en"
from copy import copy
import theano
from theano.gof import deque, utils
from theano.gof import utils
from theano.gof.python25 import deque
# Lazy imports to avoid circular dependencies.
is_same_graph_with_merge = None
......
......@@ -16,7 +16,7 @@ import toolbox
import op
import theano
from theano import config
from theano.gof.python25 import any, all
from theano.gof.python25 import any, all, deque
from theano.configparser import AddConfigVar, BoolParam, config
#if sys.version_info[:2] >= (2,5):
......@@ -29,7 +29,6 @@ AddConfigVar('time_seq_optimizer',
BoolParam(False),
in_c_key=False)
from theano.gof import deque
import destroyhandler as dh
import traceback
......@@ -989,8 +988,10 @@ class TopoOptimizer(NavigatorOptimizer):
q.append(node)
def pruner(node):
if node is not current_node:
try: q.remove(node)
except ValueError: pass
try:
q.remove(node)
except ValueError:
pass
u = self.attach_updater(env, importer, pruner)
try:
......@@ -1027,8 +1028,10 @@ class OpKeyOptimizer(NavigatorOptimizer):
if node.op == op: q.append(node)
def pruner(node):
if node is not current_node and node.op == op:
try: q.remove(node)
except ValueError: pass
try:
q.remove(node)
except ValueError:
pass
u = self.attach_updater(env, importer, pruner)
try:
while q:
......@@ -1133,8 +1136,10 @@ class EquilibriumOptimizer(NavigatorOptimizer):
q.append(node)
def pruner(node):
if node is not current_node:
try: q.remove(node)
except ValueError: pass
try:
q.remove(node)
except ValueError:
pass
u = self.attach_updater(env, importer, pruner)
try:
......@@ -1153,7 +1158,7 @@ class EquilibriumOptimizer(NavigatorOptimizer):
process_count[lopt] += 1
changed = True
if node not in env.nodes:
break# go to next node
break # go to next node
finally:
self.detach_updater(env, u)
self.detach_updater(env, u) #TODO: erase this line, it's redundant at best
......
......@@ -2,18 +2,24 @@
Helper functions to make gof backwards compatible (tested on python 2.4 and 2.5)
"""
import collections
import sys
if sys.version_info[:2] < (2,5):
def all(iterable):
for element in iterable:
if not element:
return False
return True
def any(iterable):
for element in iterable:
if element:
return True
return False
def partial(func, *args, **keywords):
def newfunc(*fargs, **fkeywords):
newkeywords = keywords.copy()
......@@ -23,6 +29,25 @@ if sys.version_info[:2] < (2,5):
newfunc.args = args
newfunc.keywords = keywords
return newfunc
class deque(collections.deque):
"""
Custom deque class to implement the `remove` method.
"""
def remove(self, item):
found = None
for i, x in enumerate(self):
if x == item:
found = i
break
if found is None:
raise ValueError('item not found in deque')
# To remove an item, we rotate the queue until it is the first item
# in the queue, we pop it, and finally we rotate back the queue.
self.rotate(-found)
self.popleft()
self.rotate(found)
class defaultdict(dict):
def __init__(self, default_factory=None, *a, **kw):
if (default_factory is not None and
......@@ -60,14 +85,16 @@ if sys.version_info[:2] < (2,5):
dict.__repr__(self))
else:
# Only bother with this else clause and the __all__ line if you are putting
# this in a separate file.
import __builtin__
all = __builtin__.all
any = __builtin__.any
import functools, collections
partial = functools.partial
defaultdict = collections.defaultdict
# Only bother with this else clause and the __all__ line if you are putting
# this in a separate file.
import __builtin__
all = __builtin__.all
any = __builtin__.any
import functools, collections
partial = functools.partial
defaultdict = collections.defaultdict
deque = collections.deque
__all__ = ['all', 'any']
if sys.version_info[:2] < (2,6):
......
import unittest
from collections import deque
from theano import tensor
from theano.gof.graph import (
......
......@@ -6,14 +6,13 @@ if sys.version_info[:2] >= (2,5):
else:
from theano.gof.python25 import partial
from collections import deque
import numpy
from copy import copy
from theano.compile import (SymbolicInputKit, SymbolicInput,
Module, module, Method, Member, In, Component)
from theano.gof import Container
from theano.gof.python25 import deque
from theano.tensor import raw_random
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论