提交 a6fead02 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Pep8 / Pyflakes

上级 872c8166
......@@ -79,4 +79,3 @@ from theano.gof.type import \
from theano.gof.utils import \
object2, MethodNotDefined
......@@ -8,7 +8,6 @@ import os
import shutil
import stat
import StringIO
import struct
import subprocess
import sys
import tempfile
......@@ -153,14 +152,14 @@ static struct PyModuleDef moduledef = {{
}};
""".format(name=self.name)
print >> stream, "PyMODINIT_FUNC PyInit_%s(void) {" % self.name
for b in self.init_blocks:
print >> stream, ' ', b
for block in self.init_blocks:
print >> stream, ' ', block
print >> stream, " PyObject *m = PyModule_Create(&moduledef);"
print >> stream, " return m;"
else:
print >> stream, "PyMODINIT_FUNC init%s(void){" % self.name
for b in self.init_blocks:
print >> stream, ' ', b
for block in self.init_blocks:
print >> stream, ' ', block
print >> stream, ' ', ('(void) Py_InitModule("%s", MyMethods);'
% self.name)
print >> stream, "}"
......@@ -1541,7 +1540,8 @@ class GCC_compiler(object):
lines = stdout + stderr
return lines
# The '-' at the end is needed. Otherwise, g++ do not output enough information.
# The '-' at the end is needed. Otherwise, g++ do not output
# enough information.
native_lines = get_lines("g++ -march=native -E -v -")
_logger.info("g++ -march=native selected lines: %s", native_lines)
if len(native_lines) != 1:
......
# import op
# import variable
import re
import traceback
from theano import config
import re, traceback
def add_tag_trace(thing):
"""Add tag.trace to an node or variable.
......@@ -11,15 +10,18 @@ def add_tag_trace(thing):
The argument is returned after being affected (inplace).
"""
limit = config.traceback.limit
if limit == -1: limit = None
if limit == -1:
limit = None
thing.tag.trace = traceback.extract_stack(limit=limit)[:-1]
return thing
def hashgen():
hashgen.next += 1
return hashgen.next
hashgen.next = 0
class MethodNotDefined(Exception):
"""
To be raised by functions defined as part of an interface.
......@@ -28,6 +30,7 @@ class MethodNotDefined(Exception):
function has been left out of an implementation class.
"""
class object2(object):
__slots__ = []
if 0:
......@@ -36,23 +39,30 @@ class object2(object):
if hasattr(self, '__eq__') or hasattr(self, '__cmp__'):
raise TypeError("unhashable object: %s" % self)
return id(self)
def __ne__(self, other):
return not self == other
class scratchpad:
def clear(self):
self.__dict__.clear()
def __update__(self, other):
self.__dict__.update(other.__dict__)
return self
def __str__(self):
return "scratchpad" + str(self.__dict__)
def __repr__(self):
return "scratchpad" + str(self.__dict__)
def info(self):
print "<theano.gof.utils.scratchpad instance at %i>"%id(self)
for k,v in self.__dict__.items():
print " %s: %s" % (k,v)
print "<theano.gof.utils.scratchpad instance at %i>" % id(self)
for k, v in self.__dict__.items():
print " %s: %s" % (k, v)
class D:
def __init__(self, **d):
......@@ -63,6 +73,7 @@ def memoize(f):
"""Cache the return value for each tuple of arguments
(which must be hashable) """
cache = {}
def rval(*args, **kwargs):
kwtup = tuple(kwargs.items())
key = (args, kwtup)
......@@ -72,8 +83,8 @@ def memoize(f):
else:
val = cache[key]
return val
return rval
return rval
def deprecated(filename, msg=''):
......@@ -92,6 +103,7 @@ def deprecated(filename, msg=''):
"""
def _deprecated(f):
printme = [True]
def g(*args, **kwargs):
if printme[0]:
print 'WARNING: %s.%s deprecated. %s'\
......@@ -99,19 +111,23 @@ def deprecated(filename, msg=''):
printme[0] = False
return f(*args, **kwargs)
return g
return _deprecated
def uniq(seq):
#TODO: consider building a set out of seq so that the if condition is constant time -JB
#TODO: consider building a set out of seq so that the if condition
#is constant time -JB
return [x for i, x in enumerate(seq) if seq.index(x) == i]
def difference(seq1, seq2):
"""
Returns all elements in seq1 which are not in seq2: i.e seq1\seq2
"""
try:
# try to use O(const * len(seq1)) algo
if len(seq2) < 4: # I'm guessing this threshold -JB
if len(seq2) < 4: # I'm guessing this threshold -JB
raise Exception('not worth it')
set2 = set(seq2)
return [x for x in seq1 if x not in set2]
......@@ -132,13 +148,16 @@ def partition(f, seq):
seqf.append(elem)
return seqt, seqf
def attr_checker(*attrs):
def f(candidate):
for attr in attrs:
if not hasattr(candidate, attr):
return False
return True
f.__doc__ = "Checks that the candidate has the following attributes: %s" % ", ".join(["'%s'"%attr for attr in attrs])
f.__doc__ = ("Checks that the candidate has the following attributes: %s"
% ", ".join(["'%s'" % attr for attr in attrs]))
return f
......@@ -149,11 +168,10 @@ def all_bases(cls, accept):
return [cls for cls in rval if accept(cls)]
def all_bases_collect(cls, raw_name):
rval = set()
name = "__%s__" % raw_name
if name in cls.__dict__: # don't use hasattr
if name in cls.__dict__: # don't use hasattr
rval.add(getattr(cls, name))
cut = "__%s_override__" % raw_name
if not cls.__dict__.get(cut, False):
......@@ -162,7 +180,7 @@ def all_bases_collect(cls, raw_name):
return rval
def camelcase_to_separated(string, sep = "_"):
def camelcase_to_separated(string, sep="_"):
return re.sub('(.)([A-Z])', '\\1%s\\2' % sep, string).lower()
......@@ -172,6 +190,7 @@ def to_return_values(values):
else:
return values
def from_return_values(values):
if isinstance(values, (list, tuple)):
return values
......@@ -186,7 +205,8 @@ class ClsInit(type):
Validate and initialize the L{Op} subclass 'cls'
This function:
- changes class attributes input_names and output_names to be lists if they are single strings.
- changes class attributes input_names and output_names to be lists
if they are single strings.
"""
type.__init__(cls, name, bases, dct)
......@@ -195,8 +215,10 @@ class ClsInit(type):
def toposort(prereqs_d):
"""
Sorts prereqs_d.keys() topologically. prereqs_d[x] contains all the elements
that must come before x in the ordering.
Sorts prereqs_d.keys() topologically.
prereqs_d[x] contains all the elements that must come before x
in the ordering.
"""
# all1 = set(prereqs_d.keys())
......@@ -223,19 +245,26 @@ def toposort(prereqs_d):
if not prereqs_d[postreq].difference(done):
next.add(postreq)
if len(prereqs_d) != len(seq):
raise Exception("Cannot sort topologically: there might be cycles, " + \
"prereqs_d does not have a key for each element or " + \
raise Exception("Cannot sort topologically: there might be cycles, "
"prereqs_d does not have a key for each element or "
"some orderings contain invalid elements.")
return seq
def print_for_dot(self):
#TODO: popen2("dot -Tpng | display") and actually make the graph window pop up
print "digraph unix { size = '6,6'; node [color = lightblue2; style = filled];"
for op in self.order:
for input in op.inputs:
if input.owner:
print input.owner.__class__.__name__ + str(abs(id(input.owner))), " -> ", op.__class__.__name__ + str(abs(id(op))), ";"
#TODO: popen2("dot -Tpng | display") and actually make the graph window
#pop up
print ("digraph unix { size = '6,6'; node [color = lightblue2;"
"style = filled];")
for op in self.order:
for input in op.inputs:
if input.owner:
print ' '.join((
input.owner.__class__.__name__ + str(abs(id(input.owner))),
" -> ",
op.__class__.__name__ + str(abs(id(op))),
";"))
class Keyword:
......@@ -263,9 +292,11 @@ simple_types = (int, float, str, bool, None.__class__, Keyword)
ANY_TYPE = Keyword("ANY_TYPE")
FALL_THROUGH = Keyword("FALL_THROUGH")
def comm_guard(type1, type2):
def wrap(f):
old_f = f.func_globals[f.__name__]
def new_f(arg1, arg2, *rest):
if (type1 is ANY_TYPE or isinstance(arg1, type1)) \
and (type2 is ANY_TYPE or isinstance(arg2, type2)):
......@@ -283,6 +314,7 @@ def comm_guard(type1, type2):
return variable
new_f.__name__ = f.__name__
def typename(type):
if isinstance(type, Keyword):
return str(type)
......@@ -290,14 +322,19 @@ def comm_guard(type1, type2):
return "(" + ", ".join([x.__name__ for x in type]) + ")"
else:
return type.__name__
new_f.__doc__ = str(old_f.__doc__) + "\n" + ", ".join([typename(type) for type in (type1, type2)]) + "\n" + str(f.__doc__ or "")
new_f.__doc__ = (str(old_f.__doc__) + "\n" +
", ".join([typename(type) for type in (type1, type2)]) +
"\n" + str(f.__doc__ or ""))
return new_f
return wrap
def type_guard(type1):
def wrap(f):
old_f = f.func_globals[f.__name__]
def new_f(arg1, *rest):
if (type1 is ANY_TYPE or isinstance(arg1, type1)):
variable = f(arg1, *rest)
......@@ -308,8 +345,8 @@ def type_guard(type1):
else:
return old_f(arg1, *rest)
new_f.__name__ = f.__name__
def typename(type):
if isinstance(type, Keyword):
return str(type)
......@@ -317,8 +354,12 @@ def type_guard(type1):
return "(" + ", ".join([x.__name__ for x in type]) + ")"
else:
return type.__name__
new_f.__doc__ = str(old_f.__doc__) + "\n" + ", ".join([typename(type) for type in (type1,)]) + "\n" + str(f.__doc__ or "")
new_f.__doc__ = (str(old_f.__doc__) + "\n" +
", ".join([typename(type) for type in (type1,)]) +
"\n" + str(f.__doc__ or ""))
return new_f
return wrap
......@@ -331,15 +372,18 @@ def flatten(a):
else:
return [a]
def unique(x):
return len(set(x)) == len(x)
def hist(coll):
counts = {}
for elem in coll:
counts[elem] = counts.get(elem, 0) + 1
return counts
def give_variables_names(variables):
""" Gives unique names to an iterable of variables. Modifies input.
......@@ -349,10 +393,10 @@ def give_variables_names(variables):
bad_var = lambda var: not var.name or h[var.name] > 1
for i, var in enumerate(filter(bad_var, variables)):
var.name = (var.name or "") + "_%d"%i
var.name = (var.name or "") + "_%d" % i
if not unique(map(str, variables)):
raise ValueError("Not all variables have unique names."
"Maybe you've named some of the variables identically")
"Maybe you've named some of the variables identically")
return variables
......@@ -9,7 +9,6 @@ import warnings
import numpy
import theano
from theano.gof import local_bitwidth
from theano.gof.cc import hash_from_file
from theano.gof.cmodule import (std_libs, std_lib_dirs,
......@@ -246,8 +245,6 @@ class NVCC_compiler(object):
cppfile = file(cppfilename, 'w')
_logger.debug('Writing module C++ code to %s', cppfilename)
ofiles = []
rval = None
cppfile.write(src_code)
cppfile.close()
......
......@@ -259,7 +259,7 @@ class RepeatOp(theano.Op):
% numpy_unsupported_dtypes), repeats.dtype)
if self.axis is None:
broadcastable=[False]
broadcastable = [False]
else:
try:
const_reps = basic.get_scalar_constant_value(repeats)
......
......@@ -13,6 +13,7 @@ from theano import config, tensor, function
numpy_ver = [int(n) for n in numpy.__version__.split('.')[:2]]
numpy_16 = bool(numpy_ver >= [1, 6])
class TestBinCountOp(utt.InferShapeTester):
def setUp(self):
super(TestBinCountOp, self).setUp()
......@@ -188,7 +189,6 @@ class SqueezeTester(utt.InferShapeTester):
def test_grad(self):
for shape, broadcast in zip(self.shape_list, self.broadcast_list):
data = numpy.random.random(size=shape).astype(theano.config.floatX)
variable = tensor.TensorType(theano.config.floatX, broadcast)()
utt.verify_grad(self.op, [data])
......@@ -287,11 +287,12 @@ class TestRepeatOp(utt.InferShapeTester):
x = T.TensorType(config.floatX, [False, True, False])()
r = RepeatOp(axis=1)(x, 2)
self.assertEqual(r.broadcastable, (False, False, False))
r = RepeatOp(axis=1)(x, 1)
r = RepeatOp(axis=1)(x, 1)
self.assertEqual(r.broadcastable, (False, True, False))
r = RepeatOp(axis=0)(x, 2)
r = RepeatOp(axis=0)(x, 2)
self.assertEqual(r.broadcastable, (False, True, False))
class TestBartlett(utt.InferShapeTester):
def setUp(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论