提交 2e3f17cb authored 作者: Brandon T. Willard's avatar Brandon T. Willard

Apply pyupgrade to theano.gof

上级 beb93a44
......@@ -6,15 +6,15 @@ import six.moves.cPickle as pickle
_logger = logging.getLogger("theano.gof.callcache")
class CallCache(object):
class CallCache:
def __init__(self, filename=None):
self.filename = filename
try:
if filename is None:
raise IOError("bad filename") # just goes to except
with open(filename, "r") as f:
raise OSError("bad filename") # just goes to except
with open(filename) as f:
self.cache = pickle.load(f)
except IOError:
except OSError:
self.cache = {}
def persist(self, filename=None):
......
......@@ -8,7 +8,6 @@ import sys
from copy import copy
import numpy as np
from six import reraise, string_types
from six.moves import StringIO
import theano
......@@ -233,7 +232,7 @@ def struct_gen(args, struct_builders, blocks, sub):
# declares the storage
storage_decl = "\n".join(["PyObject* %s;" % arg for arg in args])
# in the constructor, sets the storage to the arguments
storage_set = "\n".join(["this->%s = %s;" % (arg, arg) for arg in args])
storage_set = "\n".join(["this->{} = {};".format(arg, arg) for arg in args])
# increments the storage's refcount in the constructor
storage_incref = "\n".join(["Py_XINCREF(%s);" % arg for arg in args])
# decrements the storage's refcount in the destructor
......@@ -359,7 +358,7 @@ def get_c_declare(r, name, sub):
[
getattr(c.op, "check_input", config.check_input)
for (c, _) in r.clients
if not isinstance(c, string_types)
if not isinstance(c, str)
]
) or (r.owner and getattr(r.owner.op, "check_input", config.check_input)):
c_declare = r.type.c_declare(name, sub, True)
......@@ -405,7 +404,7 @@ def get_c_extract(r, name, sub):
[
getattr(c.op, "check_input", config.check_input)
for (c, _) in r.clients
if not isinstance(c, string_types)
if not isinstance(c, str)
]
):
# check_broadcast is just an hack to easily remove just the
......@@ -415,7 +414,7 @@ def get_c_extract(r, name, sub):
[
getattr(c.op, "check_broadcast", True)
for (c, _) in r.clients
if not isinstance(c, string_types)
if not isinstance(c, str)
]
):
c_extract = r.type.c_extract(name, sub, True)
......@@ -849,7 +848,7 @@ class CLinker(link.Linker):
pass
else:
# The following will be executed if the "try" block succeeds
assert isinstance(c_support_code_apply[-1], string_types), (
assert isinstance(c_support_code_apply[-1], str), (
str(node.op) + " didn't return a string for c_support_code_apply"
)
......@@ -858,13 +857,13 @@ class CLinker(link.Linker):
except utils.MethodNotDefined:
pass
else:
assert isinstance(c_init_code_apply[-1], string_types), (
assert isinstance(c_init_code_apply[-1], str), (
str(node.op) + " didn't return a string for c_init_code_apply"
)
try:
struct_init = op.c_init_code_struct(node, name, sub_struct)
assert isinstance(struct_init, string_types), (
assert isinstance(struct_init, str), (
str(node.op) + " didn't return a string for c_init_code_struct"
)
except utils.MethodNotDefined:
......@@ -872,7 +871,7 @@ class CLinker(link.Linker):
try:
struct_support = op.c_support_code_struct(node, name)
assert isinstance(struct_support, string_types), (
assert isinstance(struct_support, str), (
str(node.op) + " didn't return a string for c_support_code_struct"
)
except utils.MethodNotDefined:
......@@ -880,7 +879,7 @@ class CLinker(link.Linker):
try:
struct_cleanup = op.c_cleanup_code_struct(node, name)
assert isinstance(struct_cleanup, string_types), (
assert isinstance(struct_cleanup, str), (
str(node.op) + " didn't return a string for c_cleanup_code_struct"
)
except utils.MethodNotDefined:
......@@ -891,7 +890,7 @@ class CLinker(link.Linker):
behavior = op.c_code(node, name, isyms, osyms, sub)
except utils.MethodNotDefined:
raise NotImplementedError("%s cannot produce C code" % op)
assert isinstance(behavior, string_types), (
assert isinstance(behavior, str), (
str(node.op) + " didn't return a string for c_code"
)
# To help understand what is following. It help read the c code.
......@@ -1429,7 +1428,7 @@ class CLinker(link.Linker):
# set of variables that have been computed by nodes we have
# seen 'so far' in the loop below
fgraph_computed_set = set()
fgraph_inputs_dict = dict((i, (-1, pos)) for pos, i in enumerate(fgraph.inputs))
fgraph_inputs_dict = {i: (-1, pos) for pos, i in enumerate(fgraph.inputs)}
constant_ids = dict()
op_pos = {} # Apply -> topological position
......@@ -1794,7 +1793,7 @@ class CLinker(link.Linker):
return code.getvalue()
class _CThunk(object):
class _CThunk:
"""
A thunk with a C implementation.
......@@ -1864,7 +1863,7 @@ class _CThunk(object):
)
print(self.error_storage, file=sys.stderr)
raise
reraise(exc_type, exc_value, exc_trace)
raise exc_value.with_traceback(exc_trace)
class OpWiseCLinker(link.LocalLinker):
......@@ -2130,7 +2129,7 @@ class DualLinker(link.Linker):
return f, i1, o1
class HideC(object):
class HideC:
def __hide(*args):
raise utils.MethodNotDefined()
......
......@@ -19,7 +19,7 @@ import warnings
import numpy.distutils
import six.moves.cPickle as pickle
from six import BytesIO, StringIO, b, string_types
from six import BytesIO, StringIO, b
import theano
from theano import config
......@@ -53,8 +53,6 @@ class MissingGXX(Exception):
"""
pass
def debug_counter(name, every=1):
"""
......@@ -70,10 +68,10 @@ def debug_counter(name, every=1):
setattr(debug_counter, name, getattr(debug_counter, name, 0) + 1)
n = getattr(debug_counter, name)
if n % every == 0:
print("debug_counter [%s]: %s" % (name, n), file=sys.stderr)
print("debug_counter [{}]: {}".format(name, n), file=sys.stderr)
class ExtFunction(object):
class ExtFunction:
"""
A C function to put into a DynamicModule.
......@@ -118,10 +116,12 @@ class ExtFunction(object):
It goes into the DynamicModule's method table.
"""
return '\t{"%s", %s, %s, "%s"}' % (self.name, self.name, self.method, self.doc)
return '\t{{"{}", {}, {}, "{}"}}'.format(
self.name, self.name, self.method, self.doc
)
class DynamicModule(object):
class DynamicModule:
def __init__(self, name=None):
assert name is None, (
"The 'name' parameter of DynamicModule"
......@@ -436,7 +436,7 @@ def get_module_hash(src_code, key):
# This should be the C++ compilation command line parameters or the
# libraries to link against.
to_hash += list(key_element)
elif isinstance(key_element, string_types):
elif isinstance(key_element, str):
if key_element.startswith("md5:") or key_element.startswith("hash:"):
# This is actually a sha256 hash of the config options.
# Currently, we still keep md5 to don't break old Theano.
......@@ -481,7 +481,7 @@ def get_safe_part(key):
# rest of the cache mechanism will just skip that key.
hash = None
for key_element in c_link_key[1:]:
if isinstance(key_element, string_types):
if isinstance(key_element, str):
if key_element.startswith("md5:"):
hash = key_element[4:]
break
......@@ -492,7 +492,7 @@ def get_safe_part(key):
return key[0] + (hash,)
class KeyData(object):
class KeyData:
"""
Used to store the key information in the cache.
......@@ -594,7 +594,7 @@ class KeyData(object):
pass
class ModuleCache(object):
class ModuleCache:
"""
Interface to the cache of dynamically compiled modules on disk.
......@@ -1011,7 +1011,7 @@ class ModuleCache(object):
# Test to see that the file is [present and] readable.
open(entry).close()
gone = False
except IOError:
except OSError:
gone = True
if gone:
......@@ -1140,7 +1140,7 @@ class ModuleCache(object):
key_pkl = os.path.join(location, "key.pkl")
assert not os.path.exists(key_pkl)
key_data = KeyData(
keys=set([key]), module_hash=module_hash, key_pkl=key_pkl, entry=name
keys={key}, module_hash=module_hash, key_pkl=key_pkl, entry=name
)
key_broken = False
......@@ -1518,7 +1518,7 @@ class ModuleCache(object):
fname = os.path.join(self.dirname, filename, "key.pkl")
open(fname).close()
has_key = True
except IOError:
except OSError:
has_key = False
if not has_key:
# Use the compiled file by default
......@@ -1724,12 +1724,10 @@ def std_lib_dirs_and_libs():
for f, lib in [("libpython27.a", "libpython 1.2")]:
if not os.path.exists(os.path.join(libdir, f)):
print(
(
"Your Python version is from Canopy. "
+ "You need to install the package '"
+ lib
+ "' from Canopy package manager."
)
"Your Python version is from Canopy. "
+ "You need to install the package '"
+ lib
+ "' from Canopy package manager."
)
libdirs = [
# Used in older Canopy
......@@ -1747,12 +1745,10 @@ def std_lib_dirs_and_libs():
]
):
print(
(
"Your Python version is from Canopy. "
+ "You need to install the package '"
+ lib
+ "' from Canopy package manager."
)
"Your Python version is from Canopy. "
+ "You need to install the package '"
+ lib
+ "' from Canopy package manager."
)
python_lib_dirs.insert(0, libdir)
std_lib_dirs_and_libs.data = [libname], python_lib_dirs
......@@ -1833,15 +1829,15 @@ def gcc_llvm():
# Normally this should not happen as we should not try to
# compile when g++ is not available. If this happen, it
# will crash later so supposing it is not llvm is "safe".
output = b("")
gcc_llvm.is_llvm = b("llvm") in output
output = b""
gcc_llvm.is_llvm = b"llvm" in output
return gcc_llvm.is_llvm
gcc_llvm.is_llvm = None
class Compiler(object):
class Compiler:
"""
Meta compiler that offer some generic function.
......@@ -2077,7 +2073,7 @@ class GCC_compiler(Compiler):
# as stdin (which is the default) results in the process
# waiting forever without returning. For that reason,
# we use a pipe, and use the empty string as input.
(stdout, stderr) = p.communicate(input=b(""))
(stdout, stderr) = p.communicate(input=b"")
if p.returncode != 0:
return None
......@@ -2355,7 +2351,7 @@ class GCC_compiler(Compiler):
line.startswith("#define hypot _hypot") for line in config_h
):
cxxflags.append("-D_hypot=hypot")
except IOError:
except OSError:
pass
return cxxflags
......@@ -2472,9 +2468,9 @@ class GCC_compiler(Compiler):
if dist_suffix is not None and dist_suffix != "":
suffix = dist_suffix
filepath = "%s%s" % (module_name, suffix)
filepath = "{}{}".format(module_name, suffix)
else:
filepath = "%s.%s" % (module_name, get_lib_extension())
filepath = "{}.{}".format(module_name, get_lib_extension())
lib_filename = os.path.join(location, filepath)
......@@ -2488,10 +2484,13 @@ class GCC_compiler(Compiler):
# to support path that includes spaces, we need to wrap it with double quotes on Windows
path_wrapper = '"' if os.name == "nt" else ""
cmd.extend(
["-I%s%s%s" % (path_wrapper, idir, path_wrapper) for idir in include_dirs]
[
"-I{}{}{}".format(path_wrapper, idir, path_wrapper)
for idir in include_dirs
]
)
cmd.extend(
["-L%s%s%s" % (path_wrapper, ldir, path_wrapper) for ldir in lib_dirs]
["-L{}{}{}".format(path_wrapper, ldir, path_wrapper) for ldir in lib_dirs]
)
if hide_symbols and sys.platform != "win32":
# This has been available since gcc 4.0 so we suppose it
......@@ -2501,8 +2500,8 @@ class GCC_compiler(Compiler):
# improved loading times on most platforms (win32 is
# different, as usual).
cmd.append("-fvisibility=hidden")
cmd.extend(["-o", "%s%s%s" % (path_wrapper, lib_filename, path_wrapper)])
cmd.append("%s%s%s" % (path_wrapper, cppfilename, path_wrapper))
cmd.extend(["-o", "{}{}{}".format(path_wrapper, lib_filename, path_wrapper)])
cmd.append("{}{}{}".format(path_wrapper, cppfilename, path_wrapper))
cmd.extend(["-l%s" % l for l in libs])
# print >> sys.stderr, 'COMPILING W CMD', cmd
_logger.debug("Running cmd: %s", " ".join(cmd))
......
......@@ -4,7 +4,6 @@ import shutil
import numpy as np
import six.moves.cPickle as pickle
from six import string_types
import theano
from theano import config
......@@ -46,7 +45,7 @@ def cleanup():
# force the removing of key
have_npy_abi_version = False
break
elif isinstance(obj, string_types):
elif isinstance(obj, str):
if obj.startswith("NPY_ABI_VERSION=0x"):
have_npy_abi_version = True
elif obj.startswith("c_compiler_str="):
......@@ -67,7 +66,7 @@ def cleanup():
if keydata.key_pkl != filename:
keydata.key_pkl = filename
keydata.remove_key(key)
except IOError:
except OSError:
_logger.error(
"Could not remove file '%s'. To complete "
"the clean-up, please remove manually "
......@@ -84,7 +83,7 @@ def cleanup():
"the directory containing it.",
filename,
)
except IOError:
except OSError:
_logger.error(
"Could not clean up this directory: '%s'. To complete "
"the clean-up, please remove it manually.",
......@@ -126,29 +125,21 @@ def print_compiledir_content():
try:
keydata = pickle.load(file)
ops = list(
set(
[
x
for x in flatten(keydata.keys)
if isinstance(x, theano.gof.Op)
]
)
{x for x in flatten(keydata.keys) if isinstance(x, theano.gof.Op)}
)
# Whatever the case, we count compilations for OP classes.
for op_class in set([op.__class__ for op in ops]):
for op_class in {op.__class__ for op in ops}:
table_op_class.setdefault(op_class, 0)
table_op_class[op_class] += 1
if len(ops) == 0:
zeros_op += 1
else:
types = list(
set(
[
x
for x in flatten(keydata.keys)
if isinstance(x, theano.gof.Type)
]
)
{
x
for x in flatten(keydata.keys)
if isinstance(x, theano.gof.Type)
}
)
compile_start = compile_end = float("nan")
for fn in os.listdir(os.path.join(compiledir, dir)):
......@@ -177,7 +168,7 @@ def print_compiledir_content():
nb_keys.setdefault(len(keydata.keys), 0)
nb_keys[len(keydata.keys)] += 1
except IOError:
except OSError:
pass
except AttributeError:
_logger.error("Could not read key file '%s'.", filename)
......@@ -221,16 +212,12 @@ def print_compiledir_content():
big_key_files = sorted(big_key_files, key=lambda t: str(t[1]))
big_total_size = sum([sz for _, sz, _ in big_key_files])
print(
(
"There are directories with key files bigger than %d bytes "
"(they probably contain big tensor constants)" % max_key_file_size
)
"There are directories with key files bigger than %d bytes "
"(they probably contain big tensor constants)" % max_key_file_size
)
print(
(
"They use %d bytes out of %d (total size used by all key files)"
"" % (big_total_size, total_key_sizes)
)
"They use %d bytes out of %d (total size used by all key files)"
"" % (big_total_size, total_key_sizes)
)
for dir, size, ops in big_key_files:
......@@ -246,10 +233,8 @@ def print_compiledir_content():
print(n_k, n_m)
print()
print(
(
"Skipped %d files that contained 0 op "
"(are they always theano.scalar ops?)" % zeros_op
)
"Skipped %d files that contained 0 op "
"(are they always theano.scalar ops?)" % zeros_op
)
......
......@@ -10,7 +10,6 @@ import time
from contextlib import contextmanager
import numpy as np
from six import PY3
from theano import config
......@@ -283,14 +282,9 @@ def lock(tmp_dir, timeout=notset, min_wait=None, max_wait=None, verbosity=1):
nb_wait += 1
time.sleep(random.uniform(min_wait, max_wait))
if PY3:
exception = FileExistsError # noqa
else:
exception = OSError
try:
os.mkdir(tmp_dir)
except exception:
except FileExistsError:
# Error while creating the directory: someone else
# must have tried at the exact same time.
nb_error += 1
......@@ -332,7 +326,7 @@ def refresh_lock(lock_file):
unique id, using a new (randomly generated) id, which is also returned.
"""
unique_id = "%s_%s_%s" % (
unique_id = "{}_{}_{}".format(
os.getpid(),
"".join([str(random.randint(0, 9)) for i in range(10)]),
hostname,
......@@ -355,7 +349,7 @@ def refresh_lock(lock_file):
return unique_id
class Unlocker(object):
class Unlocker:
"""
Class wrapper around release mechanism so that the lock is automatically
released when the program exits (even when crashing or being interrupted),
......
......@@ -22,8 +22,6 @@ class ProtocolError(Exception):
"""
pass
def _contains_cycle(fgraph, orderings):
"""
......@@ -762,17 +760,17 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa
# OPT: pre-compute this on import
tolerate_same = getattr(app.op, "destroyhandler_tolerate_same", [])
assert isinstance(tolerate_same, list)
tolerated = set(
tolerated = {
idx1 for idx0, idx1 in tolerate_same if idx0 == destroyed_idx
)
}
tolerated.add(destroyed_idx)
tolerate_aliased = getattr(
app.op, "destroyhandler_tolerate_aliased", []
)
assert isinstance(tolerate_aliased, list)
ignored = set(
ignored = {
idx1 for idx0, idx1 in tolerate_aliased if idx0 == destroyed_idx
)
}
for i, input in enumerate(app.inputs):
if i in ignored:
continue
......
......@@ -26,8 +26,6 @@ class InconsistencyError(Exception):
"""
pass
class MissingInputError(Exception):
"""
......
......@@ -7,8 +7,6 @@ from collections import deque
from copy import copy
from itertools import count
from six import integer_types, string_types
import theano
from theano import config
from theano.gof.utils import (
......@@ -174,7 +172,7 @@ class Apply(Node):
raise ValueError(
"%s.default_output should be an output index." % self.op
)
elif not isinstance(do, integer_types):
elif not isinstance(do, int):
raise ValueError("%s.default_output should be an int or long" % self.op)
elif do < 0 or do >= len(self.outputs):
raise ValueError("%s.default_output is out of range." % self.op)
......@@ -395,11 +393,11 @@ class Variable(Node):
raise TypeError("owner must be an Apply instance", owner)
self.owner = owner
if index is not None and not isinstance(index, integer_types):
if index is not None and not isinstance(index, int):
raise TypeError("index must be an int", index)
self.index = index
if name is not None and not isinstance(name, string_types):
if name is not None and not isinstance(name, str):
raise TypeError("name must be a string", name)
self.name = name
......@@ -1156,7 +1154,7 @@ default_leaf_formatter = str
def default_node_formatter(op, argstrings):
return "%s(%s)" % (op.op, ", ".join(argstrings))
return "{}({})".format(op.op, ", ".join(argstrings))
def io_connection_pattern(inputs, outputs):
......@@ -1331,7 +1329,7 @@ def view_roots(r):
if owner is not None:
try:
view_map = owner.op.view_map
view_map = dict((owner.outputs[o], i) for o, i in view_map.items())
view_map = {owner.outputs[o]: i for o, i in view_map.items()}
except AttributeError:
return [r]
if r in view_map:
......
......@@ -59,11 +59,11 @@ try:
if not os.path.exists(init_file):
try:
open(init_file, "w").close()
except IOError as e:
except OSError as e:
if os.path.exists(init_file):
pass # has already been created
else:
e.args += ("%s exist? %s" % (location, os.path.exists(location)),)
e.args += ("{} exist? {}".format(location, os.path.exists(location)),)
raise
_need_reload = False
......
......@@ -4,7 +4,6 @@ from copy import copy, deepcopy
from sys import getsizeof
import numpy as np
from six import reraise
from six.moves import StringIO
import theano
......@@ -123,7 +122,7 @@ def raise_with_op(node, thunk=None, exc_info=None, storage_map=None):
exc_type, exc_value, exc_trace = exc_info
if exc_type == KeyboardInterrupt:
# print a simple traceback from KeyboardInterrupt
reraise(exc_type, exc_value, exc_trace)
raise exc_value.with_traceback(exc_trace)
try:
trace = node.outputs[0].tag.trace
except AttributeError:
......@@ -315,11 +314,11 @@ def raise_with_op(node, thunk=None, exc_info=None, storage_map=None):
detailed_err_msg += ", TotalSize: %s Byte(s)\n" % item[3]
else:
detailed_err_msg += "\n"
detailed_err_msg += " TotalSize: %s Byte(s) %.3f GB\n" % (
detailed_err_msg += " TotalSize: {} Byte(s) {:.3f} GB\n".format(
total_size,
total_size / 1024.0 / 1024 / 1024,
)
detailed_err_msg += " TotalSize inputs: %s Byte(s) %.3f GB\n" % (
detailed_err_msg += " TotalSize inputs: {} Byte(s) {:.3f} GB\n".format(
total_size_inputs,
total_size_inputs / 1024.0 / 1024 / 1024,
)
......@@ -341,11 +340,10 @@ def raise_with_op(node, thunk=None, exc_info=None, storage_map=None):
)
# Some exception need extra parameter in inputs. So forget the
# extra long error message in that case.
pass
reraise(exc_type, exc_value, exc_trace)
raise exc_value.with_traceback(exc_trace)
class Linker(object):
class Linker:
"""
WRITEME
......@@ -434,7 +432,7 @@ class Linker(object):
# TODO: Move this class to the compile module, where it is used (and for which it exists).
class Container(object):
class Container:
"""
This class joins a variable with its computed value.
......
......@@ -122,7 +122,7 @@ def compute_test_value(node):
output.tag.test_value = storage_map[output][0]
class CLinkerObject(object):
class CLinkerObject:
"""
Standard elements of an Op or Type used with the CLinker.
......@@ -550,7 +550,7 @@ class CLinkerOp(CLinkerObject):
)
class PureOp(object):
class PureOp:
"""A class that models and constructs operations in a graph.
A `PureOp` instance has several responsibilities:
......@@ -842,7 +842,6 @@ class Op(object2, PureOp, CLinkerOp):
good to do so.
"""
pass
def make_c_thunk(self, node, storage_map, compute_map, no_recycling):
"""Like make_thunk, but will only try to make a C thunk."""
......@@ -1263,19 +1262,17 @@ class COp(Op):
section_re = re.compile(r"^#section ([a-zA-Z0-9_]+)$", re.MULTILINE)
backward_re = re.compile(r"^THEANO_(APPLY|SUPPORT)_CODE_SECTION$", re.MULTILINE)
# This is the set of allowed markers
SECTIONS = set(
[
"init_code",
"init_code_apply",
"init_code_struct",
"support_code",
"support_code_apply",
"support_code_struct",
"cleanup_code_struct",
"code",
"code_cleanup",
]
)
SECTIONS = {
"init_code",
"init_code_apply",
"init_code_struct",
"support_code",
"support_code_apply",
"support_code_struct",
"cleanup_code_struct",
"code",
"code_cleanup",
}
@classmethod
def get_path(cls, f):
......@@ -1535,10 +1532,10 @@ class COp(Op):
def get_sub_macros(self, sub):
define_macros = []
undef_macros = []
define_macros.append("#define FAIL %s" % (self._lquote_macro(sub["fail"]),))
define_macros.append("#define FAIL {}".format(self._lquote_macro(sub["fail"])))
undef_macros.append("#undef FAIL")
if "params" in sub:
define_macros.append("#define PARAMS %s" % (sub["params"],))
define_macros.append("#define PARAMS {}".format(sub["params"]))
undef_macros.append("#undef PARAMS")
return "\n".join(define_macros), "\n".join(undef_macros)
......@@ -1584,7 +1581,7 @@ class COp(Op):
params = ""
if "params" in sub:
params = ", %s" % (sub["params"],)
params = ", {}".format(sub["params"])
# Generate the C code
return """
......
差异被折叠。
......@@ -2,7 +2,7 @@ import copy
import math
import sys
from six import StringIO, integer_types
from six import StringIO
from theano import config
from theano.compat import DefaultOrderedDict
......@@ -10,7 +10,7 @@ from theano.gof import opt
from theano.misc.ordered_set import OrderedSet
class DB(object):
class DB:
def __hash__(self):
if not hasattr(self, "_optimizer_idx"):
self._optimizer_idx = opt._optimizer_idx[0]
......@@ -169,7 +169,7 @@ multiple time in a DB. Tryed to register "%s" again under the new name "%s".
print(" db", self.__db__, file=stream)
class Query(object):
class Query:
"""
Parameters
......@@ -296,7 +296,7 @@ class EquilibriumDB(DB):
"""
def __init__(self, ignore_newtrees=True, tracks_on_change_inputs=False):
super(EquilibriumDB, self).__init__()
super().__init__()
self.ignore_newtrees = ignore_newtrees
self.tracks_on_change_inputs = tracks_on_change_inputs
self.__final__ = {}
......@@ -307,12 +307,12 @@ class EquilibriumDB(DB):
cleanup = kwtags.pop("cleanup", False)
# An opt should not be final and clean up
assert not (final_opt and cleanup)
super(EquilibriumDB, self).register(name, obj, *tags, **kwtags)
super().register(name, obj, *tags, **kwtags)
self.__final__[name] = final_opt
self.__cleanup__[name] = cleanup
def query(self, *tags, **kwtags):
_opts = super(EquilibriumDB, self).query(*tags, **kwtags)
_opts = super().query(*tags, **kwtags)
final_opts = [o for o in _opts if self.__final__.get(o.name, False)]
cleanup_opts = [o for o in _opts if self.__cleanup__.get(o.name, False)]
opts = [o for o in _opts if o not in final_opts and o not in cleanup_opts]
......@@ -349,19 +349,19 @@ class SequenceDB(DB):
seq_opt = opt.SeqOptimizer
def __init__(self, failure_callback=opt.SeqOptimizer.warn):
super(SequenceDB, self).__init__()
super().__init__()
self.__position__ = {}
self.failure_callback = failure_callback
def register(self, name, obj, position, *tags):
super(SequenceDB, self).register(name, obj, *tags)
super().register(name, obj, *tags)
if position == "last":
if len(self.__position__) == 0:
self.__position__[name] = 0
else:
self.__position__[name] = max(self.__position__.values()) + 1
else:
assert isinstance(position, (integer_types, float))
assert isinstance(position, ((int,), float))
self.__position__[name] = position
def query(self, *tags, **kwtags):
......@@ -373,7 +373,7 @@ class SequenceDB(DB):
Only optimizations with position less than the cutoff are returned.
"""
opts = super(SequenceDB, self).query(*tags, **kwtags)
opts = super().query(*tags, **kwtags)
position_cutoff = kwtags.pop("position_cutoff", config.optdb.position_cutoff)
position_dict = self.__position__
......@@ -442,7 +442,7 @@ class LocalGroupDB(DB):
def __init__(
self, apply_all_opts=False, profile=False, local_opt=opt.LocalOptGroup
):
super(LocalGroupDB, self).__init__()
super().__init__()
self.failure_callback = None
self.apply_all_opts = apply_all_opts
self.profile = profile
......@@ -450,7 +450,7 @@ class LocalGroupDB(DB):
self.local_opt = local_opt
def register(self, name, obj, *tags, **kwargs):
super(LocalGroupDB, self).register(name, obj, *tags)
super().register(name, obj, *tags)
position = kwargs.pop("position", "last")
if position == "last":
if len(self.__position__) == 0:
......@@ -458,12 +458,12 @@ class LocalGroupDB(DB):
else:
self.__position__[name] = max(self.__position__.values()) + 1
else:
assert isinstance(position, (integer_types, float))
assert isinstance(position, ((int,), float))
self.__position__[name] = position
def query(self, *tags, **kwtags):
# For the new `useless` optimizer
opts = list(super(LocalGroupDB, self).query(*tags, **kwtags))
opts = list(super().query(*tags, **kwtags))
opts.sort(key=lambda obj: (self.__position__[obj.name], obj.name))
ret = self.local_opt(
......@@ -482,7 +482,7 @@ class TopoDB(DB):
def __init__(
self, db, order="in_to_out", ignore_newtrees=False, failure_callback=None
):
super(TopoDB, self).__init__()
super().__init__()
self.db = db
self.order = order
self.ignore_newtrees = ignore_newtrees
......
......@@ -153,13 +153,13 @@ class Params(dict):
raise TypeError(
'Params: ParamsType attribute "%s" not in Params args.' % field
)
super(Params, self).__init__(**kwargs)
super().__init__(**kwargs)
self.__dict__.update(__params_type__=params_type, __signatures__=None)
def __repr__(self):
return "Params(%s)" % ", ".join(
[
("%s:%s:%s" % (k, type(self[k]).__name__, self[k]))
("{}:{}:{}".format(k, type(self[k]).__name__, self[k]))
for k in sorted(self.keys())
]
)
......@@ -270,14 +270,14 @@ class ParamsType(Type):
if enum_types:
# We don't want same enum names in different enum types.
if sum(len(t) for t in enum_types) != len(
set(k for t in enum_types for k in t)
{k for t in enum_types for k in t}
):
raise AttributeError(
"ParamsType: found different enum types with common constants names."
)
# We don't want same aliases in different enum types.
if sum(len(t.aliases) for t in enum_types) != len(
set(alias for t in enum_types for alias in t.aliases)
{alias for t in enum_types for alias in t.aliases}
):
raise AttributeError(
"ParamsType: found different enum types with common constants aliases."
......@@ -319,11 +319,14 @@ class ParamsType(Type):
# Now we can access value of each enum defined inside enum types wrapped into the current ParamsType.
if key in self.__const_to_enum:
return self.__const_to_enum[key][key]
return super(ParamsType, self).__getattr__(self, key)
return super().__getattr__(self, key)
def __repr__(self):
return "ParamsType<%s>" % ", ".join(
[("%s:%s" % (self.fields[i], self.types[i])) for i in range(self.length)]
[
("{}:{}".format(self.fields[i], self.types[i]))
for i in range(self.length)
]
)
def __eq__(self, other):
......@@ -345,7 +348,7 @@ class ParamsType(Type):
types_string = ",".join(str(t) for t in self.types).encode("utf-8")
fields_hex = hashlib.sha256(fields_string).hexdigest()
types_hex = hashlib.sha256(types_string).hexdigest()
return "_Params_%s_%s" % (fields_hex, types_hex)
return "_Params_{}_{}".format(fields_hex, types_hex)
def has_type(self, theano_type):
"""
......
......@@ -139,8 +139,8 @@ def _toposort(edges):
"""
incoming_edges = reverse_dict(edges)
incoming_edges = dict((k, set(val)) for k, val in incoming_edges.items())
S = set((v for v in edges if v not in incoming_edges))
incoming_edges = {k: set(val) for k, val in incoming_edges.items()}
S = {v for v in edges if v not in incoming_edges}
L = []
while S:
......@@ -189,8 +189,8 @@ def posort(nodes, *cmps):
[0, 8, 2, 4, 6, 1, 3, 5, 7, 9, 16, 18, 10, 12, 14, 17, 19, 11, 13, 15]
"""
comes_before = dict((a, set()) for a in nodes)
comes_after = dict((a, set()) for a in nodes)
comes_before = {a: set() for a in nodes}
comes_after = {a: set() for a in nodes}
def add_links(a, b): # b depends on a
comes_after[a].add(b)
......
......@@ -21,8 +21,6 @@ class AlreadyThere(Exception):
"""
pass
class ReplacementDidntRemovedError(Exception):
"""
......@@ -32,8 +30,6 @@ class ReplacementDidntRemovedError(Exception):
"""
pass
class BadOptimization(Exception):
"""
......@@ -103,7 +99,7 @@ class BadOptimization(Exception):
old_graph=None,
new_graph=None,
):
super(BadOptimization, self).__init__()
super().__init__()
self.old_r = old_r
self.new_r = new_r
self.old_r_val = old_r_val
......@@ -139,7 +135,7 @@ class BadOptimization(Exception):
return self.full_err
sio = StringIO()
val_str_len_limit = 800
print("BadOptimization Error", super(BadOptimization, self).__str__(), file=sio)
print("BadOptimization Error", super().__str__(), file=sio)
print(" Variable: id", id(self.new_r), self.new_r, file=sio)
print(" Op", self.new_r.owner, file=sio)
print(" Value Type:", type(self.new_r_val), file=sio)
......@@ -225,7 +221,7 @@ class BadOptimization(Exception):
return sio.getvalue()
class Feature(object):
class Feature:
"""
Base class for FunctionGraph extensions.
......@@ -466,7 +462,9 @@ class Validator(Feature):
r = uf.f_locals.get("r", "")
reason = uf_info.function
print(
"validate failed on node %s.\n Reason: %s, %s" % (r, reason, e)
"validate failed on node {}.\n Reason: {}, {}".format(
r, reason, e
)
)
raise
t1 = time.time()
......@@ -578,7 +576,9 @@ class ReplaceValidate(History, Validator):
except Exception as e:
fgraph.revert(chk)
if verbose:
print("validate failed on node %s.\n Reason: %s, %s" % (r, reason, e))
print(
"validate failed on node {}.\n Reason: {}, {}".format(r, reason, e)
)
raise
if config.scan.debug:
scans2 = [
......@@ -731,15 +731,15 @@ class PrintListener(Feature):
def on_import(self, fgraph, node, reason):
if self.active:
print("-- importing: %s, reason: %s" % (node, reason))
print("-- importing: {}, reason: {}".format(node, reason))
def on_prune(self, fgraph, node, reason):
if self.active:
print("-- pruning: %s, reason: %s" % (node, reason))
print("-- pruning: {}, reason: {}".format(node, reason))
def on_change_input(self, fgraph, node, i, r, new_r, reason=None):
if self.active:
print("-- changing (%s.inputs[%s]) from %s to %s" % (node, i, r, new_r))
print("-- changing ({}.inputs[{}]) from {} to {}".format(node, i, r, new_r))
class PreserveNames(Feature):
......@@ -918,9 +918,9 @@ def is_same_graph(var1, var2, givens=None):
for to_replace, replace_by in givens.items():
# Map a substitution variable to the computational graphs it
# belongs to.
inside = dict(
(v, [in_var(v, k) for k in (1, 2)]) for v in (to_replace, replace_by)
)
inside = {
v: [in_var(v, k) for k in (1, 2)] for v in (to_replace, replace_by)
}
if (
inside[to_replace][0]
and not inside[to_replace][1]
......
......@@ -10,8 +10,6 @@ import ctypes
import platform
import re
from six import string_types
import theano
from theano import change_flags
from theano.gof import graph, utils
......@@ -272,7 +270,7 @@ class CLinkerType(CLinkerObject):
return ()
class PureType(object):
class PureType:
"""
Interface specification for variable type instances.
......@@ -707,10 +705,10 @@ class CDataType(Type):
extra_support_code="",
version=None,
):
assert isinstance(ctype, string_types)
assert isinstance(ctype, str)
self.ctype = ctype
if freefunc is not None:
assert isinstance(freefunc, string_types)
assert isinstance(freefunc, str)
self.freefunc = freefunc
self.headers = tuple(headers)
self.header_dirs = tuple(header_dirs)
......@@ -848,7 +846,7 @@ if (py_%(name)s == NULL) { %(freefunc)s(%(name)s); }
return v
def __str__(self):
return "%s{%s}" % (self.__class__.__name__, self.ctype)
return "{}{{{}}}".format(self.__class__.__name__, self.ctype)
def __setstate__(self, dct):
self.__dict__.update(dct)
......@@ -1034,7 +1032,7 @@ class EnumType(Type, dict):
raise TypeError(
"%s: some aliases have same names as constants." % type(self).__name__
)
super(EnumType, self).__init__(**kwargs)
super().__init__(**kwargs)
def fromalias(self, alias):
"""
......@@ -1060,11 +1058,11 @@ class EnumType(Type, dict):
names_to_aliases = {constant_name: "" for constant_name in self}
for alias in self.aliases:
names_to_aliases[self.aliases[alias]] = "(%s)" % alias
return "%s<%s>(%s)" % (
return "{}<{}>({})".format(
type(self).__name__,
self.ctype,
", ".join(
"%s%s:%s" % (k, names_to_aliases[k], self[k])
"{}{}:{}".format(k, names_to_aliases[k], self[k])
for k in sorted(self.keys())
),
)
......@@ -1298,7 +1296,7 @@ class EnumList(EnumType):
kwargs.update(ctype=ctype)
if cname is not None:
kwargs.update(cname=cname)
super(EnumList, self).__init__(**kwargs)
super().__init__(**kwargs)
class CEnumType(EnumList):
......@@ -1336,7 +1334,7 @@ class CEnumType(EnumList):
return self.pyint_compat_code + self.c_to_string()
def c_extract(self, name, sub, check_input=True):
swapped_dict = dict((v, k) for (k, v) in self.items())
swapped_dict = {v: k for (k, v) in self.items()}
# swapped_dict's keys are integers.
return """
......@@ -1360,4 +1358,4 @@ class CEnumType(EnumList):
)
def c_code_cache_version(self):
return (1, super(CEnumType, self).c_code_cache_version())
return (1, super().c_code_cache_version())
......@@ -19,7 +19,7 @@ from theano.gof.utils import ANY_TYPE, FALL_THROUGH, comm_guard
################################
class Variable(object):
class Variable:
"""
Serves as a base class of variables for the purpose of unification.
"Unification" here basically means matching two patterns, see the
......@@ -46,7 +46,9 @@ class Variable(object):
return (
self.__class__.__name__
+ "("
+ ", ".join("%s=%s" % (key, value) for key, value in self.__dict__.items())
+ ", ".join(
"{}={}".format(key, value) for key, value in self.__dict__.items()
)
+ ")"
)
......@@ -60,8 +62,6 @@ class FreeVariable(Variable):
"""
pass
class BoundVariable(Variable):
"""
......@@ -70,7 +70,7 @@ class BoundVariable(Variable):
"""
def __init__(self, name, value):
super(BoundVariable, self).__init__(name=name)
super().__init__(name=name)
self.value = value
......@@ -82,7 +82,7 @@ class OrVariable(Variable):
"""
def __init__(self, name, options):
super(OrVariable, self).__init__(name=name)
super().__init__(name=name)
self.options = options
......@@ -94,7 +94,7 @@ class NotVariable(Variable):
"""
def __init__(self, name, not_options):
super(NotVariable, self).__init__(name=name)
super().__init__(name=name)
self.not_options = not_options
......
......@@ -3,7 +3,6 @@ import sys
import traceback
import numpy as np
from six import integer_types, string_types, with_metaclass
from six.moves import StringIO
from theano import config
......@@ -161,8 +160,6 @@ undef = object()
class TestValueError(Exception):
"""Base exception class for all test value errors."""
pass
class MethodNotDefined(Exception):
"""
......@@ -173,8 +170,6 @@ class MethodNotDefined(Exception):
"""
pass
class MetaObject(type):
def __new__(cls, name, bases, dct):
......@@ -182,7 +177,7 @@ class MetaObject(type):
if props is not None:
if not isinstance(props, tuple):
raise TypeError("__props__ has to be a tuple")
if not all(isinstance(p, string_types) for p in props):
if not all(isinstance(p, str) for p in props):
raise TypeError("elements of __props__ have to be strings")
def _props(self):
......@@ -201,7 +196,7 @@ class MetaObject(type):
least all the original props.
"""
return dict([(a, getattr(self, a)) for a in props])
return {a: getattr(self, a) for a in props}
dct["_props_dict"] = _props_dict
......@@ -225,14 +220,16 @@ class MetaObject(type):
if len(props) == 0:
def __str__(self):
return "%s" % (self.__class__.__name__,)
return "{}".format(self.__class__.__name__)
else:
def __str__(self):
return "%s{%s}" % (
return "{}{{{}}}".format(
self.__class__.__name__,
", ".join("%s=%r" % (p, getattr(self, p)) for p in props),
", ".join(
"{}={!r}".format(p, getattr(self, p)) for p in props
),
)
dct["__str__"] = __str__
......@@ -240,14 +237,14 @@ class MetaObject(type):
return type.__new__(cls, name, bases, dct)
class object2(with_metaclass(MetaObject, object)):
class object2(metaclass=MetaObject):
__slots__ = []
def __ne__(self, other):
return not self == other
class Scratchpad(object):
class Scratchpad:
def clear(self):
self.__dict__.clear()
......@@ -264,7 +261,7 @@ class Scratchpad(object):
def info(self):
print("<theano.gof.utils.scratchpad instance at %i>" % id(self))
for k, v in self.__dict__.items():
print(" %s: %s" % (k, v))
print(" {}: {}".format(k, v))
class ValidatingScratchpad(Scratchpad):
......@@ -330,7 +327,7 @@ def deprecated(filename, msg=""):
def g(*args, **kwargs):
if printme[0]:
print("WARNING: %s.%s deprecated. %s" % (filename, f.__name__, msg))
print("WARNING: {}.{} deprecated. {}".format(filename, f.__name__, msg))
printme[0] = False
return f(*args, **kwargs)
......@@ -404,7 +401,7 @@ def toposort(prereqs_d):
for x, prereqs in prereqs_d.items():
for prereq in prereqs:
postreqs_d.setdefault(prereq, set()).add(x)
next = set([k for k in prereqs_d if not prereqs_d[k]])
next = {k for k in prereqs_d if not prereqs_d[k]}
while next:
bases = next
next = set()
......@@ -449,7 +446,7 @@ RETRY = Keyword("RETRY", False)
FAILURE = Keyword("FAILURE", False)
simple_types = integer_types + string_types + (float, bool, None.__class__, Keyword)
simple_types = (int, str, float, bool, type(None), Keyword)
ANY_TYPE = Keyword("ANY_TYPE")
......
......@@ -30,8 +30,8 @@ def calculate_reallocate_info(order, fgraph, storage_map, compute_map_re, depend
for var in fgraph.variables:
viewed_by[var] = []
view_of = {}
pre_allocated = set([])
allocated = set([])
pre_allocated = set()
allocated = set()
for idx in range(len(order)):
node = order[idx]
......@@ -120,7 +120,7 @@ def calculate_reallocate_info(order, fgraph, storage_map, compute_map_re, depend
return reallocated_info
class VM(object):
class VM:
"""
A VM object's __call__ method evaluates a Theano program.
......@@ -273,7 +273,7 @@ class LoopGC(VM):
"""
def __init__(self, nodes, thunks, pre_call_clear, post_thunk_clear):
super(LoopGC, self).__init__(nodes, thunks, pre_call_clear)
super().__init__(nodes, thunks, pre_call_clear)
self.post_thunk_clear = post_thunk_clear
# Some other part of Theano query that information
self.allow_gc = True
......@@ -353,7 +353,7 @@ class Stack(VM):
callback=None,
callback_input=None,
):
super(Stack, self).__init__(nodes, thunks, pre_call_clear)
super().__init__(nodes, thunks, pre_call_clear)
self.allow_gc = allow_gc
self.message = ""
......@@ -712,7 +712,6 @@ except (OSError, theano.gof.cmodule.MissingGXX) as e:
assert not [x for x in _config_var_list if x.fullname == "linker"][
0
].default.startswith("cvm"), e
pass
class VM_Linker(link.LocalLinker):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论