提交 86dbc392 authored 作者: Brandon T. Willard's avatar Brandon T. Willard

Apply pyupgrade to top-level modules in theano package

上级 d6477d58
......@@ -84,7 +84,7 @@ def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, env=
stderr=(subprocess.PIPE if hide_stderr else None),
)
break
except EnvironmentError:
except OSError:
e = sys.exc_info()[1]
if e.errno == errno.ENOENT:
continue
......@@ -94,7 +94,7 @@ def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, env=
return None, None
else:
if verbose:
print("unable to find command, tried %s" % (commands,))
print("unable to find command, tried {}".format(commands))
return None, None
stdout = p.communicate()[0].strip().decode()
if p.returncode != 0:
......@@ -145,7 +145,7 @@ def git_get_keywords(versionfile_abs):
# _version.py.
keywords = {}
try:
f = open(versionfile_abs, "r")
f = open(versionfile_abs)
for line in f.readlines():
if line.strip().startswith("git_refnames ="):
mo = re.search(r'=\s*"(.*)"', line)
......@@ -160,7 +160,7 @@ def git_get_keywords(versionfile_abs):
if mo:
keywords["date"] = mo.group(1)
f.close()
except EnvironmentError:
except OSError:
pass
return keywords
......@@ -184,11 +184,11 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose):
if verbose:
print("keywords are unexpanded, not using")
raise NotThisMethod("unexpanded keywords, not a git-archive tarball")
refs = set([r.strip() for r in refnames.strip("()").split(",")])
refs = {r.strip() for r in refnames.strip("()").split(",")}
# starting in git-1.8.3, tags are listed as "tag: foo-1.0" instead of
# just "foo-1.0". If we see a "tag: " prefix, prefer those.
TAG = "tag: "
tags = set([r[len(TAG) :] for r in refs if r.startswith(TAG)])
tags = {r[len(TAG) :] for r in refs if r.startswith(TAG)}
if not tags:
# Either we're using git < 1.8.3, or there really are no tags. We use
# a heuristic: assume all version tags have a digit. The old git %d
......@@ -197,7 +197,7 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose):
# between branches and tags. By ignoring refnames without digits, we
# filter out many common branch names like "release" and
# "stabilization", as well as "HEAD" and "master".
tags = set([r for r in refs if re.search(r"\d", r)])
tags = {r for r in refs if re.search(r"\d", r)}
if verbose:
print("discarding '%s', no digits" % ",".join(refs - tags))
if verbose:
......@@ -300,7 +300,7 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command):
if verbose:
fmt = "tag '%s' doesn't start with prefix '%s'"
print(fmt % (full_tag, tag_prefix))
pieces["error"] = "tag '%s' doesn't start with prefix '%s'" % (
pieces["error"] = "tag '{}' doesn't start with prefix '{}'".format(
full_tag,
tag_prefix,
)
......
......@@ -5,7 +5,7 @@
from collections import OrderedDict
# Python 3.x compatibility
from six import PY3, BytesIO, b, next
from six import PY3, BytesIO, b
from six.moves import configparser
from six.moves import reload_module as reload
......@@ -24,59 +24,44 @@ except ImportError:
__all__ = ["PY3", "b", "BytesIO", "next", "configparser", "reload"]
if PY3:
from operator import truediv as operator_div
from operator import truediv as operator_div
# In python 3.x, when an exception is reraised it saves original
# exception in its args, therefore in order to find the actual
# message, we need to unpack arguments recursively.
def exc_message(e):
msg = e.args[0]
if isinstance(msg, Exception):
return exc_message(msg)
return msg
def cmp(x, y):
"""Return -1 if x < y, 0 if x == y, 1 if x > y."""
return (x > y) - (x < y)
# In python 3.x, when an exception is reraised it saves original
# exception in its args, therefore in order to find the actual
# message, we need to unpack arguments recursively.
def exc_message(e):
msg = e.args[0]
if isinstance(msg, Exception):
return exc_message(msg)
return msg
def get_unbound_function(unbound):
# Op.make_thunk isn't bound, so don't have a __func__ attr.
# But bound method, have a __func__ method that point to the
# not bound method. That is what we want.
if hasattr(unbound, "__func__"):
return unbound.__func__
return unbound
def decode(x):
return x.decode()
def cmp(x, y):
"""Return -1 if x < y, 0 if x == y, 1 if x > y."""
return (x > y) - (x < y)
def decode_iter(itr):
for x in itr:
yield x.decode()
def decode_with(x, encoding):
return x.decode(encoding)
def get_unbound_function(unbound):
# Op.make_thunk isn't bound, so don't have a __func__ attr.
# But bound method, have a __func__ method that point to the
# not bound method. That is what we want.
if hasattr(unbound, "__func__"):
return unbound.__func__
return unbound
else:
from operator import div as operator_div
def decode(x):
return x.decode()
from six import get_unbound_function
def exc_message(e):
return e[0]
def decode_iter(itr):
for x in itr:
yield x.decode()
cmp = cmp
def decode(x):
return x
def decode_iter(x):
return x
def decode_with(x, encoding):
return x
def decode_with(x, encoding):
return x.decode(encoding)
__all__ += [
......
......@@ -10,7 +10,6 @@ import textwrap
import warnings
import numpy as np
from six import string_types
import theano
from theano.compat import maybe_add_to_os_environ_pathlist
......@@ -142,18 +141,16 @@ class DeviceParam(ConfigParam):
)
else:
raise ValueError(
(
'Invalid value ("%s") for configuration '
'variable "%s". Valid options start with '
'one of "cpu", "opencl" or "cuda".' % (val, self.fullname)
)
'Invalid value ("%s") for configuration '
'variable "%s". Valid options start with '
'one of "cpu", "opencl" or "cuda".' % (val, self.fullname)
)
over = kwargs.get("allow_override", True)
super(DeviceParam, self).__init__(default, filter, over)
super().__init__(default, filter, over)
def __str__(self):
return "%s (%s, opencl*, cuda*) " % (self.fullname, self.default)
return "{} ({}, opencl*, cuda*) ".format(self.fullname, self.default)
AddConfigVar(
......@@ -211,13 +208,13 @@ class ContextsParam(ConfigParam):
for v in val.split(";"):
s = v.split("->")
if len(s) != 2:
raise ValueError("Malformed context map: %s" % (v,))
raise ValueError("Malformed context map: {}".format(v))
if (
s[0] == "cpu"
or s[0].startswith("cuda")
or s[0].startswith("opencl")
):
raise ValueError("Cannot use %s as context name" % (s[0],))
raise ValueError("Cannot use {} as context name".format(s[0]))
return val
ConfigParam.__init__(self, "", filter, False)
......@@ -1409,7 +1406,7 @@ AddConfigVar(
def is_valid_check_preallocated_output_param(param):
if not isinstance(param, string_types):
if not isinstance(param, str):
return False
valid = [
"initial",
......@@ -1821,7 +1818,7 @@ def default_blas_ldflags():
# we just pass the whole ldflags as the -l
# options part.
[
"-L%s%s%s" % (path_wrapper, l, path_wrapper)
"-L{}{}{}".format(path_wrapper, l, path_wrapper)
for l in blas_info.get("library_dirs", [])
]
+ ["-l%s" % l for l in blas_info.get("libraries", [])]
......@@ -1902,7 +1899,7 @@ def try_blas_flag(flags):
path_wrapper = '"' if os.name == "nt" else ""
cflags.extend(
[
"-L%s%s%s" % (path_wrapper, d, path_wrapper)
"-L{}{}{}".format(path_wrapper, d, path_wrapper)
for d in theano.gof.cmodule.std_lib_dirs()
]
)
......@@ -2311,11 +2308,11 @@ def filter_compiledir(path):
if not os.path.exists(init_file):
try:
open(init_file, "w").close()
except IOError as e:
except OSError as e:
if os.path.exists(init_file):
pass # has already been created
else:
e.args += ("%s exist? %s" % (path, os.path.exists(path)),)
e.args += ("{} exist? {}".format(path, os.path.exists(path)),)
raise
return path
......@@ -2390,4 +2387,4 @@ AddConfigVar(
# Check if there are remaining flags provided by the user through THEANO_FLAGS.
for key in THEANO_FLAGS_DICT.keys():
warnings.warn("Theano does not recognise this flag: {0}".format(key))
warnings.warn("Theano does not recognise this flag: {}".format(key))
......@@ -10,7 +10,7 @@ import sys
import warnings
from functools import wraps
from six import PY3, StringIO, string_types
from six import PY3, StringIO
import theano
from theano.compat import configparser as ConfigParser
......@@ -95,7 +95,7 @@ theano_raw_cfg = ConfigParser.RawConfigParser()
theano_raw_cfg.read(config_files)
class change_flags(object):
class change_flags:
"""
Use this as a decorator or context manager to change the value of
Theano config variables.
......@@ -204,12 +204,12 @@ def get_config_hash():
)
return theano.gof.utils.hash_from_code(
"\n".join(
["%s = %s" % (cv.fullname, cv.__get__(True, None)) for cv in all_opts]
["{} = {}".format(cv.fullname, cv.__get__(True, None)) for cv in all_opts]
)
)
class TheanoConfigParser(object):
class TheanoConfigParser:
# properties are installed by AddConfigVar
_i_am_a_config_class = True
......@@ -276,7 +276,7 @@ def AddConfigVar(name, doc, configparam, root=config, in_c_key=True):
if not hasattr(root, sections[0]):
# every internal node in the config tree is an instance of its own
# unique class
class SubObj(object):
class SubObj:
_i_am_a_config_class = True
setattr(root.__class__, sections[0], SubObj())
......@@ -312,7 +312,7 @@ def AddConfigVar(name, doc, configparam, root=config, in_c_key=True):
_config_var_list.append(configparam)
class ConfigParam(object):
class ConfigParam:
def __init__(self, default, filter=None, allow_override=True):
"""
If allow_override is False, we can't change the value after the import
......@@ -368,7 +368,7 @@ class EnumStr(ConfigParam):
# All options should be strings
for val in self.all:
if not isinstance(val, string_types):
if not isinstance(val, str):
raise ValueError(
"Valid values for an EnumStr parameter " "should be strings",
val,
......@@ -384,17 +384,15 @@ class EnumStr(ConfigParam):
return val
else:
raise ValueError(
(
'Invalid value ("%s") for configuration variable "%s". '
"Valid options are %s" % (val, self.fullname, self.all)
)
'Invalid value ("%s") for configuration variable "%s". '
"Valid options are %s" % (val, self.fullname, self.all)
)
over = kwargs.get("allow_override", True)
super(EnumStr, self).__init__(default, filter, over)
super().__init__(default, filter, over)
def __str__(self):
return "%s (%s) " % (self.fullname, self.all)
return "{} ({}) ".format(self.fullname, self.all)
class TypedParam(ConfigParam):
......@@ -414,10 +412,10 @@ class TypedParam(ConfigParam):
)
return cast_val
super(TypedParam, self).__init__(default, filter, allow_override=allow_override)
super().__init__(default, filter, allow_override=allow_override)
def __str__(self):
return "%s (%s) " % (self.fullname, self.mytype)
return "{} ({}) ".format(self.fullname, self.mytype)
def StrParam(default, is_valid=None, allow_override=True):
......
......@@ -130,20 +130,16 @@ class DisconnectedType(theano.gof.type.Type):
def filter(self, data, strict=False, allow_downcast=None):
raise AssertionError(
(
"If you're assigning to a DisconnectedType you're"
" doing something wrong. It should only be used as"
" a symbolic placeholder."
)
"If you're assigning to a DisconnectedType you're"
" doing something wrong. It should only be used as"
" a symbolic placeholder."
)
def fiter_variable(self, other):
raise AssertionError(
(
"If you're assigning to a DisconnectedType you're"
" doing something wrong. It should only be used as"
" a symbolic placeholder."
)
"If you're assigning to a DisconnectedType you're"
" doing something wrong. It should only be used as"
" a symbolic placeholder."
)
def may_share_memory(a, b):
......@@ -151,11 +147,9 @@ class DisconnectedType(theano.gof.type.Type):
def value_eq(a, b, force_same_dtype=True):
raise AssertionError(
(
"If you're assigning to a DisconnectedType you're"
" doing something wrong. It should only be used as"
" a symbolic placeholder."
)
"If you're assigning to a DisconnectedType you're"
" doing something wrong. It should only be used as"
" a symbolic placeholder."
)
def __str__(self):
......@@ -846,7 +840,7 @@ def _node_to_pattern(node):
raise TypeError(
"%s.connection_pattern should return" % node.op
+ " a list of lists, but element %d" % ii
+ "is %s of type %s." % (output_pattern, type(output_pattern))
+ "is {} of type {}.".format(output_pattern, type(output_pattern))
)
else:
connection_pattern = [[True for output in node.outputs] for ipt in node.inputs]
......@@ -933,7 +927,7 @@ def _populate_var_to_app_to_idx(outputs, wrt, consider_constant):
# Note: we need to revisit the apply nodes repeatedly, because
# different outputs of the apply node are connected to
# different subsets of the inputs.
accounted_for = set([])
accounted_for = set()
def account_for(var):
# Don't visit the same variable twice
......@@ -984,7 +978,7 @@ def _populate_var_to_app_to_idx(outputs, wrt, consider_constant):
# determine which variables have elements of wrt as a true
# ancestor. Do this with an upward pass starting from wrt,
# following only true connections
visited = set([])
visited = set()
def visit(var):
if var in visited:
......@@ -1458,7 +1452,7 @@ def _populate_grad_dict(var_to_app_to_idx, grad_dict, wrt, cost_name=None):
grad_dict[var] = disconnected_type()
if cost_name is not None and var.name is not None:
grad_dict[var].name = "(d%s/d%s)" % (cost_name, var.name)
grad_dict[var].name = "(d{}/d{})".format(cost_name, var.name)
else:
# this variable isn't connected to the cost in the
# computational graph
......@@ -1494,7 +1488,7 @@ def _float_ones_like(x):
return x.ones_like(dtype=dtype)
class numeric_grad(object):
class numeric_grad:
"""
Compute the numeric derivative of a scalar-valued function at a particular
point.
......@@ -1818,13 +1812,11 @@ def verify_grad(
if rng is None:
raise TypeError(
(
"rng should be a valid instance of "
"numpy.random.RandomState. You may "
"want to use tests.unittest"
"_tools.verify_grad instead of "
"theano.gradient.verify_grad."
)
"rng should be a valid instance of "
"numpy.random.RandomState. You may "
"want to use tests.unittest"
"_tools.verify_grad instead of "
"theano.gradient.verify_grad."
)
# We allow input downcast in function, because numeric_grad works in the
......@@ -1853,7 +1845,7 @@ def verify_grad(
if isinstance(o_output, list):
raise NotImplementedError(
("cant (yet) autotest gradient of fun " "with multiple outputs")
"cant (yet) autotest gradient of fun " "with multiple outputs"
)
# we could make loop over outputs making random projections R for each,
# but this doesn't handle the case where not all the outputs are
......
......@@ -190,11 +190,9 @@ class IfElse(Op):
)
if c.ndim > 0:
raise TypeError(
(
"Condition given to the op has to be a scalar "
"with 0 standing for False, anything else "
"for True"
)
"Condition given to the op has to be a scalar "
"with 0 standing for False, anything else "
"for True"
)
return Apply(self, [c] + list(args), [t.type() for t in ts])
......@@ -401,13 +399,11 @@ def ifelse(condition, then_branch, else_branch, name=None):
if len(then_branch) != len(else_branch):
raise ValueError(
(
"The number of values on the `then` branch"
" should have the same number of variables as "
"the `else` branch : (variables on `then` "
"%d" % len(then_branch) + ", variables on `else` "
"%d" % len(else_branch) + ")"
)
"The number of values on the `then` branch"
" should have the same number of variables as "
"the `else` branch : (variables on `then` "
"%d" % len(then_branch) + ", variables on `else` "
"%d" % len(else_branch) + ")"
)
new_ifelse = IfElse(n_outs=len(then_branch), as_view=False, gpu=False, name=name)
......
......@@ -2,7 +2,7 @@ import os
import sys
class PathParser(object):
class PathParser:
"""
Class that allows to modify system's PATH environment variable
at runtime. Currently used in ``theano.gpuarray.dnn`` module
......
......@@ -12,7 +12,6 @@ from copy import copy
from functools import reduce
import numpy as np
from six import integer_types, string_types
from six.moves import StringIO
import theano
......@@ -55,7 +54,7 @@ except ImportError:
_logger = logging.getLogger("theano.printing")
VALID_ASSOC = set(["left", "right", "either"])
VALID_ASSOC = {"left", "right", "either"}
def debugprint(
......@@ -121,7 +120,7 @@ def debugprint(
to the Apply's identifier, to indicate which output a line corresponds to.
"""
if not isinstance(depth, integer_types):
if not isinstance(depth, int):
raise Exception("depth parameter must be an int")
if file == "str":
_file = StringIO()
......@@ -168,7 +167,7 @@ def debugprint(
smap.extend([getattr(obj, "storage_map", None) for item in obj.outputs])
topo = obj.toposort()
order.extend([topo for item in obj.outputs])
elif isinstance(obj, (integer_types, float, np.ndarray)):
elif isinstance(obj, (int, float, np.ndarray)):
print(obj, file=_file)
elif isinstance(obj, (theano.In, theano.Out)):
results_to_print.append(obj.variable)
......@@ -239,14 +238,10 @@ N.B.:
else:
inner_inputs = s.owner.op.inputs
outer_inputs = s.owner.inputs
inner_to_outer_inputs = dict(
[
(inner_inputs[i], outer_inputs[o])
for i, o in s.owner.op.var_mappings[
"outer_inp_from_inner_inp"
].items()
]
)
inner_to_outer_inputs = {
inner_inputs[i]: outer_inputs[o]
for i, o in s.owner.op.var_mappings["outer_inp_from_inner_inp"].items()
}
print("", file=_file)
debugmode.debugprint(
......@@ -440,7 +435,7 @@ class PatternPrinter:
def __init__(self, *patterns):
self.patterns = []
for pattern in patterns:
if isinstance(pattern, string_types):
if isinstance(pattern, str):
self.patterns.append((pattern, ()))
else:
self.patterns.append((pattern[0], pattern[1:]))
......@@ -469,13 +464,13 @@ class PatternPrinter:
return r
d = dict(
(str(i), x)
d = {
str(i): x
for i, x in enumerate(
pp_process(input, precedence)
for input, precedence in zip(node.inputs, precedences)
)
)
}
r = pattern % d
pstate.memo[output] = r
return r
......@@ -501,7 +496,7 @@ class FunctionPrinter:
try:
old_precedence = getattr(pstate, "precedence", None)
pstate.precedence = new_precedence
r = "%s(%s)" % (
r = "{}({})".format(
name,
", ".join([pprinter.process(input, pstate) for input in node.inputs]),
)
......@@ -556,7 +551,7 @@ class DefaultPrinter:
try:
old_precedence = getattr(pstate, "precedence", None)
pstate.precedence = new_precedence
r = "%s(%s)" % (
r = "{}({})".format(
str(node.op),
", ".join([pprinter.process(input, pstate) for input in node.inputs]),
)
......@@ -624,7 +619,7 @@ class PPrinter:
pprinter = self.clone_assign(
lambda pstate, r: r.name is not None and r is not current, leaf_printer
)
inv_updates = dict((b, a) for (a, b) in updates.items())
inv_updates = {b: a for (a, b) in updates.items()}
i = 1
for node in gof.graph.io_toposort(
list(inputs) + updates.keys(), list(outputs) + updates.values()
......@@ -633,7 +628,7 @@ class PPrinter:
if output in inv_updates:
name = str(inv_updates[output])
strings.append(
(i + 1000, "%s <- %s" % (name, pprinter.process(output)))
(i + 1000, "{} <- {}".format(name, pprinter.process(output)))
)
i += 1
if output.name is not None or output in outputs:
......@@ -653,7 +648,7 @@ class PPrinter:
strings.append((idx, "return %s" % pprinter.process(output)))
else:
strings.append(
(idx, "%s = %s" % (name, pprinter.process(output)))
(idx, "{} = {}".format(name, pprinter.process(output)))
)
i += 1
strings.sort()
......@@ -901,7 +896,7 @@ def pydotprint(
dstr = "val=" + str(np.asarray(var.data))
if "\n" in dstr:
dstr = dstr[: dstr.index("\n")]
varstr = "%s %s" % (dstr, str(var.type))
varstr = "{} {}".format(dstr, str(var.type))
elif var in input_update and input_update[var].name is not None:
varstr = input_update[var].name
if not var_with_name_simple:
......@@ -933,7 +928,7 @@ def pydotprint(
pf = 0
else:
pf = time * 100 / profile.fct_call_time
prof_str = " (%.3fs,%.3f%%)" % (time, pf)
prof_str = " ({:.3f}s,{:.3f}%)".format(time, pf)
applystr = str(node.op).replace(":", "_")
applystr += prof_str
if (applystr in all_strings) or with_ids:
......
......@@ -25,7 +25,7 @@ class Raise(gof.Op):
self.exc = exc
def __str__(self):
return "Raise{%s(%s)}" % (self.exc, self.msg)
return "Raise{{{}({})}}".format(self.exc, self.msg)
def make_node(self, x):
return gof.Apply(self, [x], [x.type()])
......
......@@ -42,7 +42,7 @@ class OrderedUpdates(OrderedDict):
"an OrderedDict that is available at "
"theano.compat.OrderedDict for python 2.6+."
)
super(OrderedUpdates, self).__init__(*key, **kwargs)
super().__init__(*key, **kwargs)
for key in self:
if not isinstance(key, SharedVariable):
raise TypeError(
......@@ -59,7 +59,7 @@ class OrderedUpdates(OrderedDict):
# value. Should it be cast to a GPU value right away? Should
# literals be transformed into constants immediately?
return super(OrderedUpdates, self).__setitem__(key, value)
return super().__setitem__(key, value)
else:
raise TypeError(
"OrderedUpdates keys must inherit from " "SharedVariable", key
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论