提交 2e3f17cb authored 作者: Brandon T. Willard's avatar Brandon T. Willard

Apply pyupgrade to theano.gof

上级 beb93a44
...@@ -6,15 +6,15 @@ import six.moves.cPickle as pickle ...@@ -6,15 +6,15 @@ import six.moves.cPickle as pickle
_logger = logging.getLogger("theano.gof.callcache") _logger = logging.getLogger("theano.gof.callcache")
class CallCache(object): class CallCache:
def __init__(self, filename=None): def __init__(self, filename=None):
self.filename = filename self.filename = filename
try: try:
if filename is None: if filename is None:
raise IOError("bad filename") # just goes to except raise OSError("bad filename") # just goes to except
with open(filename, "r") as f: with open(filename) as f:
self.cache = pickle.load(f) self.cache = pickle.load(f)
except IOError: except OSError:
self.cache = {} self.cache = {}
def persist(self, filename=None): def persist(self, filename=None):
......
...@@ -8,7 +8,6 @@ import sys ...@@ -8,7 +8,6 @@ import sys
from copy import copy from copy import copy
import numpy as np import numpy as np
from six import reraise, string_types
from six.moves import StringIO from six.moves import StringIO
import theano import theano
...@@ -233,7 +232,7 @@ def struct_gen(args, struct_builders, blocks, sub): ...@@ -233,7 +232,7 @@ def struct_gen(args, struct_builders, blocks, sub):
# declares the storage # declares the storage
storage_decl = "\n".join(["PyObject* %s;" % arg for arg in args]) storage_decl = "\n".join(["PyObject* %s;" % arg for arg in args])
# in the constructor, sets the storage to the arguments # in the constructor, sets the storage to the arguments
storage_set = "\n".join(["this->%s = %s;" % (arg, arg) for arg in args]) storage_set = "\n".join(["this->{} = {};".format(arg, arg) for arg in args])
# increments the storage's refcount in the constructor # increments the storage's refcount in the constructor
storage_incref = "\n".join(["Py_XINCREF(%s);" % arg for arg in args]) storage_incref = "\n".join(["Py_XINCREF(%s);" % arg for arg in args])
# decrements the storage's refcount in the destructor # decrements the storage's refcount in the destructor
...@@ -359,7 +358,7 @@ def get_c_declare(r, name, sub): ...@@ -359,7 +358,7 @@ def get_c_declare(r, name, sub):
[ [
getattr(c.op, "check_input", config.check_input) getattr(c.op, "check_input", config.check_input)
for (c, _) in r.clients for (c, _) in r.clients
if not isinstance(c, string_types) if not isinstance(c, str)
] ]
) or (r.owner and getattr(r.owner.op, "check_input", config.check_input)): ) or (r.owner and getattr(r.owner.op, "check_input", config.check_input)):
c_declare = r.type.c_declare(name, sub, True) c_declare = r.type.c_declare(name, sub, True)
...@@ -405,7 +404,7 @@ def get_c_extract(r, name, sub): ...@@ -405,7 +404,7 @@ def get_c_extract(r, name, sub):
[ [
getattr(c.op, "check_input", config.check_input) getattr(c.op, "check_input", config.check_input)
for (c, _) in r.clients for (c, _) in r.clients
if not isinstance(c, string_types) if not isinstance(c, str)
] ]
): ):
# check_broadcast is just an hack to easily remove just the # check_broadcast is just an hack to easily remove just the
...@@ -415,7 +414,7 @@ def get_c_extract(r, name, sub): ...@@ -415,7 +414,7 @@ def get_c_extract(r, name, sub):
[ [
getattr(c.op, "check_broadcast", True) getattr(c.op, "check_broadcast", True)
for (c, _) in r.clients for (c, _) in r.clients
if not isinstance(c, string_types) if not isinstance(c, str)
] ]
): ):
c_extract = r.type.c_extract(name, sub, True) c_extract = r.type.c_extract(name, sub, True)
...@@ -849,7 +848,7 @@ class CLinker(link.Linker): ...@@ -849,7 +848,7 @@ class CLinker(link.Linker):
pass pass
else: else:
# The following will be executed if the "try" block succeeds # The following will be executed if the "try" block succeeds
assert isinstance(c_support_code_apply[-1], string_types), ( assert isinstance(c_support_code_apply[-1], str), (
str(node.op) + " didn't return a string for c_support_code_apply" str(node.op) + " didn't return a string for c_support_code_apply"
) )
...@@ -858,13 +857,13 @@ class CLinker(link.Linker): ...@@ -858,13 +857,13 @@ class CLinker(link.Linker):
except utils.MethodNotDefined: except utils.MethodNotDefined:
pass pass
else: else:
assert isinstance(c_init_code_apply[-1], string_types), ( assert isinstance(c_init_code_apply[-1], str), (
str(node.op) + " didn't return a string for c_init_code_apply" str(node.op) + " didn't return a string for c_init_code_apply"
) )
try: try:
struct_init = op.c_init_code_struct(node, name, sub_struct) struct_init = op.c_init_code_struct(node, name, sub_struct)
assert isinstance(struct_init, string_types), ( assert isinstance(struct_init, str), (
str(node.op) + " didn't return a string for c_init_code_struct" str(node.op) + " didn't return a string for c_init_code_struct"
) )
except utils.MethodNotDefined: except utils.MethodNotDefined:
...@@ -872,7 +871,7 @@ class CLinker(link.Linker): ...@@ -872,7 +871,7 @@ class CLinker(link.Linker):
try: try:
struct_support = op.c_support_code_struct(node, name) struct_support = op.c_support_code_struct(node, name)
assert isinstance(struct_support, string_types), ( assert isinstance(struct_support, str), (
str(node.op) + " didn't return a string for c_support_code_struct" str(node.op) + " didn't return a string for c_support_code_struct"
) )
except utils.MethodNotDefined: except utils.MethodNotDefined:
...@@ -880,7 +879,7 @@ class CLinker(link.Linker): ...@@ -880,7 +879,7 @@ class CLinker(link.Linker):
try: try:
struct_cleanup = op.c_cleanup_code_struct(node, name) struct_cleanup = op.c_cleanup_code_struct(node, name)
assert isinstance(struct_cleanup, string_types), ( assert isinstance(struct_cleanup, str), (
str(node.op) + " didn't return a string for c_cleanup_code_struct" str(node.op) + " didn't return a string for c_cleanup_code_struct"
) )
except utils.MethodNotDefined: except utils.MethodNotDefined:
...@@ -891,7 +890,7 @@ class CLinker(link.Linker): ...@@ -891,7 +890,7 @@ class CLinker(link.Linker):
behavior = op.c_code(node, name, isyms, osyms, sub) behavior = op.c_code(node, name, isyms, osyms, sub)
except utils.MethodNotDefined: except utils.MethodNotDefined:
raise NotImplementedError("%s cannot produce C code" % op) raise NotImplementedError("%s cannot produce C code" % op)
assert isinstance(behavior, string_types), ( assert isinstance(behavior, str), (
str(node.op) + " didn't return a string for c_code" str(node.op) + " didn't return a string for c_code"
) )
# To help understand what is following. It help read the c code. # To help understand what is following. It help read the c code.
...@@ -1429,7 +1428,7 @@ class CLinker(link.Linker): ...@@ -1429,7 +1428,7 @@ class CLinker(link.Linker):
# set of variables that have been computed by nodes we have # set of variables that have been computed by nodes we have
# seen 'so far' in the loop below # seen 'so far' in the loop below
fgraph_computed_set = set() fgraph_computed_set = set()
fgraph_inputs_dict = dict((i, (-1, pos)) for pos, i in enumerate(fgraph.inputs)) fgraph_inputs_dict = {i: (-1, pos) for pos, i in enumerate(fgraph.inputs)}
constant_ids = dict() constant_ids = dict()
op_pos = {} # Apply -> topological position op_pos = {} # Apply -> topological position
...@@ -1794,7 +1793,7 @@ class CLinker(link.Linker): ...@@ -1794,7 +1793,7 @@ class CLinker(link.Linker):
return code.getvalue() return code.getvalue()
class _CThunk(object): class _CThunk:
""" """
A thunk with a C implementation. A thunk with a C implementation.
...@@ -1864,7 +1863,7 @@ class _CThunk(object): ...@@ -1864,7 +1863,7 @@ class _CThunk(object):
) )
print(self.error_storage, file=sys.stderr) print(self.error_storage, file=sys.stderr)
raise raise
reraise(exc_type, exc_value, exc_trace) raise exc_value.with_traceback(exc_trace)
class OpWiseCLinker(link.LocalLinker): class OpWiseCLinker(link.LocalLinker):
...@@ -2130,7 +2129,7 @@ class DualLinker(link.Linker): ...@@ -2130,7 +2129,7 @@ class DualLinker(link.Linker):
return f, i1, o1 return f, i1, o1
class HideC(object): class HideC:
def __hide(*args): def __hide(*args):
raise utils.MethodNotDefined() raise utils.MethodNotDefined()
......
...@@ -19,7 +19,7 @@ import warnings ...@@ -19,7 +19,7 @@ import warnings
import numpy.distutils import numpy.distutils
import six.moves.cPickle as pickle import six.moves.cPickle as pickle
from six import BytesIO, StringIO, b, string_types from six import BytesIO, StringIO, b
import theano import theano
from theano import config from theano import config
...@@ -53,8 +53,6 @@ class MissingGXX(Exception): ...@@ -53,8 +53,6 @@ class MissingGXX(Exception):
""" """
pass
def debug_counter(name, every=1): def debug_counter(name, every=1):
""" """
...@@ -70,10 +68,10 @@ def debug_counter(name, every=1): ...@@ -70,10 +68,10 @@ def debug_counter(name, every=1):
setattr(debug_counter, name, getattr(debug_counter, name, 0) + 1) setattr(debug_counter, name, getattr(debug_counter, name, 0) + 1)
n = getattr(debug_counter, name) n = getattr(debug_counter, name)
if n % every == 0: if n % every == 0:
print("debug_counter [%s]: %s" % (name, n), file=sys.stderr) print("debug_counter [{}]: {}".format(name, n), file=sys.stderr)
class ExtFunction(object): class ExtFunction:
""" """
A C function to put into a DynamicModule. A C function to put into a DynamicModule.
...@@ -118,10 +116,12 @@ class ExtFunction(object): ...@@ -118,10 +116,12 @@ class ExtFunction(object):
It goes into the DynamicModule's method table. It goes into the DynamicModule's method table.
""" """
return '\t{"%s", %s, %s, "%s"}' % (self.name, self.name, self.method, self.doc) return '\t{{"{}", {}, {}, "{}"}}'.format(
self.name, self.name, self.method, self.doc
)
class DynamicModule(object): class DynamicModule:
def __init__(self, name=None): def __init__(self, name=None):
assert name is None, ( assert name is None, (
"The 'name' parameter of DynamicModule" "The 'name' parameter of DynamicModule"
...@@ -436,7 +436,7 @@ def get_module_hash(src_code, key): ...@@ -436,7 +436,7 @@ def get_module_hash(src_code, key):
# This should be the C++ compilation command line parameters or the # This should be the C++ compilation command line parameters or the
# libraries to link against. # libraries to link against.
to_hash += list(key_element) to_hash += list(key_element)
elif isinstance(key_element, string_types): elif isinstance(key_element, str):
if key_element.startswith("md5:") or key_element.startswith("hash:"): if key_element.startswith("md5:") or key_element.startswith("hash:"):
# This is actually a sha256 hash of the config options. # This is actually a sha256 hash of the config options.
# Currently, we still keep md5 to don't break old Theano. # Currently, we still keep md5 to don't break old Theano.
...@@ -481,7 +481,7 @@ def get_safe_part(key): ...@@ -481,7 +481,7 @@ def get_safe_part(key):
# rest of the cache mechanism will just skip that key. # rest of the cache mechanism will just skip that key.
hash = None hash = None
for key_element in c_link_key[1:]: for key_element in c_link_key[1:]:
if isinstance(key_element, string_types): if isinstance(key_element, str):
if key_element.startswith("md5:"): if key_element.startswith("md5:"):
hash = key_element[4:] hash = key_element[4:]
break break
...@@ -492,7 +492,7 @@ def get_safe_part(key): ...@@ -492,7 +492,7 @@ def get_safe_part(key):
return key[0] + (hash,) return key[0] + (hash,)
class KeyData(object): class KeyData:
""" """
Used to store the key information in the cache. Used to store the key information in the cache.
...@@ -594,7 +594,7 @@ class KeyData(object): ...@@ -594,7 +594,7 @@ class KeyData(object):
pass pass
class ModuleCache(object): class ModuleCache:
""" """
Interface to the cache of dynamically compiled modules on disk. Interface to the cache of dynamically compiled modules on disk.
...@@ -1011,7 +1011,7 @@ class ModuleCache(object): ...@@ -1011,7 +1011,7 @@ class ModuleCache(object):
# Test to see that the file is [present and] readable. # Test to see that the file is [present and] readable.
open(entry).close() open(entry).close()
gone = False gone = False
except IOError: except OSError:
gone = True gone = True
if gone: if gone:
...@@ -1140,7 +1140,7 @@ class ModuleCache(object): ...@@ -1140,7 +1140,7 @@ class ModuleCache(object):
key_pkl = os.path.join(location, "key.pkl") key_pkl = os.path.join(location, "key.pkl")
assert not os.path.exists(key_pkl) assert not os.path.exists(key_pkl)
key_data = KeyData( key_data = KeyData(
keys=set([key]), module_hash=module_hash, key_pkl=key_pkl, entry=name keys={key}, module_hash=module_hash, key_pkl=key_pkl, entry=name
) )
key_broken = False key_broken = False
...@@ -1518,7 +1518,7 @@ class ModuleCache(object): ...@@ -1518,7 +1518,7 @@ class ModuleCache(object):
fname = os.path.join(self.dirname, filename, "key.pkl") fname = os.path.join(self.dirname, filename, "key.pkl")
open(fname).close() open(fname).close()
has_key = True has_key = True
except IOError: except OSError:
has_key = False has_key = False
if not has_key: if not has_key:
# Use the compiled file by default # Use the compiled file by default
...@@ -1724,12 +1724,10 @@ def std_lib_dirs_and_libs(): ...@@ -1724,12 +1724,10 @@ def std_lib_dirs_and_libs():
for f, lib in [("libpython27.a", "libpython 1.2")]: for f, lib in [("libpython27.a", "libpython 1.2")]:
if not os.path.exists(os.path.join(libdir, f)): if not os.path.exists(os.path.join(libdir, f)):
print( print(
( "Your Python version is from Canopy. "
"Your Python version is from Canopy. " + "You need to install the package '"
+ "You need to install the package '" + lib
+ lib + "' from Canopy package manager."
+ "' from Canopy package manager."
)
) )
libdirs = [ libdirs = [
# Used in older Canopy # Used in older Canopy
...@@ -1747,12 +1745,10 @@ def std_lib_dirs_and_libs(): ...@@ -1747,12 +1745,10 @@ def std_lib_dirs_and_libs():
] ]
): ):
print( print(
( "Your Python version is from Canopy. "
"Your Python version is from Canopy. " + "You need to install the package '"
+ "You need to install the package '" + lib
+ lib + "' from Canopy package manager."
+ "' from Canopy package manager."
)
) )
python_lib_dirs.insert(0, libdir) python_lib_dirs.insert(0, libdir)
std_lib_dirs_and_libs.data = [libname], python_lib_dirs std_lib_dirs_and_libs.data = [libname], python_lib_dirs
...@@ -1833,15 +1829,15 @@ def gcc_llvm(): ...@@ -1833,15 +1829,15 @@ def gcc_llvm():
# Normally this should not happen as we should not try to # Normally this should not happen as we should not try to
# compile when g++ is not available. If this happen, it # compile when g++ is not available. If this happen, it
# will crash later so supposing it is not llvm is "safe". # will crash later so supposing it is not llvm is "safe".
output = b("") output = b""
gcc_llvm.is_llvm = b("llvm") in output gcc_llvm.is_llvm = b"llvm" in output
return gcc_llvm.is_llvm return gcc_llvm.is_llvm
gcc_llvm.is_llvm = None gcc_llvm.is_llvm = None
class Compiler(object): class Compiler:
""" """
Meta compiler that offer some generic function. Meta compiler that offer some generic function.
...@@ -2077,7 +2073,7 @@ class GCC_compiler(Compiler): ...@@ -2077,7 +2073,7 @@ class GCC_compiler(Compiler):
# as stdin (which is the default) results in the process # as stdin (which is the default) results in the process
# waiting forever without returning. For that reason, # waiting forever without returning. For that reason,
# we use a pipe, and use the empty string as input. # we use a pipe, and use the empty string as input.
(stdout, stderr) = p.communicate(input=b("")) (stdout, stderr) = p.communicate(input=b"")
if p.returncode != 0: if p.returncode != 0:
return None return None
...@@ -2355,7 +2351,7 @@ class GCC_compiler(Compiler): ...@@ -2355,7 +2351,7 @@ class GCC_compiler(Compiler):
line.startswith("#define hypot _hypot") for line in config_h line.startswith("#define hypot _hypot") for line in config_h
): ):
cxxflags.append("-D_hypot=hypot") cxxflags.append("-D_hypot=hypot")
except IOError: except OSError:
pass pass
return cxxflags return cxxflags
...@@ -2472,9 +2468,9 @@ class GCC_compiler(Compiler): ...@@ -2472,9 +2468,9 @@ class GCC_compiler(Compiler):
if dist_suffix is not None and dist_suffix != "": if dist_suffix is not None and dist_suffix != "":
suffix = dist_suffix suffix = dist_suffix
filepath = "%s%s" % (module_name, suffix) filepath = "{}{}".format(module_name, suffix)
else: else:
filepath = "%s.%s" % (module_name, get_lib_extension()) filepath = "{}.{}".format(module_name, get_lib_extension())
lib_filename = os.path.join(location, filepath) lib_filename = os.path.join(location, filepath)
...@@ -2488,10 +2484,13 @@ class GCC_compiler(Compiler): ...@@ -2488,10 +2484,13 @@ class GCC_compiler(Compiler):
# to support path that includes spaces, we need to wrap it with double quotes on Windows # to support path that includes spaces, we need to wrap it with double quotes on Windows
path_wrapper = '"' if os.name == "nt" else "" path_wrapper = '"' if os.name == "nt" else ""
cmd.extend( cmd.extend(
["-I%s%s%s" % (path_wrapper, idir, path_wrapper) for idir in include_dirs] [
"-I{}{}{}".format(path_wrapper, idir, path_wrapper)
for idir in include_dirs
]
) )
cmd.extend( cmd.extend(
["-L%s%s%s" % (path_wrapper, ldir, path_wrapper) for ldir in lib_dirs] ["-L{}{}{}".format(path_wrapper, ldir, path_wrapper) for ldir in lib_dirs]
) )
if hide_symbols and sys.platform != "win32": if hide_symbols and sys.platform != "win32":
# This has been available since gcc 4.0 so we suppose it # This has been available since gcc 4.0 so we suppose it
...@@ -2501,8 +2500,8 @@ class GCC_compiler(Compiler): ...@@ -2501,8 +2500,8 @@ class GCC_compiler(Compiler):
# improved loading times on most platforms (win32 is # improved loading times on most platforms (win32 is
# different, as usual). # different, as usual).
cmd.append("-fvisibility=hidden") cmd.append("-fvisibility=hidden")
cmd.extend(["-o", "%s%s%s" % (path_wrapper, lib_filename, path_wrapper)]) cmd.extend(["-o", "{}{}{}".format(path_wrapper, lib_filename, path_wrapper)])
cmd.append("%s%s%s" % (path_wrapper, cppfilename, path_wrapper)) cmd.append("{}{}{}".format(path_wrapper, cppfilename, path_wrapper))
cmd.extend(["-l%s" % l for l in libs]) cmd.extend(["-l%s" % l for l in libs])
# print >> sys.stderr, 'COMPILING W CMD', cmd # print >> sys.stderr, 'COMPILING W CMD', cmd
_logger.debug("Running cmd: %s", " ".join(cmd)) _logger.debug("Running cmd: %s", " ".join(cmd))
......
...@@ -4,7 +4,6 @@ import shutil ...@@ -4,7 +4,6 @@ import shutil
import numpy as np import numpy as np
import six.moves.cPickle as pickle import six.moves.cPickle as pickle
from six import string_types
import theano import theano
from theano import config from theano import config
...@@ -46,7 +45,7 @@ def cleanup(): ...@@ -46,7 +45,7 @@ def cleanup():
# force the removing of key # force the removing of key
have_npy_abi_version = False have_npy_abi_version = False
break break
elif isinstance(obj, string_types): elif isinstance(obj, str):
if obj.startswith("NPY_ABI_VERSION=0x"): if obj.startswith("NPY_ABI_VERSION=0x"):
have_npy_abi_version = True have_npy_abi_version = True
elif obj.startswith("c_compiler_str="): elif obj.startswith("c_compiler_str="):
...@@ -67,7 +66,7 @@ def cleanup(): ...@@ -67,7 +66,7 @@ def cleanup():
if keydata.key_pkl != filename: if keydata.key_pkl != filename:
keydata.key_pkl = filename keydata.key_pkl = filename
keydata.remove_key(key) keydata.remove_key(key)
except IOError: except OSError:
_logger.error( _logger.error(
"Could not remove file '%s'. To complete " "Could not remove file '%s'. To complete "
"the clean-up, please remove manually " "the clean-up, please remove manually "
...@@ -84,7 +83,7 @@ def cleanup(): ...@@ -84,7 +83,7 @@ def cleanup():
"the directory containing it.", "the directory containing it.",
filename, filename,
) )
except IOError: except OSError:
_logger.error( _logger.error(
"Could not clean up this directory: '%s'. To complete " "Could not clean up this directory: '%s'. To complete "
"the clean-up, please remove it manually.", "the clean-up, please remove it manually.",
...@@ -126,29 +125,21 @@ def print_compiledir_content(): ...@@ -126,29 +125,21 @@ def print_compiledir_content():
try: try:
keydata = pickle.load(file) keydata = pickle.load(file)
ops = list( ops = list(
set( {x for x in flatten(keydata.keys) if isinstance(x, theano.gof.Op)}
[
x
for x in flatten(keydata.keys)
if isinstance(x, theano.gof.Op)
]
)
) )
# Whatever the case, we count compilations for OP classes. # Whatever the case, we count compilations for OP classes.
for op_class in set([op.__class__ for op in ops]): for op_class in {op.__class__ for op in ops}:
table_op_class.setdefault(op_class, 0) table_op_class.setdefault(op_class, 0)
table_op_class[op_class] += 1 table_op_class[op_class] += 1
if len(ops) == 0: if len(ops) == 0:
zeros_op += 1 zeros_op += 1
else: else:
types = list( types = list(
set( {
[ x
x for x in flatten(keydata.keys)
for x in flatten(keydata.keys) if isinstance(x, theano.gof.Type)
if isinstance(x, theano.gof.Type) }
]
)
) )
compile_start = compile_end = float("nan") compile_start = compile_end = float("nan")
for fn in os.listdir(os.path.join(compiledir, dir)): for fn in os.listdir(os.path.join(compiledir, dir)):
...@@ -177,7 +168,7 @@ def print_compiledir_content(): ...@@ -177,7 +168,7 @@ def print_compiledir_content():
nb_keys.setdefault(len(keydata.keys), 0) nb_keys.setdefault(len(keydata.keys), 0)
nb_keys[len(keydata.keys)] += 1 nb_keys[len(keydata.keys)] += 1
except IOError: except OSError:
pass pass
except AttributeError: except AttributeError:
_logger.error("Could not read key file '%s'.", filename) _logger.error("Could not read key file '%s'.", filename)
...@@ -221,16 +212,12 @@ def print_compiledir_content(): ...@@ -221,16 +212,12 @@ def print_compiledir_content():
big_key_files = sorted(big_key_files, key=lambda t: str(t[1])) big_key_files = sorted(big_key_files, key=lambda t: str(t[1]))
big_total_size = sum([sz for _, sz, _ in big_key_files]) big_total_size = sum([sz for _, sz, _ in big_key_files])
print( print(
( "There are directories with key files bigger than %d bytes "
"There are directories with key files bigger than %d bytes " "(they probably contain big tensor constants)" % max_key_file_size
"(they probably contain big tensor constants)" % max_key_file_size
)
) )
print( print(
( "They use %d bytes out of %d (total size used by all key files)"
"They use %d bytes out of %d (total size used by all key files)" "" % (big_total_size, total_key_sizes)
"" % (big_total_size, total_key_sizes)
)
) )
for dir, size, ops in big_key_files: for dir, size, ops in big_key_files:
...@@ -246,10 +233,8 @@ def print_compiledir_content(): ...@@ -246,10 +233,8 @@ def print_compiledir_content():
print(n_k, n_m) print(n_k, n_m)
print() print()
print( print(
( "Skipped %d files that contained 0 op "
"Skipped %d files that contained 0 op " "(are they always theano.scalar ops?)" % zeros_op
"(are they always theano.scalar ops?)" % zeros_op
)
) )
......
...@@ -10,7 +10,6 @@ import time ...@@ -10,7 +10,6 @@ import time
from contextlib import contextmanager from contextlib import contextmanager
import numpy as np import numpy as np
from six import PY3
from theano import config from theano import config
...@@ -283,14 +282,9 @@ def lock(tmp_dir, timeout=notset, min_wait=None, max_wait=None, verbosity=1): ...@@ -283,14 +282,9 @@ def lock(tmp_dir, timeout=notset, min_wait=None, max_wait=None, verbosity=1):
nb_wait += 1 nb_wait += 1
time.sleep(random.uniform(min_wait, max_wait)) time.sleep(random.uniform(min_wait, max_wait))
if PY3:
exception = FileExistsError # noqa
else:
exception = OSError
try: try:
os.mkdir(tmp_dir) os.mkdir(tmp_dir)
except exception: except FileExistsError:
# Error while creating the directory: someone else # Error while creating the directory: someone else
# must have tried at the exact same time. # must have tried at the exact same time.
nb_error += 1 nb_error += 1
...@@ -332,7 +326,7 @@ def refresh_lock(lock_file): ...@@ -332,7 +326,7 @@ def refresh_lock(lock_file):
unique id, using a new (randomly generated) id, which is also returned. unique id, using a new (randomly generated) id, which is also returned.
""" """
unique_id = "%s_%s_%s" % ( unique_id = "{}_{}_{}".format(
os.getpid(), os.getpid(),
"".join([str(random.randint(0, 9)) for i in range(10)]), "".join([str(random.randint(0, 9)) for i in range(10)]),
hostname, hostname,
...@@ -355,7 +349,7 @@ def refresh_lock(lock_file): ...@@ -355,7 +349,7 @@ def refresh_lock(lock_file):
return unique_id return unique_id
class Unlocker(object): class Unlocker:
""" """
Class wrapper around release mechanism so that the lock is automatically Class wrapper around release mechanism so that the lock is automatically
released when the program exits (even when crashing or being interrupted), released when the program exits (even when crashing or being interrupted),
......
...@@ -22,8 +22,6 @@ class ProtocolError(Exception): ...@@ -22,8 +22,6 @@ class ProtocolError(Exception):
""" """
pass
def _contains_cycle(fgraph, orderings): def _contains_cycle(fgraph, orderings):
""" """
...@@ -762,17 +760,17 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa ...@@ -762,17 +760,17 @@ class DestroyHandler(toolbox.Bookkeeper): # noqa
# OPT: pre-compute this on import # OPT: pre-compute this on import
tolerate_same = getattr(app.op, "destroyhandler_tolerate_same", []) tolerate_same = getattr(app.op, "destroyhandler_tolerate_same", [])
assert isinstance(tolerate_same, list) assert isinstance(tolerate_same, list)
tolerated = set( tolerated = {
idx1 for idx0, idx1 in tolerate_same if idx0 == destroyed_idx idx1 for idx0, idx1 in tolerate_same if idx0 == destroyed_idx
) }
tolerated.add(destroyed_idx) tolerated.add(destroyed_idx)
tolerate_aliased = getattr( tolerate_aliased = getattr(
app.op, "destroyhandler_tolerate_aliased", [] app.op, "destroyhandler_tolerate_aliased", []
) )
assert isinstance(tolerate_aliased, list) assert isinstance(tolerate_aliased, list)
ignored = set( ignored = {
idx1 for idx0, idx1 in tolerate_aliased if idx0 == destroyed_idx idx1 for idx0, idx1 in tolerate_aliased if idx0 == destroyed_idx
) }
for i, input in enumerate(app.inputs): for i, input in enumerate(app.inputs):
if i in ignored: if i in ignored:
continue continue
......
...@@ -26,8 +26,6 @@ class InconsistencyError(Exception): ...@@ -26,8 +26,6 @@ class InconsistencyError(Exception):
""" """
pass
class MissingInputError(Exception): class MissingInputError(Exception):
""" """
......
...@@ -7,8 +7,6 @@ from collections import deque ...@@ -7,8 +7,6 @@ from collections import deque
from copy import copy from copy import copy
from itertools import count from itertools import count
from six import integer_types, string_types
import theano import theano
from theano import config from theano import config
from theano.gof.utils import ( from theano.gof.utils import (
...@@ -174,7 +172,7 @@ class Apply(Node): ...@@ -174,7 +172,7 @@ class Apply(Node):
raise ValueError( raise ValueError(
"%s.default_output should be an output index." % self.op "%s.default_output should be an output index." % self.op
) )
elif not isinstance(do, integer_types): elif not isinstance(do, int):
raise ValueError("%s.default_output should be an int or long" % self.op) raise ValueError("%s.default_output should be an int or long" % self.op)
elif do < 0 or do >= len(self.outputs): elif do < 0 or do >= len(self.outputs):
raise ValueError("%s.default_output is out of range." % self.op) raise ValueError("%s.default_output is out of range." % self.op)
...@@ -395,11 +393,11 @@ class Variable(Node): ...@@ -395,11 +393,11 @@ class Variable(Node):
raise TypeError("owner must be an Apply instance", owner) raise TypeError("owner must be an Apply instance", owner)
self.owner = owner self.owner = owner
if index is not None and not isinstance(index, integer_types): if index is not None and not isinstance(index, int):
raise TypeError("index must be an int", index) raise TypeError("index must be an int", index)
self.index = index self.index = index
if name is not None and not isinstance(name, string_types): if name is not None and not isinstance(name, str):
raise TypeError("name must be a string", name) raise TypeError("name must be a string", name)
self.name = name self.name = name
...@@ -1156,7 +1154,7 @@ default_leaf_formatter = str ...@@ -1156,7 +1154,7 @@ default_leaf_formatter = str
def default_node_formatter(op, argstrings): def default_node_formatter(op, argstrings):
return "%s(%s)" % (op.op, ", ".join(argstrings)) return "{}({})".format(op.op, ", ".join(argstrings))
def io_connection_pattern(inputs, outputs): def io_connection_pattern(inputs, outputs):
...@@ -1331,7 +1329,7 @@ def view_roots(r): ...@@ -1331,7 +1329,7 @@ def view_roots(r):
if owner is not None: if owner is not None:
try: try:
view_map = owner.op.view_map view_map = owner.op.view_map
view_map = dict((owner.outputs[o], i) for o, i in view_map.items()) view_map = {owner.outputs[o]: i for o, i in view_map.items()}
except AttributeError: except AttributeError:
return [r] return [r]
if r in view_map: if r in view_map:
......
...@@ -59,11 +59,11 @@ try: ...@@ -59,11 +59,11 @@ try:
if not os.path.exists(init_file): if not os.path.exists(init_file):
try: try:
open(init_file, "w").close() open(init_file, "w").close()
except IOError as e: except OSError as e:
if os.path.exists(init_file): if os.path.exists(init_file):
pass # has already been created pass # has already been created
else: else:
e.args += ("%s exist? %s" % (location, os.path.exists(location)),) e.args += ("{} exist? {}".format(location, os.path.exists(location)),)
raise raise
_need_reload = False _need_reload = False
......
...@@ -4,7 +4,6 @@ from copy import copy, deepcopy ...@@ -4,7 +4,6 @@ from copy import copy, deepcopy
from sys import getsizeof from sys import getsizeof
import numpy as np import numpy as np
from six import reraise
from six.moves import StringIO from six.moves import StringIO
import theano import theano
...@@ -123,7 +122,7 @@ def raise_with_op(node, thunk=None, exc_info=None, storage_map=None): ...@@ -123,7 +122,7 @@ def raise_with_op(node, thunk=None, exc_info=None, storage_map=None):
exc_type, exc_value, exc_trace = exc_info exc_type, exc_value, exc_trace = exc_info
if exc_type == KeyboardInterrupt: if exc_type == KeyboardInterrupt:
# print a simple traceback from KeyboardInterrupt # print a simple traceback from KeyboardInterrupt
reraise(exc_type, exc_value, exc_trace) raise exc_value.with_traceback(exc_trace)
try: try:
trace = node.outputs[0].tag.trace trace = node.outputs[0].tag.trace
except AttributeError: except AttributeError:
...@@ -315,11 +314,11 @@ def raise_with_op(node, thunk=None, exc_info=None, storage_map=None): ...@@ -315,11 +314,11 @@ def raise_with_op(node, thunk=None, exc_info=None, storage_map=None):
detailed_err_msg += ", TotalSize: %s Byte(s)\n" % item[3] detailed_err_msg += ", TotalSize: %s Byte(s)\n" % item[3]
else: else:
detailed_err_msg += "\n" detailed_err_msg += "\n"
detailed_err_msg += " TotalSize: %s Byte(s) %.3f GB\n" % ( detailed_err_msg += " TotalSize: {} Byte(s) {:.3f} GB\n".format(
total_size, total_size,
total_size / 1024.0 / 1024 / 1024, total_size / 1024.0 / 1024 / 1024,
) )
detailed_err_msg += " TotalSize inputs: %s Byte(s) %.3f GB\n" % ( detailed_err_msg += " TotalSize inputs: {} Byte(s) {:.3f} GB\n".format(
total_size_inputs, total_size_inputs,
total_size_inputs / 1024.0 / 1024 / 1024, total_size_inputs / 1024.0 / 1024 / 1024,
) )
...@@ -341,11 +340,10 @@ def raise_with_op(node, thunk=None, exc_info=None, storage_map=None): ...@@ -341,11 +340,10 @@ def raise_with_op(node, thunk=None, exc_info=None, storage_map=None):
) )
# Some exception need extra parameter in inputs. So forget the # Some exception need extra parameter in inputs. So forget the
# extra long error message in that case. # extra long error message in that case.
pass raise exc_value.with_traceback(exc_trace)
reraise(exc_type, exc_value, exc_trace)
class Linker(object): class Linker:
""" """
WRITEME WRITEME
...@@ -434,7 +432,7 @@ class Linker(object): ...@@ -434,7 +432,7 @@ class Linker(object):
# TODO: Move this class to the compile module, where it is used (and for which it exists). # TODO: Move this class to the compile module, where it is used (and for which it exists).
class Container(object): class Container:
""" """
This class joins a variable with its computed value. This class joins a variable with its computed value.
......
...@@ -122,7 +122,7 @@ def compute_test_value(node): ...@@ -122,7 +122,7 @@ def compute_test_value(node):
output.tag.test_value = storage_map[output][0] output.tag.test_value = storage_map[output][0]
class CLinkerObject(object): class CLinkerObject:
""" """
Standard elements of an Op or Type used with the CLinker. Standard elements of an Op or Type used with the CLinker.
...@@ -550,7 +550,7 @@ class CLinkerOp(CLinkerObject): ...@@ -550,7 +550,7 @@ class CLinkerOp(CLinkerObject):
) )
class PureOp(object): class PureOp:
"""A class that models and constructs operations in a graph. """A class that models and constructs operations in a graph.
A `PureOp` instance has several responsibilities: A `PureOp` instance has several responsibilities:
...@@ -842,7 +842,6 @@ class Op(object2, PureOp, CLinkerOp): ...@@ -842,7 +842,6 @@ class Op(object2, PureOp, CLinkerOp):
good to do so. good to do so.
""" """
pass
def make_c_thunk(self, node, storage_map, compute_map, no_recycling): def make_c_thunk(self, node, storage_map, compute_map, no_recycling):
"""Like make_thunk, but will only try to make a C thunk.""" """Like make_thunk, but will only try to make a C thunk."""
...@@ -1263,19 +1262,17 @@ class COp(Op): ...@@ -1263,19 +1262,17 @@ class COp(Op):
section_re = re.compile(r"^#section ([a-zA-Z0-9_]+)$", re.MULTILINE) section_re = re.compile(r"^#section ([a-zA-Z0-9_]+)$", re.MULTILINE)
backward_re = re.compile(r"^THEANO_(APPLY|SUPPORT)_CODE_SECTION$", re.MULTILINE) backward_re = re.compile(r"^THEANO_(APPLY|SUPPORT)_CODE_SECTION$", re.MULTILINE)
# This is the set of allowed markers # This is the set of allowed markers
SECTIONS = set( SECTIONS = {
[ "init_code",
"init_code", "init_code_apply",
"init_code_apply", "init_code_struct",
"init_code_struct", "support_code",
"support_code", "support_code_apply",
"support_code_apply", "support_code_struct",
"support_code_struct", "cleanup_code_struct",
"cleanup_code_struct", "code",
"code", "code_cleanup",
"code_cleanup", }
]
)
@classmethod @classmethod
def get_path(cls, f): def get_path(cls, f):
...@@ -1535,10 +1532,10 @@ class COp(Op): ...@@ -1535,10 +1532,10 @@ class COp(Op):
def get_sub_macros(self, sub): def get_sub_macros(self, sub):
define_macros = [] define_macros = []
undef_macros = [] undef_macros = []
define_macros.append("#define FAIL %s" % (self._lquote_macro(sub["fail"]),)) define_macros.append("#define FAIL {}".format(self._lquote_macro(sub["fail"])))
undef_macros.append("#undef FAIL") undef_macros.append("#undef FAIL")
if "params" in sub: if "params" in sub:
define_macros.append("#define PARAMS %s" % (sub["params"],)) define_macros.append("#define PARAMS {}".format(sub["params"]))
undef_macros.append("#undef PARAMS") undef_macros.append("#undef PARAMS")
return "\n".join(define_macros), "\n".join(undef_macros) return "\n".join(define_macros), "\n".join(undef_macros)
...@@ -1584,7 +1581,7 @@ class COp(Op): ...@@ -1584,7 +1581,7 @@ class COp(Op):
params = "" params = ""
if "params" in sub: if "params" in sub:
params = ", %s" % (sub["params"],) params = ", {}".format(sub["params"])
# Generate the C code # Generate the C code
return """ return """
......
差异被折叠。
...@@ -2,7 +2,7 @@ import copy ...@@ -2,7 +2,7 @@ import copy
import math import math
import sys import sys
from six import StringIO, integer_types from six import StringIO
from theano import config from theano import config
from theano.compat import DefaultOrderedDict from theano.compat import DefaultOrderedDict
...@@ -10,7 +10,7 @@ from theano.gof import opt ...@@ -10,7 +10,7 @@ from theano.gof import opt
from theano.misc.ordered_set import OrderedSet from theano.misc.ordered_set import OrderedSet
class DB(object): class DB:
def __hash__(self): def __hash__(self):
if not hasattr(self, "_optimizer_idx"): if not hasattr(self, "_optimizer_idx"):
self._optimizer_idx = opt._optimizer_idx[0] self._optimizer_idx = opt._optimizer_idx[0]
...@@ -169,7 +169,7 @@ multiple time in a DB. Tryed to register "%s" again under the new name "%s". ...@@ -169,7 +169,7 @@ multiple time in a DB. Tryed to register "%s" again under the new name "%s".
print(" db", self.__db__, file=stream) print(" db", self.__db__, file=stream)
class Query(object): class Query:
""" """
Parameters Parameters
...@@ -296,7 +296,7 @@ class EquilibriumDB(DB): ...@@ -296,7 +296,7 @@ class EquilibriumDB(DB):
""" """
def __init__(self, ignore_newtrees=True, tracks_on_change_inputs=False): def __init__(self, ignore_newtrees=True, tracks_on_change_inputs=False):
super(EquilibriumDB, self).__init__() super().__init__()
self.ignore_newtrees = ignore_newtrees self.ignore_newtrees = ignore_newtrees
self.tracks_on_change_inputs = tracks_on_change_inputs self.tracks_on_change_inputs = tracks_on_change_inputs
self.__final__ = {} self.__final__ = {}
...@@ -307,12 +307,12 @@ class EquilibriumDB(DB): ...@@ -307,12 +307,12 @@ class EquilibriumDB(DB):
cleanup = kwtags.pop("cleanup", False) cleanup = kwtags.pop("cleanup", False)
# An opt should not be final and clean up # An opt should not be final and clean up
assert not (final_opt and cleanup) assert not (final_opt and cleanup)
super(EquilibriumDB, self).register(name, obj, *tags, **kwtags) super().register(name, obj, *tags, **kwtags)
self.__final__[name] = final_opt self.__final__[name] = final_opt
self.__cleanup__[name] = cleanup self.__cleanup__[name] = cleanup
def query(self, *tags, **kwtags): def query(self, *tags, **kwtags):
_opts = super(EquilibriumDB, self).query(*tags, **kwtags) _opts = super().query(*tags, **kwtags)
final_opts = [o for o in _opts if self.__final__.get(o.name, False)] final_opts = [o for o in _opts if self.__final__.get(o.name, False)]
cleanup_opts = [o for o in _opts if self.__cleanup__.get(o.name, False)] cleanup_opts = [o for o in _opts if self.__cleanup__.get(o.name, False)]
opts = [o for o in _opts if o not in final_opts and o not in cleanup_opts] opts = [o for o in _opts if o not in final_opts and o not in cleanup_opts]
...@@ -349,19 +349,19 @@ class SequenceDB(DB): ...@@ -349,19 +349,19 @@ class SequenceDB(DB):
seq_opt = opt.SeqOptimizer seq_opt = opt.SeqOptimizer
def __init__(self, failure_callback=opt.SeqOptimizer.warn): def __init__(self, failure_callback=opt.SeqOptimizer.warn):
super(SequenceDB, self).__init__() super().__init__()
self.__position__ = {} self.__position__ = {}
self.failure_callback = failure_callback self.failure_callback = failure_callback
def register(self, name, obj, position, *tags): def register(self, name, obj, position, *tags):
super(SequenceDB, self).register(name, obj, *tags) super().register(name, obj, *tags)
if position == "last": if position == "last":
if len(self.__position__) == 0: if len(self.__position__) == 0:
self.__position__[name] = 0 self.__position__[name] = 0
else: else:
self.__position__[name] = max(self.__position__.values()) + 1 self.__position__[name] = max(self.__position__.values()) + 1
else: else:
assert isinstance(position, (integer_types, float)) assert isinstance(position, ((int,), float))
self.__position__[name] = position self.__position__[name] = position
def query(self, *tags, **kwtags): def query(self, *tags, **kwtags):
...@@ -373,7 +373,7 @@ class SequenceDB(DB): ...@@ -373,7 +373,7 @@ class SequenceDB(DB):
Only optimizations with position less than the cutoff are returned. Only optimizations with position less than the cutoff are returned.
""" """
opts = super(SequenceDB, self).query(*tags, **kwtags) opts = super().query(*tags, **kwtags)
position_cutoff = kwtags.pop("position_cutoff", config.optdb.position_cutoff) position_cutoff = kwtags.pop("position_cutoff", config.optdb.position_cutoff)
position_dict = self.__position__ position_dict = self.__position__
...@@ -442,7 +442,7 @@ class LocalGroupDB(DB): ...@@ -442,7 +442,7 @@ class LocalGroupDB(DB):
def __init__( def __init__(
self, apply_all_opts=False, profile=False, local_opt=opt.LocalOptGroup self, apply_all_opts=False, profile=False, local_opt=opt.LocalOptGroup
): ):
super(LocalGroupDB, self).__init__() super().__init__()
self.failure_callback = None self.failure_callback = None
self.apply_all_opts = apply_all_opts self.apply_all_opts = apply_all_opts
self.profile = profile self.profile = profile
...@@ -450,7 +450,7 @@ class LocalGroupDB(DB): ...@@ -450,7 +450,7 @@ class LocalGroupDB(DB):
self.local_opt = local_opt self.local_opt = local_opt
def register(self, name, obj, *tags, **kwargs): def register(self, name, obj, *tags, **kwargs):
super(LocalGroupDB, self).register(name, obj, *tags) super().register(name, obj, *tags)
position = kwargs.pop("position", "last") position = kwargs.pop("position", "last")
if position == "last": if position == "last":
if len(self.__position__) == 0: if len(self.__position__) == 0:
...@@ -458,12 +458,12 @@ class LocalGroupDB(DB): ...@@ -458,12 +458,12 @@ class LocalGroupDB(DB):
else: else:
self.__position__[name] = max(self.__position__.values()) + 1 self.__position__[name] = max(self.__position__.values()) + 1
else: else:
assert isinstance(position, (integer_types, float)) assert isinstance(position, ((int,), float))
self.__position__[name] = position self.__position__[name] = position
def query(self, *tags, **kwtags): def query(self, *tags, **kwtags):
# For the new `useless` optimizer # For the new `useless` optimizer
opts = list(super(LocalGroupDB, self).query(*tags, **kwtags)) opts = list(super().query(*tags, **kwtags))
opts.sort(key=lambda obj: (self.__position__[obj.name], obj.name)) opts.sort(key=lambda obj: (self.__position__[obj.name], obj.name))
ret = self.local_opt( ret = self.local_opt(
...@@ -482,7 +482,7 @@ class TopoDB(DB): ...@@ -482,7 +482,7 @@ class TopoDB(DB):
def __init__( def __init__(
self, db, order="in_to_out", ignore_newtrees=False, failure_callback=None self, db, order="in_to_out", ignore_newtrees=False, failure_callback=None
): ):
super(TopoDB, self).__init__() super().__init__()
self.db = db self.db = db
self.order = order self.order = order
self.ignore_newtrees = ignore_newtrees self.ignore_newtrees = ignore_newtrees
......
...@@ -153,13 +153,13 @@ class Params(dict): ...@@ -153,13 +153,13 @@ class Params(dict):
raise TypeError( raise TypeError(
'Params: ParamsType attribute "%s" not in Params args.' % field 'Params: ParamsType attribute "%s" not in Params args.' % field
) )
super(Params, self).__init__(**kwargs) super().__init__(**kwargs)
self.__dict__.update(__params_type__=params_type, __signatures__=None) self.__dict__.update(__params_type__=params_type, __signatures__=None)
def __repr__(self): def __repr__(self):
return "Params(%s)" % ", ".join( return "Params(%s)" % ", ".join(
[ [
("%s:%s:%s" % (k, type(self[k]).__name__, self[k])) ("{}:{}:{}".format(k, type(self[k]).__name__, self[k]))
for k in sorted(self.keys()) for k in sorted(self.keys())
] ]
) )
...@@ -270,14 +270,14 @@ class ParamsType(Type): ...@@ -270,14 +270,14 @@ class ParamsType(Type):
if enum_types: if enum_types:
# We don't want same enum names in different enum types. # We don't want same enum names in different enum types.
if sum(len(t) for t in enum_types) != len( if sum(len(t) for t in enum_types) != len(
set(k for t in enum_types for k in t) {k for t in enum_types for k in t}
): ):
raise AttributeError( raise AttributeError(
"ParamsType: found different enum types with common constants names." "ParamsType: found different enum types with common constants names."
) )
# We don't want same aliases in different enum types. # We don't want same aliases in different enum types.
if sum(len(t.aliases) for t in enum_types) != len( if sum(len(t.aliases) for t in enum_types) != len(
set(alias for t in enum_types for alias in t.aliases) {alias for t in enum_types for alias in t.aliases}
): ):
raise AttributeError( raise AttributeError(
"ParamsType: found different enum types with common constants aliases." "ParamsType: found different enum types with common constants aliases."
...@@ -319,11 +319,14 @@ class ParamsType(Type): ...@@ -319,11 +319,14 @@ class ParamsType(Type):
# Now we can access value of each enum defined inside enum types wrapped into the current ParamsType. # Now we can access value of each enum defined inside enum types wrapped into the current ParamsType.
if key in self.__const_to_enum: if key in self.__const_to_enum:
return self.__const_to_enum[key][key] return self.__const_to_enum[key][key]
return super(ParamsType, self).__getattr__(self, key) return super().__getattr__(self, key)
def __repr__(self): def __repr__(self):
return "ParamsType<%s>" % ", ".join( return "ParamsType<%s>" % ", ".join(
[("%s:%s" % (self.fields[i], self.types[i])) for i in range(self.length)] [
("{}:{}".format(self.fields[i], self.types[i]))
for i in range(self.length)
]
) )
def __eq__(self, other): def __eq__(self, other):
...@@ -345,7 +348,7 @@ class ParamsType(Type): ...@@ -345,7 +348,7 @@ class ParamsType(Type):
types_string = ",".join(str(t) for t in self.types).encode("utf-8") types_string = ",".join(str(t) for t in self.types).encode("utf-8")
fields_hex = hashlib.sha256(fields_string).hexdigest() fields_hex = hashlib.sha256(fields_string).hexdigest()
types_hex = hashlib.sha256(types_string).hexdigest() types_hex = hashlib.sha256(types_string).hexdigest()
return "_Params_%s_%s" % (fields_hex, types_hex) return "_Params_{}_{}".format(fields_hex, types_hex)
def has_type(self, theano_type): def has_type(self, theano_type):
""" """
......
...@@ -139,8 +139,8 @@ def _toposort(edges): ...@@ -139,8 +139,8 @@ def _toposort(edges):
""" """
incoming_edges = reverse_dict(edges) incoming_edges = reverse_dict(edges)
incoming_edges = dict((k, set(val)) for k, val in incoming_edges.items()) incoming_edges = {k: set(val) for k, val in incoming_edges.items()}
S = set((v for v in edges if v not in incoming_edges)) S = {v for v in edges if v not in incoming_edges}
L = [] L = []
while S: while S:
...@@ -189,8 +189,8 @@ def posort(nodes, *cmps): ...@@ -189,8 +189,8 @@ def posort(nodes, *cmps):
[0, 8, 2, 4, 6, 1, 3, 5, 7, 9, 16, 18, 10, 12, 14, 17, 19, 11, 13, 15] [0, 8, 2, 4, 6, 1, 3, 5, 7, 9, 16, 18, 10, 12, 14, 17, 19, 11, 13, 15]
""" """
comes_before = dict((a, set()) for a in nodes) comes_before = {a: set() for a in nodes}
comes_after = dict((a, set()) for a in nodes) comes_after = {a: set() for a in nodes}
def add_links(a, b): # b depends on a def add_links(a, b): # b depends on a
comes_after[a].add(b) comes_after[a].add(b)
......
...@@ -21,8 +21,6 @@ class AlreadyThere(Exception): ...@@ -21,8 +21,6 @@ class AlreadyThere(Exception):
""" """
pass
class ReplacementDidntRemovedError(Exception): class ReplacementDidntRemovedError(Exception):
""" """
...@@ -32,8 +30,6 @@ class ReplacementDidntRemovedError(Exception): ...@@ -32,8 +30,6 @@ class ReplacementDidntRemovedError(Exception):
""" """
pass
class BadOptimization(Exception): class BadOptimization(Exception):
""" """
...@@ -103,7 +99,7 @@ class BadOptimization(Exception): ...@@ -103,7 +99,7 @@ class BadOptimization(Exception):
old_graph=None, old_graph=None,
new_graph=None, new_graph=None,
): ):
super(BadOptimization, self).__init__() super().__init__()
self.old_r = old_r self.old_r = old_r
self.new_r = new_r self.new_r = new_r
self.old_r_val = old_r_val self.old_r_val = old_r_val
...@@ -139,7 +135,7 @@ class BadOptimization(Exception): ...@@ -139,7 +135,7 @@ class BadOptimization(Exception):
return self.full_err return self.full_err
sio = StringIO() sio = StringIO()
val_str_len_limit = 800 val_str_len_limit = 800
print("BadOptimization Error", super(BadOptimization, self).__str__(), file=sio) print("BadOptimization Error", super().__str__(), file=sio)
print(" Variable: id", id(self.new_r), self.new_r, file=sio) print(" Variable: id", id(self.new_r), self.new_r, file=sio)
print(" Op", self.new_r.owner, file=sio) print(" Op", self.new_r.owner, file=sio)
print(" Value Type:", type(self.new_r_val), file=sio) print(" Value Type:", type(self.new_r_val), file=sio)
...@@ -225,7 +221,7 @@ class BadOptimization(Exception): ...@@ -225,7 +221,7 @@ class BadOptimization(Exception):
return sio.getvalue() return sio.getvalue()
class Feature(object): class Feature:
""" """
Base class for FunctionGraph extensions. Base class for FunctionGraph extensions.
...@@ -466,7 +462,9 @@ class Validator(Feature): ...@@ -466,7 +462,9 @@ class Validator(Feature):
r = uf.f_locals.get("r", "") r = uf.f_locals.get("r", "")
reason = uf_info.function reason = uf_info.function
print( print(
"validate failed on node %s.\n Reason: %s, %s" % (r, reason, e) "validate failed on node {}.\n Reason: {}, {}".format(
r, reason, e
)
) )
raise raise
t1 = time.time() t1 = time.time()
...@@ -578,7 +576,9 @@ class ReplaceValidate(History, Validator): ...@@ -578,7 +576,9 @@ class ReplaceValidate(History, Validator):
except Exception as e: except Exception as e:
fgraph.revert(chk) fgraph.revert(chk)
if verbose: if verbose:
print("validate failed on node %s.\n Reason: %s, %s" % (r, reason, e)) print(
"validate failed on node {}.\n Reason: {}, {}".format(r, reason, e)
)
raise raise
if config.scan.debug: if config.scan.debug:
scans2 = [ scans2 = [
...@@ -731,15 +731,15 @@ class PrintListener(Feature): ...@@ -731,15 +731,15 @@ class PrintListener(Feature):
def on_import(self, fgraph, node, reason): def on_import(self, fgraph, node, reason):
if self.active: if self.active:
print("-- importing: %s, reason: %s" % (node, reason)) print("-- importing: {}, reason: {}".format(node, reason))
def on_prune(self, fgraph, node, reason): def on_prune(self, fgraph, node, reason):
if self.active: if self.active:
print("-- pruning: %s, reason: %s" % (node, reason)) print("-- pruning: {}, reason: {}".format(node, reason))
def on_change_input(self, fgraph, node, i, r, new_r, reason=None): def on_change_input(self, fgraph, node, i, r, new_r, reason=None):
if self.active: if self.active:
print("-- changing (%s.inputs[%s]) from %s to %s" % (node, i, r, new_r)) print("-- changing ({}.inputs[{}]) from {} to {}".format(node, i, r, new_r))
class PreserveNames(Feature): class PreserveNames(Feature):
...@@ -918,9 +918,9 @@ def is_same_graph(var1, var2, givens=None): ...@@ -918,9 +918,9 @@ def is_same_graph(var1, var2, givens=None):
for to_replace, replace_by in givens.items(): for to_replace, replace_by in givens.items():
# Map a substitution variable to the computational graphs it # Map a substitution variable to the computational graphs it
# belongs to. # belongs to.
inside = dict( inside = {
(v, [in_var(v, k) for k in (1, 2)]) for v in (to_replace, replace_by) v: [in_var(v, k) for k in (1, 2)] for v in (to_replace, replace_by)
) }
if ( if (
inside[to_replace][0] inside[to_replace][0]
and not inside[to_replace][1] and not inside[to_replace][1]
......
...@@ -10,8 +10,6 @@ import ctypes ...@@ -10,8 +10,6 @@ import ctypes
import platform import platform
import re import re
from six import string_types
import theano import theano
from theano import change_flags from theano import change_flags
from theano.gof import graph, utils from theano.gof import graph, utils
...@@ -272,7 +270,7 @@ class CLinkerType(CLinkerObject): ...@@ -272,7 +270,7 @@ class CLinkerType(CLinkerObject):
return () return ()
class PureType(object): class PureType:
""" """
Interface specification for variable type instances. Interface specification for variable type instances.
...@@ -707,10 +705,10 @@ class CDataType(Type): ...@@ -707,10 +705,10 @@ class CDataType(Type):
extra_support_code="", extra_support_code="",
version=None, version=None,
): ):
assert isinstance(ctype, string_types) assert isinstance(ctype, str)
self.ctype = ctype self.ctype = ctype
if freefunc is not None: if freefunc is not None:
assert isinstance(freefunc, string_types) assert isinstance(freefunc, str)
self.freefunc = freefunc self.freefunc = freefunc
self.headers = tuple(headers) self.headers = tuple(headers)
self.header_dirs = tuple(header_dirs) self.header_dirs = tuple(header_dirs)
...@@ -848,7 +846,7 @@ if (py_%(name)s == NULL) { %(freefunc)s(%(name)s); } ...@@ -848,7 +846,7 @@ if (py_%(name)s == NULL) { %(freefunc)s(%(name)s); }
return v return v
def __str__(self): def __str__(self):
return "%s{%s}" % (self.__class__.__name__, self.ctype) return "{}{{{}}}".format(self.__class__.__name__, self.ctype)
def __setstate__(self, dct): def __setstate__(self, dct):
self.__dict__.update(dct) self.__dict__.update(dct)
...@@ -1034,7 +1032,7 @@ class EnumType(Type, dict): ...@@ -1034,7 +1032,7 @@ class EnumType(Type, dict):
raise TypeError( raise TypeError(
"%s: some aliases have same names as constants." % type(self).__name__ "%s: some aliases have same names as constants." % type(self).__name__
) )
super(EnumType, self).__init__(**kwargs) super().__init__(**kwargs)
def fromalias(self, alias): def fromalias(self, alias):
""" """
...@@ -1060,11 +1058,11 @@ class EnumType(Type, dict): ...@@ -1060,11 +1058,11 @@ class EnumType(Type, dict):
names_to_aliases = {constant_name: "" for constant_name in self} names_to_aliases = {constant_name: "" for constant_name in self}
for alias in self.aliases: for alias in self.aliases:
names_to_aliases[self.aliases[alias]] = "(%s)" % alias names_to_aliases[self.aliases[alias]] = "(%s)" % alias
return "%s<%s>(%s)" % ( return "{}<{}>({})".format(
type(self).__name__, type(self).__name__,
self.ctype, self.ctype,
", ".join( ", ".join(
"%s%s:%s" % (k, names_to_aliases[k], self[k]) "{}{}:{}".format(k, names_to_aliases[k], self[k])
for k in sorted(self.keys()) for k in sorted(self.keys())
), ),
) )
...@@ -1298,7 +1296,7 @@ class EnumList(EnumType): ...@@ -1298,7 +1296,7 @@ class EnumList(EnumType):
kwargs.update(ctype=ctype) kwargs.update(ctype=ctype)
if cname is not None: if cname is not None:
kwargs.update(cname=cname) kwargs.update(cname=cname)
super(EnumList, self).__init__(**kwargs) super().__init__(**kwargs)
class CEnumType(EnumList): class CEnumType(EnumList):
...@@ -1336,7 +1334,7 @@ class CEnumType(EnumList): ...@@ -1336,7 +1334,7 @@ class CEnumType(EnumList):
return self.pyint_compat_code + self.c_to_string() return self.pyint_compat_code + self.c_to_string()
def c_extract(self, name, sub, check_input=True): def c_extract(self, name, sub, check_input=True):
swapped_dict = dict((v, k) for (k, v) in self.items()) swapped_dict = {v: k for (k, v) in self.items()}
# swapped_dict's keys are integers. # swapped_dict's keys are integers.
return """ return """
...@@ -1360,4 +1358,4 @@ class CEnumType(EnumList): ...@@ -1360,4 +1358,4 @@ class CEnumType(EnumList):
) )
def c_code_cache_version(self): def c_code_cache_version(self):
return (1, super(CEnumType, self).c_code_cache_version()) return (1, super().c_code_cache_version())
...@@ -19,7 +19,7 @@ from theano.gof.utils import ANY_TYPE, FALL_THROUGH, comm_guard ...@@ -19,7 +19,7 @@ from theano.gof.utils import ANY_TYPE, FALL_THROUGH, comm_guard
################################ ################################
class Variable(object): class Variable:
""" """
Serves as a base class of variables for the purpose of unification. Serves as a base class of variables for the purpose of unification.
"Unification" here basically means matching two patterns, see the "Unification" here basically means matching two patterns, see the
...@@ -46,7 +46,9 @@ class Variable(object): ...@@ -46,7 +46,9 @@ class Variable(object):
return ( return (
self.__class__.__name__ self.__class__.__name__
+ "(" + "("
+ ", ".join("%s=%s" % (key, value) for key, value in self.__dict__.items()) + ", ".join(
"{}={}".format(key, value) for key, value in self.__dict__.items()
)
+ ")" + ")"
) )
...@@ -60,8 +62,6 @@ class FreeVariable(Variable): ...@@ -60,8 +62,6 @@ class FreeVariable(Variable):
""" """
pass
class BoundVariable(Variable): class BoundVariable(Variable):
""" """
...@@ -70,7 +70,7 @@ class BoundVariable(Variable): ...@@ -70,7 +70,7 @@ class BoundVariable(Variable):
""" """
def __init__(self, name, value): def __init__(self, name, value):
super(BoundVariable, self).__init__(name=name) super().__init__(name=name)
self.value = value self.value = value
...@@ -82,7 +82,7 @@ class OrVariable(Variable): ...@@ -82,7 +82,7 @@ class OrVariable(Variable):
""" """
def __init__(self, name, options): def __init__(self, name, options):
super(OrVariable, self).__init__(name=name) super().__init__(name=name)
self.options = options self.options = options
...@@ -94,7 +94,7 @@ class NotVariable(Variable): ...@@ -94,7 +94,7 @@ class NotVariable(Variable):
""" """
def __init__(self, name, not_options): def __init__(self, name, not_options):
super(NotVariable, self).__init__(name=name) super().__init__(name=name)
self.not_options = not_options self.not_options = not_options
......
...@@ -3,7 +3,6 @@ import sys ...@@ -3,7 +3,6 @@ import sys
import traceback import traceback
import numpy as np import numpy as np
from six import integer_types, string_types, with_metaclass
from six.moves import StringIO from six.moves import StringIO
from theano import config from theano import config
...@@ -161,8 +160,6 @@ undef = object() ...@@ -161,8 +160,6 @@ undef = object()
class TestValueError(Exception): class TestValueError(Exception):
"""Base exception class for all test value errors.""" """Base exception class for all test value errors."""
pass
class MethodNotDefined(Exception): class MethodNotDefined(Exception):
""" """
...@@ -173,8 +170,6 @@ class MethodNotDefined(Exception): ...@@ -173,8 +170,6 @@ class MethodNotDefined(Exception):
""" """
pass
class MetaObject(type): class MetaObject(type):
def __new__(cls, name, bases, dct): def __new__(cls, name, bases, dct):
...@@ -182,7 +177,7 @@ class MetaObject(type): ...@@ -182,7 +177,7 @@ class MetaObject(type):
if props is not None: if props is not None:
if not isinstance(props, tuple): if not isinstance(props, tuple):
raise TypeError("__props__ has to be a tuple") raise TypeError("__props__ has to be a tuple")
if not all(isinstance(p, string_types) for p in props): if not all(isinstance(p, str) for p in props):
raise TypeError("elements of __props__ have to be strings") raise TypeError("elements of __props__ have to be strings")
def _props(self): def _props(self):
...@@ -201,7 +196,7 @@ class MetaObject(type): ...@@ -201,7 +196,7 @@ class MetaObject(type):
least all the original props. least all the original props.
""" """
return dict([(a, getattr(self, a)) for a in props]) return {a: getattr(self, a) for a in props}
dct["_props_dict"] = _props_dict dct["_props_dict"] = _props_dict
...@@ -225,14 +220,16 @@ class MetaObject(type): ...@@ -225,14 +220,16 @@ class MetaObject(type):
if len(props) == 0: if len(props) == 0:
def __str__(self): def __str__(self):
return "%s" % (self.__class__.__name__,) return "{}".format(self.__class__.__name__)
else: else:
def __str__(self): def __str__(self):
return "%s{%s}" % ( return "{}{{{}}}".format(
self.__class__.__name__, self.__class__.__name__,
", ".join("%s=%r" % (p, getattr(self, p)) for p in props), ", ".join(
"{}={!r}".format(p, getattr(self, p)) for p in props
),
) )
dct["__str__"] = __str__ dct["__str__"] = __str__
...@@ -240,14 +237,14 @@ class MetaObject(type): ...@@ -240,14 +237,14 @@ class MetaObject(type):
return type.__new__(cls, name, bases, dct) return type.__new__(cls, name, bases, dct)
class object2(with_metaclass(MetaObject, object)): class object2(metaclass=MetaObject):
__slots__ = [] __slots__ = []
def __ne__(self, other): def __ne__(self, other):
return not self == other return not self == other
class Scratchpad(object): class Scratchpad:
def clear(self): def clear(self):
self.__dict__.clear() self.__dict__.clear()
...@@ -264,7 +261,7 @@ class Scratchpad(object): ...@@ -264,7 +261,7 @@ class Scratchpad(object):
def info(self): def info(self):
print("<theano.gof.utils.scratchpad instance at %i>" % id(self)) print("<theano.gof.utils.scratchpad instance at %i>" % id(self))
for k, v in self.__dict__.items(): for k, v in self.__dict__.items():
print(" %s: %s" % (k, v)) print(" {}: {}".format(k, v))
class ValidatingScratchpad(Scratchpad): class ValidatingScratchpad(Scratchpad):
...@@ -330,7 +327,7 @@ def deprecated(filename, msg=""): ...@@ -330,7 +327,7 @@ def deprecated(filename, msg=""):
def g(*args, **kwargs): def g(*args, **kwargs):
if printme[0]: if printme[0]:
print("WARNING: %s.%s deprecated. %s" % (filename, f.__name__, msg)) print("WARNING: {}.{} deprecated. {}".format(filename, f.__name__, msg))
printme[0] = False printme[0] = False
return f(*args, **kwargs) return f(*args, **kwargs)
...@@ -404,7 +401,7 @@ def toposort(prereqs_d): ...@@ -404,7 +401,7 @@ def toposort(prereqs_d):
for x, prereqs in prereqs_d.items(): for x, prereqs in prereqs_d.items():
for prereq in prereqs: for prereq in prereqs:
postreqs_d.setdefault(prereq, set()).add(x) postreqs_d.setdefault(prereq, set()).add(x)
next = set([k for k in prereqs_d if not prereqs_d[k]]) next = {k for k in prereqs_d if not prereqs_d[k]}
while next: while next:
bases = next bases = next
next = set() next = set()
...@@ -449,7 +446,7 @@ RETRY = Keyword("RETRY", False) ...@@ -449,7 +446,7 @@ RETRY = Keyword("RETRY", False)
FAILURE = Keyword("FAILURE", False) FAILURE = Keyword("FAILURE", False)
simple_types = integer_types + string_types + (float, bool, None.__class__, Keyword) simple_types = (int, str, float, bool, type(None), Keyword)
ANY_TYPE = Keyword("ANY_TYPE") ANY_TYPE = Keyword("ANY_TYPE")
......
...@@ -30,8 +30,8 @@ def calculate_reallocate_info(order, fgraph, storage_map, compute_map_re, depend ...@@ -30,8 +30,8 @@ def calculate_reallocate_info(order, fgraph, storage_map, compute_map_re, depend
for var in fgraph.variables: for var in fgraph.variables:
viewed_by[var] = [] viewed_by[var] = []
view_of = {} view_of = {}
pre_allocated = set([]) pre_allocated = set()
allocated = set([]) allocated = set()
for idx in range(len(order)): for idx in range(len(order)):
node = order[idx] node = order[idx]
...@@ -120,7 +120,7 @@ def calculate_reallocate_info(order, fgraph, storage_map, compute_map_re, depend ...@@ -120,7 +120,7 @@ def calculate_reallocate_info(order, fgraph, storage_map, compute_map_re, depend
return reallocated_info return reallocated_info
class VM(object): class VM:
""" """
A VM object's __call__ method evaluates a Theano program. A VM object's __call__ method evaluates a Theano program.
...@@ -273,7 +273,7 @@ class LoopGC(VM): ...@@ -273,7 +273,7 @@ class LoopGC(VM):
""" """
def __init__(self, nodes, thunks, pre_call_clear, post_thunk_clear): def __init__(self, nodes, thunks, pre_call_clear, post_thunk_clear):
super(LoopGC, self).__init__(nodes, thunks, pre_call_clear) super().__init__(nodes, thunks, pre_call_clear)
self.post_thunk_clear = post_thunk_clear self.post_thunk_clear = post_thunk_clear
# Some other part of Theano query that information # Some other part of Theano query that information
self.allow_gc = True self.allow_gc = True
...@@ -353,7 +353,7 @@ class Stack(VM): ...@@ -353,7 +353,7 @@ class Stack(VM):
callback=None, callback=None,
callback_input=None, callback_input=None,
): ):
super(Stack, self).__init__(nodes, thunks, pre_call_clear) super().__init__(nodes, thunks, pre_call_clear)
self.allow_gc = allow_gc self.allow_gc = allow_gc
self.message = "" self.message = ""
...@@ -712,7 +712,6 @@ except (OSError, theano.gof.cmodule.MissingGXX) as e: ...@@ -712,7 +712,6 @@ except (OSError, theano.gof.cmodule.MissingGXX) as e:
assert not [x for x in _config_var_list if x.fullname == "linker"][ assert not [x for x in _config_var_list if x.fullname == "linker"][
0 0
].default.startswith("cvm"), e ].default.startswith("cvm"), e
pass
class VM_Linker(link.LocalLinker): class VM_Linker(link.LocalLinker):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论