提交 86dbc392 authored 作者: Brandon T. Willard's avatar Brandon T. Willard

Apply pyupgrade to top-level modules in theano package

上级 d6477d58
...@@ -84,7 +84,7 @@ def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, env= ...@@ -84,7 +84,7 @@ def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, env=
stderr=(subprocess.PIPE if hide_stderr else None), stderr=(subprocess.PIPE if hide_stderr else None),
) )
break break
except EnvironmentError: except OSError:
e = sys.exc_info()[1] e = sys.exc_info()[1]
if e.errno == errno.ENOENT: if e.errno == errno.ENOENT:
continue continue
...@@ -94,7 +94,7 @@ def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, env= ...@@ -94,7 +94,7 @@ def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, env=
return None, None return None, None
else: else:
if verbose: if verbose:
print("unable to find command, tried %s" % (commands,)) print("unable to find command, tried {}".format(commands))
return None, None return None, None
stdout = p.communicate()[0].strip().decode() stdout = p.communicate()[0].strip().decode()
if p.returncode != 0: if p.returncode != 0:
...@@ -145,7 +145,7 @@ def git_get_keywords(versionfile_abs): ...@@ -145,7 +145,7 @@ def git_get_keywords(versionfile_abs):
# _version.py. # _version.py.
keywords = {} keywords = {}
try: try:
f = open(versionfile_abs, "r") f = open(versionfile_abs)
for line in f.readlines(): for line in f.readlines():
if line.strip().startswith("git_refnames ="): if line.strip().startswith("git_refnames ="):
mo = re.search(r'=\s*"(.*)"', line) mo = re.search(r'=\s*"(.*)"', line)
...@@ -160,7 +160,7 @@ def git_get_keywords(versionfile_abs): ...@@ -160,7 +160,7 @@ def git_get_keywords(versionfile_abs):
if mo: if mo:
keywords["date"] = mo.group(1) keywords["date"] = mo.group(1)
f.close() f.close()
except EnvironmentError: except OSError:
pass pass
return keywords return keywords
...@@ -184,11 +184,11 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose): ...@@ -184,11 +184,11 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose):
if verbose: if verbose:
print("keywords are unexpanded, not using") print("keywords are unexpanded, not using")
raise NotThisMethod("unexpanded keywords, not a git-archive tarball") raise NotThisMethod("unexpanded keywords, not a git-archive tarball")
refs = set([r.strip() for r in refnames.strip("()").split(",")]) refs = {r.strip() for r in refnames.strip("()").split(",")}
# starting in git-1.8.3, tags are listed as "tag: foo-1.0" instead of # starting in git-1.8.3, tags are listed as "tag: foo-1.0" instead of
# just "foo-1.0". If we see a "tag: " prefix, prefer those. # just "foo-1.0". If we see a "tag: " prefix, prefer those.
TAG = "tag: " TAG = "tag: "
tags = set([r[len(TAG) :] for r in refs if r.startswith(TAG)]) tags = {r[len(TAG) :] for r in refs if r.startswith(TAG)}
if not tags: if not tags:
# Either we're using git < 1.8.3, or there really are no tags. We use # Either we're using git < 1.8.3, or there really are no tags. We use
# a heuristic: assume all version tags have a digit. The old git %d # a heuristic: assume all version tags have a digit. The old git %d
...@@ -197,7 +197,7 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose): ...@@ -197,7 +197,7 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose):
# between branches and tags. By ignoring refnames without digits, we # between branches and tags. By ignoring refnames without digits, we
# filter out many common branch names like "release" and # filter out many common branch names like "release" and
# "stabilization", as well as "HEAD" and "master". # "stabilization", as well as "HEAD" and "master".
tags = set([r for r in refs if re.search(r"\d", r)]) tags = {r for r in refs if re.search(r"\d", r)}
if verbose: if verbose:
print("discarding '%s', no digits" % ",".join(refs - tags)) print("discarding '%s', no digits" % ",".join(refs - tags))
if verbose: if verbose:
...@@ -300,7 +300,7 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): ...@@ -300,7 +300,7 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command):
if verbose: if verbose:
fmt = "tag '%s' doesn't start with prefix '%s'" fmt = "tag '%s' doesn't start with prefix '%s'"
print(fmt % (full_tag, tag_prefix)) print(fmt % (full_tag, tag_prefix))
pieces["error"] = "tag '%s' doesn't start with prefix '%s'" % ( pieces["error"] = "tag '{}' doesn't start with prefix '{}'".format(
full_tag, full_tag,
tag_prefix, tag_prefix,
) )
......
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
from collections import OrderedDict from collections import OrderedDict
# Python 3.x compatibility # Python 3.x compatibility
from six import PY3, BytesIO, b, next from six import PY3, BytesIO, b
from six.moves import configparser from six.moves import configparser
from six.moves import reload_module as reload from six.moves import reload_module as reload
...@@ -24,59 +24,44 @@ except ImportError: ...@@ -24,59 +24,44 @@ except ImportError:
__all__ = ["PY3", "b", "BytesIO", "next", "configparser", "reload"] __all__ = ["PY3", "b", "BytesIO", "next", "configparser", "reload"]
if PY3: from operator import truediv as operator_div
from operator import truediv as operator_div
# In python 3.x, when an exception is reraised it saves original
# exception in its args, therefore in order to find the actual
# message, we need to unpack arguments recursively.
def exc_message(e):
msg = e.args[0]
if isinstance(msg, Exception):
return exc_message(msg)
return msg
def cmp(x, y): # In python 3.x, when an exception is reraised it saves original
"""Return -1 if x < y, 0 if x == y, 1 if x > y.""" # exception in its args, therefore in order to find the actual
return (x > y) - (x < y) # message, we need to unpack arguments recursively.
def exc_message(e):
msg = e.args[0]
if isinstance(msg, Exception):
return exc_message(msg)
return msg
def get_unbound_function(unbound):
# Op.make_thunk isn't bound, so don't have a __func__ attr.
# But bound method, have a __func__ method that point to the
# not bound method. That is what we want.
if hasattr(unbound, "__func__"):
return unbound.__func__
return unbound
def decode(x): def cmp(x, y):
return x.decode() """Return -1 if x < y, 0 if x == y, 1 if x > y."""
return (x > y) - (x < y)
def decode_iter(itr):
for x in itr:
yield x.decode()
def decode_with(x, encoding): def get_unbound_function(unbound):
return x.decode(encoding) # Op.make_thunk isn't bound, so don't have a __func__ attr.
# But bound method, have a __func__ method that point to the
# not bound method. That is what we want.
if hasattr(unbound, "__func__"):
return unbound.__func__
return unbound
else: def decode(x):
from operator import div as operator_div return x.decode()
from six import get_unbound_function
def exc_message(e): def decode_iter(itr):
return e[0] for x in itr:
yield x.decode()
cmp = cmp
def decode(x): def decode_with(x, encoding):
return x return x.decode(encoding)
def decode_iter(x):
return x
def decode_with(x, encoding):
return x
__all__ += [ __all__ += [
......
...@@ -10,7 +10,6 @@ import textwrap ...@@ -10,7 +10,6 @@ import textwrap
import warnings import warnings
import numpy as np import numpy as np
from six import string_types
import theano import theano
from theano.compat import maybe_add_to_os_environ_pathlist from theano.compat import maybe_add_to_os_environ_pathlist
...@@ -142,18 +141,16 @@ class DeviceParam(ConfigParam): ...@@ -142,18 +141,16 @@ class DeviceParam(ConfigParam):
) )
else: else:
raise ValueError( raise ValueError(
( 'Invalid value ("%s") for configuration '
'Invalid value ("%s") for configuration ' 'variable "%s". Valid options start with '
'variable "%s". Valid options start with ' 'one of "cpu", "opencl" or "cuda".' % (val, self.fullname)
'one of "cpu", "opencl" or "cuda".' % (val, self.fullname)
)
) )
over = kwargs.get("allow_override", True) over = kwargs.get("allow_override", True)
super(DeviceParam, self).__init__(default, filter, over) super().__init__(default, filter, over)
def __str__(self): def __str__(self):
return "%s (%s, opencl*, cuda*) " % (self.fullname, self.default) return "{} ({}, opencl*, cuda*) ".format(self.fullname, self.default)
AddConfigVar( AddConfigVar(
...@@ -211,13 +208,13 @@ class ContextsParam(ConfigParam): ...@@ -211,13 +208,13 @@ class ContextsParam(ConfigParam):
for v in val.split(";"): for v in val.split(";"):
s = v.split("->") s = v.split("->")
if len(s) != 2: if len(s) != 2:
raise ValueError("Malformed context map: %s" % (v,)) raise ValueError("Malformed context map: {}".format(v))
if ( if (
s[0] == "cpu" s[0] == "cpu"
or s[0].startswith("cuda") or s[0].startswith("cuda")
or s[0].startswith("opencl") or s[0].startswith("opencl")
): ):
raise ValueError("Cannot use %s as context name" % (s[0],)) raise ValueError("Cannot use {} as context name".format(s[0]))
return val return val
ConfigParam.__init__(self, "", filter, False) ConfigParam.__init__(self, "", filter, False)
...@@ -1409,7 +1406,7 @@ AddConfigVar( ...@@ -1409,7 +1406,7 @@ AddConfigVar(
def is_valid_check_preallocated_output_param(param): def is_valid_check_preallocated_output_param(param):
if not isinstance(param, string_types): if not isinstance(param, str):
return False return False
valid = [ valid = [
"initial", "initial",
...@@ -1821,7 +1818,7 @@ def default_blas_ldflags(): ...@@ -1821,7 +1818,7 @@ def default_blas_ldflags():
# we just pass the whole ldflags as the -l # we just pass the whole ldflags as the -l
# options part. # options part.
[ [
"-L%s%s%s" % (path_wrapper, l, path_wrapper) "-L{}{}{}".format(path_wrapper, l, path_wrapper)
for l in blas_info.get("library_dirs", []) for l in blas_info.get("library_dirs", [])
] ]
+ ["-l%s" % l for l in blas_info.get("libraries", [])] + ["-l%s" % l for l in blas_info.get("libraries", [])]
...@@ -1902,7 +1899,7 @@ def try_blas_flag(flags): ...@@ -1902,7 +1899,7 @@ def try_blas_flag(flags):
path_wrapper = '"' if os.name == "nt" else "" path_wrapper = '"' if os.name == "nt" else ""
cflags.extend( cflags.extend(
[ [
"-L%s%s%s" % (path_wrapper, d, path_wrapper) "-L{}{}{}".format(path_wrapper, d, path_wrapper)
for d in theano.gof.cmodule.std_lib_dirs() for d in theano.gof.cmodule.std_lib_dirs()
] ]
) )
...@@ -2311,11 +2308,11 @@ def filter_compiledir(path): ...@@ -2311,11 +2308,11 @@ def filter_compiledir(path):
if not os.path.exists(init_file): if not os.path.exists(init_file):
try: try:
open(init_file, "w").close() open(init_file, "w").close()
except IOError as e: except OSError as e:
if os.path.exists(init_file): if os.path.exists(init_file):
pass # has already been created pass # has already been created
else: else:
e.args += ("%s exist? %s" % (path, os.path.exists(path)),) e.args += ("{} exist? {}".format(path, os.path.exists(path)),)
raise raise
return path return path
...@@ -2390,4 +2387,4 @@ AddConfigVar( ...@@ -2390,4 +2387,4 @@ AddConfigVar(
# Check if there are remaining flags provided by the user through THEANO_FLAGS. # Check if there are remaining flags provided by the user through THEANO_FLAGS.
for key in THEANO_FLAGS_DICT.keys(): for key in THEANO_FLAGS_DICT.keys():
warnings.warn("Theano does not recognise this flag: {0}".format(key)) warnings.warn("Theano does not recognise this flag: {}".format(key))
...@@ -10,7 +10,7 @@ import sys ...@@ -10,7 +10,7 @@ import sys
import warnings import warnings
from functools import wraps from functools import wraps
from six import PY3, StringIO, string_types from six import PY3, StringIO
import theano import theano
from theano.compat import configparser as ConfigParser from theano.compat import configparser as ConfigParser
...@@ -95,7 +95,7 @@ theano_raw_cfg = ConfigParser.RawConfigParser() ...@@ -95,7 +95,7 @@ theano_raw_cfg = ConfigParser.RawConfigParser()
theano_raw_cfg.read(config_files) theano_raw_cfg.read(config_files)
class change_flags(object): class change_flags:
""" """
Use this as a decorator or context manager to change the value of Use this as a decorator or context manager to change the value of
Theano config variables. Theano config variables.
...@@ -204,12 +204,12 @@ def get_config_hash(): ...@@ -204,12 +204,12 @@ def get_config_hash():
) )
return theano.gof.utils.hash_from_code( return theano.gof.utils.hash_from_code(
"\n".join( "\n".join(
["%s = %s" % (cv.fullname, cv.__get__(True, None)) for cv in all_opts] ["{} = {}".format(cv.fullname, cv.__get__(True, None)) for cv in all_opts]
) )
) )
class TheanoConfigParser(object): class TheanoConfigParser:
# properties are installed by AddConfigVar # properties are installed by AddConfigVar
_i_am_a_config_class = True _i_am_a_config_class = True
...@@ -276,7 +276,7 @@ def AddConfigVar(name, doc, configparam, root=config, in_c_key=True): ...@@ -276,7 +276,7 @@ def AddConfigVar(name, doc, configparam, root=config, in_c_key=True):
if not hasattr(root, sections[0]): if not hasattr(root, sections[0]):
# every internal node in the config tree is an instance of its own # every internal node in the config tree is an instance of its own
# unique class # unique class
class SubObj(object): class SubObj:
_i_am_a_config_class = True _i_am_a_config_class = True
setattr(root.__class__, sections[0], SubObj()) setattr(root.__class__, sections[0], SubObj())
...@@ -312,7 +312,7 @@ def AddConfigVar(name, doc, configparam, root=config, in_c_key=True): ...@@ -312,7 +312,7 @@ def AddConfigVar(name, doc, configparam, root=config, in_c_key=True):
_config_var_list.append(configparam) _config_var_list.append(configparam)
class ConfigParam(object): class ConfigParam:
def __init__(self, default, filter=None, allow_override=True): def __init__(self, default, filter=None, allow_override=True):
""" """
If allow_override is False, we can't change the value after the import If allow_override is False, we can't change the value after the import
...@@ -368,7 +368,7 @@ class EnumStr(ConfigParam): ...@@ -368,7 +368,7 @@ class EnumStr(ConfigParam):
# All options should be strings # All options should be strings
for val in self.all: for val in self.all:
if not isinstance(val, string_types): if not isinstance(val, str):
raise ValueError( raise ValueError(
"Valid values for an EnumStr parameter " "should be strings", "Valid values for an EnumStr parameter " "should be strings",
val, val,
...@@ -384,17 +384,15 @@ class EnumStr(ConfigParam): ...@@ -384,17 +384,15 @@ class EnumStr(ConfigParam):
return val return val
else: else:
raise ValueError( raise ValueError(
( 'Invalid value ("%s") for configuration variable "%s". '
'Invalid value ("%s") for configuration variable "%s". ' "Valid options are %s" % (val, self.fullname, self.all)
"Valid options are %s" % (val, self.fullname, self.all)
)
) )
over = kwargs.get("allow_override", True) over = kwargs.get("allow_override", True)
super(EnumStr, self).__init__(default, filter, over) super().__init__(default, filter, over)
def __str__(self): def __str__(self):
return "%s (%s) " % (self.fullname, self.all) return "{} ({}) ".format(self.fullname, self.all)
class TypedParam(ConfigParam): class TypedParam(ConfigParam):
...@@ -414,10 +412,10 @@ class TypedParam(ConfigParam): ...@@ -414,10 +412,10 @@ class TypedParam(ConfigParam):
) )
return cast_val return cast_val
super(TypedParam, self).__init__(default, filter, allow_override=allow_override) super().__init__(default, filter, allow_override=allow_override)
def __str__(self): def __str__(self):
return "%s (%s) " % (self.fullname, self.mytype) return "{} ({}) ".format(self.fullname, self.mytype)
def StrParam(default, is_valid=None, allow_override=True): def StrParam(default, is_valid=None, allow_override=True):
......
...@@ -130,20 +130,16 @@ class DisconnectedType(theano.gof.type.Type): ...@@ -130,20 +130,16 @@ class DisconnectedType(theano.gof.type.Type):
def filter(self, data, strict=False, allow_downcast=None): def filter(self, data, strict=False, allow_downcast=None):
raise AssertionError( raise AssertionError(
( "If you're assigning to a DisconnectedType you're"
"If you're assigning to a DisconnectedType you're" " doing something wrong. It should only be used as"
" doing something wrong. It should only be used as" " a symbolic placeholder."
" a symbolic placeholder."
)
) )
def fiter_variable(self, other): def fiter_variable(self, other):
raise AssertionError( raise AssertionError(
( "If you're assigning to a DisconnectedType you're"
"If you're assigning to a DisconnectedType you're" " doing something wrong. It should only be used as"
" doing something wrong. It should only be used as" " a symbolic placeholder."
" a symbolic placeholder."
)
) )
def may_share_memory(a, b): def may_share_memory(a, b):
...@@ -151,11 +147,9 @@ class DisconnectedType(theano.gof.type.Type): ...@@ -151,11 +147,9 @@ class DisconnectedType(theano.gof.type.Type):
def value_eq(a, b, force_same_dtype=True): def value_eq(a, b, force_same_dtype=True):
raise AssertionError( raise AssertionError(
( "If you're assigning to a DisconnectedType you're"
"If you're assigning to a DisconnectedType you're" " doing something wrong. It should only be used as"
" doing something wrong. It should only be used as" " a symbolic placeholder."
" a symbolic placeholder."
)
) )
def __str__(self): def __str__(self):
...@@ -846,7 +840,7 @@ def _node_to_pattern(node): ...@@ -846,7 +840,7 @@ def _node_to_pattern(node):
raise TypeError( raise TypeError(
"%s.connection_pattern should return" % node.op "%s.connection_pattern should return" % node.op
+ " a list of lists, but element %d" % ii + " a list of lists, but element %d" % ii
+ "is %s of type %s." % (output_pattern, type(output_pattern)) + "is {} of type {}.".format(output_pattern, type(output_pattern))
) )
else: else:
connection_pattern = [[True for output in node.outputs] for ipt in node.inputs] connection_pattern = [[True for output in node.outputs] for ipt in node.inputs]
...@@ -933,7 +927,7 @@ def _populate_var_to_app_to_idx(outputs, wrt, consider_constant): ...@@ -933,7 +927,7 @@ def _populate_var_to_app_to_idx(outputs, wrt, consider_constant):
# Note: we need to revisit the apply nodes repeatedly, because # Note: we need to revisit the apply nodes repeatedly, because
# different outputs of the apply node are connected to # different outputs of the apply node are connected to
# different subsets of the inputs. # different subsets of the inputs.
accounted_for = set([]) accounted_for = set()
def account_for(var): def account_for(var):
# Don't visit the same variable twice # Don't visit the same variable twice
...@@ -984,7 +978,7 @@ def _populate_var_to_app_to_idx(outputs, wrt, consider_constant): ...@@ -984,7 +978,7 @@ def _populate_var_to_app_to_idx(outputs, wrt, consider_constant):
# determine which variables have elements of wrt as a true # determine which variables have elements of wrt as a true
# ancestor. Do this with an upward pass starting from wrt, # ancestor. Do this with an upward pass starting from wrt,
# following only true connections # following only true connections
visited = set([]) visited = set()
def visit(var): def visit(var):
if var in visited: if var in visited:
...@@ -1458,7 +1452,7 @@ def _populate_grad_dict(var_to_app_to_idx, grad_dict, wrt, cost_name=None): ...@@ -1458,7 +1452,7 @@ def _populate_grad_dict(var_to_app_to_idx, grad_dict, wrt, cost_name=None):
grad_dict[var] = disconnected_type() grad_dict[var] = disconnected_type()
if cost_name is not None and var.name is not None: if cost_name is not None and var.name is not None:
grad_dict[var].name = "(d%s/d%s)" % (cost_name, var.name) grad_dict[var].name = "(d{}/d{})".format(cost_name, var.name)
else: else:
# this variable isn't connected to the cost in the # this variable isn't connected to the cost in the
# computational graph # computational graph
...@@ -1494,7 +1488,7 @@ def _float_ones_like(x): ...@@ -1494,7 +1488,7 @@ def _float_ones_like(x):
return x.ones_like(dtype=dtype) return x.ones_like(dtype=dtype)
class numeric_grad(object): class numeric_grad:
""" """
Compute the numeric derivative of a scalar-valued function at a particular Compute the numeric derivative of a scalar-valued function at a particular
point. point.
...@@ -1818,13 +1812,11 @@ def verify_grad( ...@@ -1818,13 +1812,11 @@ def verify_grad(
if rng is None: if rng is None:
raise TypeError( raise TypeError(
( "rng should be a valid instance of "
"rng should be a valid instance of " "numpy.random.RandomState. You may "
"numpy.random.RandomState. You may " "want to use tests.unittest"
"want to use tests.unittest" "_tools.verify_grad instead of "
"_tools.verify_grad instead of " "theano.gradient.verify_grad."
"theano.gradient.verify_grad."
)
) )
# We allow input downcast in function, because numeric_grad works in the # We allow input downcast in function, because numeric_grad works in the
...@@ -1853,7 +1845,7 @@ def verify_grad( ...@@ -1853,7 +1845,7 @@ def verify_grad(
if isinstance(o_output, list): if isinstance(o_output, list):
raise NotImplementedError( raise NotImplementedError(
("cant (yet) autotest gradient of fun " "with multiple outputs") "cant (yet) autotest gradient of fun " "with multiple outputs"
) )
# we could make loop over outputs making random projections R for each, # we could make loop over outputs making random projections R for each,
# but this doesn't handle the case where not all the outputs are # but this doesn't handle the case where not all the outputs are
......
...@@ -190,11 +190,9 @@ class IfElse(Op): ...@@ -190,11 +190,9 @@ class IfElse(Op):
) )
if c.ndim > 0: if c.ndim > 0:
raise TypeError( raise TypeError(
( "Condition given to the op has to be a scalar "
"Condition given to the op has to be a scalar " "with 0 standing for False, anything else "
"with 0 standing for False, anything else " "for True"
"for True"
)
) )
return Apply(self, [c] + list(args), [t.type() for t in ts]) return Apply(self, [c] + list(args), [t.type() for t in ts])
...@@ -401,13 +399,11 @@ def ifelse(condition, then_branch, else_branch, name=None): ...@@ -401,13 +399,11 @@ def ifelse(condition, then_branch, else_branch, name=None):
if len(then_branch) != len(else_branch): if len(then_branch) != len(else_branch):
raise ValueError( raise ValueError(
( "The number of values on the `then` branch"
"The number of values on the `then` branch" " should have the same number of variables as "
" should have the same number of variables as " "the `else` branch : (variables on `then` "
"the `else` branch : (variables on `then` " "%d" % len(then_branch) + ", variables on `else` "
"%d" % len(then_branch) + ", variables on `else` " "%d" % len(else_branch) + ")"
"%d" % len(else_branch) + ")"
)
) )
new_ifelse = IfElse(n_outs=len(then_branch), as_view=False, gpu=False, name=name) new_ifelse = IfElse(n_outs=len(then_branch), as_view=False, gpu=False, name=name)
......
...@@ -2,7 +2,7 @@ import os ...@@ -2,7 +2,7 @@ import os
import sys import sys
class PathParser(object): class PathParser:
""" """
Class that allows to modify system's PATH environment variable Class that allows to modify system's PATH environment variable
at runtime. Currently used in ``theano.gpuarray.dnn`` module at runtime. Currently used in ``theano.gpuarray.dnn`` module
......
...@@ -12,7 +12,6 @@ from copy import copy ...@@ -12,7 +12,6 @@ from copy import copy
from functools import reduce from functools import reduce
import numpy as np import numpy as np
from six import integer_types, string_types
from six.moves import StringIO from six.moves import StringIO
import theano import theano
...@@ -55,7 +54,7 @@ except ImportError: ...@@ -55,7 +54,7 @@ except ImportError:
_logger = logging.getLogger("theano.printing") _logger = logging.getLogger("theano.printing")
VALID_ASSOC = set(["left", "right", "either"]) VALID_ASSOC = {"left", "right", "either"}
def debugprint( def debugprint(
...@@ -121,7 +120,7 @@ def debugprint( ...@@ -121,7 +120,7 @@ def debugprint(
to the Apply's identifier, to indicate which output a line corresponds to. to the Apply's identifier, to indicate which output a line corresponds to.
""" """
if not isinstance(depth, integer_types): if not isinstance(depth, int):
raise Exception("depth parameter must be an int") raise Exception("depth parameter must be an int")
if file == "str": if file == "str":
_file = StringIO() _file = StringIO()
...@@ -168,7 +167,7 @@ def debugprint( ...@@ -168,7 +167,7 @@ def debugprint(
smap.extend([getattr(obj, "storage_map", None) for item in obj.outputs]) smap.extend([getattr(obj, "storage_map", None) for item in obj.outputs])
topo = obj.toposort() topo = obj.toposort()
order.extend([topo for item in obj.outputs]) order.extend([topo for item in obj.outputs])
elif isinstance(obj, (integer_types, float, np.ndarray)): elif isinstance(obj, (int, float, np.ndarray)):
print(obj, file=_file) print(obj, file=_file)
elif isinstance(obj, (theano.In, theano.Out)): elif isinstance(obj, (theano.In, theano.Out)):
results_to_print.append(obj.variable) results_to_print.append(obj.variable)
...@@ -239,14 +238,10 @@ N.B.: ...@@ -239,14 +238,10 @@ N.B.:
else: else:
inner_inputs = s.owner.op.inputs inner_inputs = s.owner.op.inputs
outer_inputs = s.owner.inputs outer_inputs = s.owner.inputs
inner_to_outer_inputs = dict( inner_to_outer_inputs = {
[ inner_inputs[i]: outer_inputs[o]
(inner_inputs[i], outer_inputs[o]) for i, o in s.owner.op.var_mappings["outer_inp_from_inner_inp"].items()
for i, o in s.owner.op.var_mappings[ }
"outer_inp_from_inner_inp"
].items()
]
)
print("", file=_file) print("", file=_file)
debugmode.debugprint( debugmode.debugprint(
...@@ -440,7 +435,7 @@ class PatternPrinter: ...@@ -440,7 +435,7 @@ class PatternPrinter:
def __init__(self, *patterns): def __init__(self, *patterns):
self.patterns = [] self.patterns = []
for pattern in patterns: for pattern in patterns:
if isinstance(pattern, string_types): if isinstance(pattern, str):
self.patterns.append((pattern, ())) self.patterns.append((pattern, ()))
else: else:
self.patterns.append((pattern[0], pattern[1:])) self.patterns.append((pattern[0], pattern[1:]))
...@@ -469,13 +464,13 @@ class PatternPrinter: ...@@ -469,13 +464,13 @@ class PatternPrinter:
return r return r
d = dict( d = {
(str(i), x) str(i): x
for i, x in enumerate( for i, x in enumerate(
pp_process(input, precedence) pp_process(input, precedence)
for input, precedence in zip(node.inputs, precedences) for input, precedence in zip(node.inputs, precedences)
) )
) }
r = pattern % d r = pattern % d
pstate.memo[output] = r pstate.memo[output] = r
return r return r
...@@ -501,7 +496,7 @@ class FunctionPrinter: ...@@ -501,7 +496,7 @@ class FunctionPrinter:
try: try:
old_precedence = getattr(pstate, "precedence", None) old_precedence = getattr(pstate, "precedence", None)
pstate.precedence = new_precedence pstate.precedence = new_precedence
r = "%s(%s)" % ( r = "{}({})".format(
name, name,
", ".join([pprinter.process(input, pstate) for input in node.inputs]), ", ".join([pprinter.process(input, pstate) for input in node.inputs]),
) )
...@@ -556,7 +551,7 @@ class DefaultPrinter: ...@@ -556,7 +551,7 @@ class DefaultPrinter:
try: try:
old_precedence = getattr(pstate, "precedence", None) old_precedence = getattr(pstate, "precedence", None)
pstate.precedence = new_precedence pstate.precedence = new_precedence
r = "%s(%s)" % ( r = "{}({})".format(
str(node.op), str(node.op),
", ".join([pprinter.process(input, pstate) for input in node.inputs]), ", ".join([pprinter.process(input, pstate) for input in node.inputs]),
) )
...@@ -624,7 +619,7 @@ class PPrinter: ...@@ -624,7 +619,7 @@ class PPrinter:
pprinter = self.clone_assign( pprinter = self.clone_assign(
lambda pstate, r: r.name is not None and r is not current, leaf_printer lambda pstate, r: r.name is not None and r is not current, leaf_printer
) )
inv_updates = dict((b, a) for (a, b) in updates.items()) inv_updates = {b: a for (a, b) in updates.items()}
i = 1 i = 1
for node in gof.graph.io_toposort( for node in gof.graph.io_toposort(
list(inputs) + updates.keys(), list(outputs) + updates.values() list(inputs) + updates.keys(), list(outputs) + updates.values()
...@@ -633,7 +628,7 @@ class PPrinter: ...@@ -633,7 +628,7 @@ class PPrinter:
if output in inv_updates: if output in inv_updates:
name = str(inv_updates[output]) name = str(inv_updates[output])
strings.append( strings.append(
(i + 1000, "%s <- %s" % (name, pprinter.process(output))) (i + 1000, "{} <- {}".format(name, pprinter.process(output)))
) )
i += 1 i += 1
if output.name is not None or output in outputs: if output.name is not None or output in outputs:
...@@ -653,7 +648,7 @@ class PPrinter: ...@@ -653,7 +648,7 @@ class PPrinter:
strings.append((idx, "return %s" % pprinter.process(output))) strings.append((idx, "return %s" % pprinter.process(output)))
else: else:
strings.append( strings.append(
(idx, "%s = %s" % (name, pprinter.process(output))) (idx, "{} = {}".format(name, pprinter.process(output)))
) )
i += 1 i += 1
strings.sort() strings.sort()
...@@ -901,7 +896,7 @@ def pydotprint( ...@@ -901,7 +896,7 @@ def pydotprint(
dstr = "val=" + str(np.asarray(var.data)) dstr = "val=" + str(np.asarray(var.data))
if "\n" in dstr: if "\n" in dstr:
dstr = dstr[: dstr.index("\n")] dstr = dstr[: dstr.index("\n")]
varstr = "%s %s" % (dstr, str(var.type)) varstr = "{} {}".format(dstr, str(var.type))
elif var in input_update and input_update[var].name is not None: elif var in input_update and input_update[var].name is not None:
varstr = input_update[var].name varstr = input_update[var].name
if not var_with_name_simple: if not var_with_name_simple:
...@@ -933,7 +928,7 @@ def pydotprint( ...@@ -933,7 +928,7 @@ def pydotprint(
pf = 0 pf = 0
else: else:
pf = time * 100 / profile.fct_call_time pf = time * 100 / profile.fct_call_time
prof_str = " (%.3fs,%.3f%%)" % (time, pf) prof_str = " ({:.3f}s,{:.3f}%)".format(time, pf)
applystr = str(node.op).replace(":", "_") applystr = str(node.op).replace(":", "_")
applystr += prof_str applystr += prof_str
if (applystr in all_strings) or with_ids: if (applystr in all_strings) or with_ids:
......
...@@ -25,7 +25,7 @@ class Raise(gof.Op): ...@@ -25,7 +25,7 @@ class Raise(gof.Op):
self.exc = exc self.exc = exc
def __str__(self): def __str__(self):
return "Raise{%s(%s)}" % (self.exc, self.msg) return "Raise{{{}({})}}".format(self.exc, self.msg)
def make_node(self, x): def make_node(self, x):
return gof.Apply(self, [x], [x.type()]) return gof.Apply(self, [x], [x.type()])
......
...@@ -42,7 +42,7 @@ class OrderedUpdates(OrderedDict): ...@@ -42,7 +42,7 @@ class OrderedUpdates(OrderedDict):
"an OrderedDict that is available at " "an OrderedDict that is available at "
"theano.compat.OrderedDict for python 2.6+." "theano.compat.OrderedDict for python 2.6+."
) )
super(OrderedUpdates, self).__init__(*key, **kwargs) super().__init__(*key, **kwargs)
for key in self: for key in self:
if not isinstance(key, SharedVariable): if not isinstance(key, SharedVariable):
raise TypeError( raise TypeError(
...@@ -59,7 +59,7 @@ class OrderedUpdates(OrderedDict): ...@@ -59,7 +59,7 @@ class OrderedUpdates(OrderedDict):
# value. Should it be cast to a GPU value right away? Should # value. Should it be cast to a GPU value right away? Should
# literals be transformed into constants immediately? # literals be transformed into constants immediately?
return super(OrderedUpdates, self).__setitem__(key, value) return super().__setitem__(key, value)
else: else:
raise TypeError( raise TypeError(
"OrderedUpdates keys must inherit from " "SharedVariable", key "OrderedUpdates keys must inherit from " "SharedVariable", key
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论