提交 ffa5e139 authored 作者: Brandon T. Willard's avatar Brandon T. Willard

Apply pyupgrade to theano.compile

上级 86dbc392
......@@ -37,7 +37,7 @@ _logger = logging.getLogger("theano.compile.debugmode")
# Filter to avoid duplicating optimization warnings
class NoDuplicateOptWarningFilter(logging.Filter):
prev_msgs = set([])
prev_msgs = set()
def filter(self, record):
msg = record.getMessage()
......@@ -64,8 +64,6 @@ class DebugModeError(Exception):
"""
pass
class BadThunkOutput(DebugModeError):
"""
......@@ -99,7 +97,7 @@ class BadThunkOutput(DebugModeError):
"""
def __init__(self, r, thunk1, val1, thunk2, val2, inputs_val=()):
super(BadThunkOutput, self).__init__()
super().__init__()
self.r = r
self.thunk1 = thunk1
self.val1 = val1
......@@ -170,7 +168,7 @@ class BadDestroyMap(DebugModeError):
"""
def __init__(self, node, idx, old_val, new_val, perform):
super(BadDestroyMap, self).__init__()
super().__init__()
self.node = node
self.idx = idx
self.old_val = old_val
......@@ -254,7 +252,7 @@ class BadViewMap(DebugModeError):
def __init__(
self, node, output_idx, out_storage, in_alias_idx=None, out_alias_idx=None
):
super(BadViewMap, self).__init__()
super().__init__()
self.node = node
self.output_idx = output_idx
self.out_storage = out_storage
......@@ -290,8 +288,6 @@ class StochasticOrder(DebugModeError):
"""
pass
class InvalidValueError(DebugModeError):
"""
......@@ -304,7 +300,7 @@ class InvalidValueError(DebugModeError):
"""
def __init__(self, r, v=None, client_node=None, hint="none", specific_hint="none"):
super(InvalidValueError, self).__init__()
super().__init__()
self.r = r
self.v = v
self.client_node = client_node
......@@ -718,7 +714,7 @@ def debugprint(
else:
outer_id_str = get_id_str(outer_r)
print(
"%s%s %s%s -> %s" % (prefix, r, id_str, type_str, outer_id_str),
"{}{} {}{} -> {}".format(prefix, r, id_str, type_str, outer_id_str),
file=file,
)
else:
......@@ -727,7 +723,7 @@ def debugprint(
if smap:
data = " " + str(smap.get(r, ""))
id_str = get_id_str(r)
print("%s%s %s%s%s" % (prefix, r, id_str, type_str, data), file=file)
print("{}{} {}{}{}".format(prefix, r, id_str, type_str, data), file=file)
return file
......@@ -1073,9 +1069,9 @@ def _find_bad_optimizations1(order, reasons, r_vals):
for i, node in enumerate(order):
program_position[node] = i
for new_r in node.outputs:
equivalence_sets.setdefault(new_r, set([new_r]))
equivalence_sets.setdefault(new_r, {new_r})
for reason, r, old_graph_str, new_graph_str in reasons[new_r]:
equivalence_sets[new_r].update(equivalence_sets.setdefault(r, set([r])))
equivalence_sets[new_r].update(equivalence_sets.setdefault(r, {r}))
for er in equivalence_sets[r]:
equivalence_sets[er] = equivalence_sets[new_r]
......@@ -1474,7 +1470,7 @@ def _check_preallocated_output(
):
_logger.debug(" name = %s", name)
thunk_name = "%s with %s output" % (perform, name)
thunk_name = "{} with {} output".format(perform, name)
if not out_map:
# Map is empty, there is no need to execute thunk() again
......@@ -1541,7 +1537,7 @@ def _check_preallocated_output(
fn.maker.mode = backup_mode
class _FunctionGraphEvent(object):
class _FunctionGraphEvent:
"""
A record of an event in the life of an FunctionGraph.
......@@ -1613,7 +1609,7 @@ class _FunctionGraphEvent(object):
return not (self == other)
class _VariableEquivalenceTracker(object):
class _VariableEquivalenceTracker:
"""
A FunctionGraph Feature that keeps tabs on an FunctionGraph and
tries to detect problems.
......@@ -1684,7 +1680,7 @@ class _VariableEquivalenceTracker(object):
else:
for r in node.outputs:
assert r not in self.equiv
self.equiv[r] = set([r])
self.equiv[r] = {r}
self.all_variables_ever.append(r)
self.reasons.setdefault(r, [])
self.replaced_by.setdefault(r, [])
......@@ -1740,13 +1736,13 @@ class _VariableEquivalenceTracker(object):
if r in self.equiv:
r_set = self.equiv[r]
else:
r_set = self.equiv.setdefault(r, set([r]))
r_set = self.equiv.setdefault(r, {r})
self.all_variables_ever.append(r)
if new_r in self.equiv:
new_r_set = self.equiv[new_r]
else:
new_r_set = self.equiv.setdefault(new_r, set([new_r]))
new_r_set = self.equiv.setdefault(new_r, {new_r})
self.all_variables_ever.append(new_r)
assert new_r in new_r_set
......@@ -1779,7 +1775,7 @@ default_make_thunk = [get_unbound_function(theano.gof.Op.make_thunk)]
# the external requirements of the .linker attribute of a mode
# 1) it's a class instance
# 2) it a has a .clone() method
class _DummyLinker(object):
class _DummyLinker:
# This is not a real linker anyway
def clone(self, allow_gc=None):
return self
......@@ -2746,7 +2742,7 @@ class DebugMode(Mode):
linker,
)
super(DebugMode, self).__init__(optimizer=optimizer, linker=linker)
super().__init__(optimizer=optimizer, linker=linker)
if stability_patience is not None:
self.stability_patience = stability_patience
......@@ -2771,7 +2767,7 @@ class DebugMode(Mode):
raise ValueError("DebugMode has to check at least one of c and py " "code")
def __str__(self):
return "DebugMode(linker=%s, optimizer=%s)" % (
return "DebugMode(linker={}, optimizer={})".format(
self.provided_linker,
self.provided_optimizer,
)
......
......@@ -9,8 +9,6 @@ import re
import traceback as tb
import warnings
from six import string_types
from theano import compat
from theano.compile.function_module import orig_function
from theano.compile.pfunc import pfunc
......@@ -67,7 +65,7 @@ def function_dump(
`['annotations', 'replacement_of', 'aggregation_scheme', 'roles']`
"""
assert isinstance(filename, string_types)
assert isinstance(filename, str)
d = dict(
inputs=inputs,
outputs=outputs,
......@@ -258,7 +256,7 @@ def function(
output_items = list(outputs.items())
for item_pair in output_items:
assert isinstance(item_pair[0], string_types)
assert isinstance(item_pair[0], str)
output_items_sorted = sorted(output_items)
......
......@@ -12,7 +12,6 @@ from itertools import chain
import numpy as np
import six.moves.copyreg as copyreg
import six.moves.cPickle as pickle
from six import string_types
import theano
import theano.compile.profiling
......@@ -35,8 +34,6 @@ class UnusedInputError(Exception):
"""
pass
def alias_root(v):
"""
......@@ -94,7 +91,7 @@ def infer_reuse_pattern(fgraph, outputs_to_disown):
for o in outputs_to_disown:
view_tree_set(alias_root(o), rval)
# remove from rval all of the inputs, constants, values.
rval = set(r for r in rval if r.owner is not None)
rval = {r for r in rval if r.owner is not None}
return rval
......@@ -219,8 +216,6 @@ class AliasedMemoryError(Exception):
"""
pass
###
# Function
......@@ -230,7 +225,7 @@ class AliasedMemoryError(Exception):
DUPLICATE = ["DUPLICATE"]
class Function(object):
class Function:
"""
Type of the functions returned by theano.function or
theano.FunctionMaker.create.
......@@ -478,7 +473,7 @@ class Function(object):
# this class is important in overriding the square-bracket notation:
# fn.value[x]
# self reference is available via the closure on the class
class ValueAttribute(object):
class ValueAttribute:
def __getitem__(self, item):
try:
s = finder[item]
......@@ -501,7 +496,9 @@ class Function(object):
except KeyError:
# Print informative error message.
msg = get_info_on_inputs(named_inputs, n_unnamed_inputs)
raise TypeError("Unknown input or state: %s. %s" % (str(item), msg))
raise TypeError(
"Unknown input or state: {}. {}".format(str(item), msg)
)
if s is DUPLICATE:
raise TypeError(
"Ambiguous name: %s - please check the "
......@@ -520,7 +517,7 @@ class Function(object):
# this class is important in overriding the square-bracket notation:
# fn.container[x]
# self reference is available via the closure on the class
class ContainerAttribute(object):
class ContainerAttribute:
def __getitem__(self, item):
return finder[item]
......@@ -1065,10 +1062,10 @@ class Function(object):
if output_subset is None:
return dict(zip(self.output_keys, outputs))
else:
return dict(
(self.output_keys[index], outputs[index])
return {
self.output_keys[index]: outputs[index]
for index in output_subset
)
}
if output_subset is None:
return outputs
......@@ -1201,13 +1198,11 @@ def insert_deepcopy(fgraph, wrapped_inputs, wrapped_outputs):
assert len(wrapped_inputs) == len(fgraph.inputs)
assert len(wrapped_outputs) == len(fgraph.outputs)
reason = "insert_deepcopy"
updated_fgraph_inputs = set(
[
fgraph_i
for i, fgraph_i in zip(wrapped_inputs, fgraph.inputs)
if getattr(i, "update", False)
]
)
updated_fgraph_inputs = {
fgraph_i
for i, fgraph_i in zip(wrapped_inputs, fgraph.inputs)
if getattr(i, "update", False)
}
# We can't use fgraph.inputs as this don't include Constant Value.
all_graph_inputs = gof.graph.inputs(fgraph.outputs)
......@@ -1286,7 +1281,7 @@ def insert_deepcopy(fgraph, wrapped_inputs, wrapped_outputs):
NODEFAULT = ["NODEFAULT"]
class FunctionMaker(object):
class FunctionMaker:
"""
`FunctionMaker` is the class to `create` `Function` instances.
......@@ -2041,7 +2036,7 @@ def convert_function_input(input):
orig = input
if not input:
raise TypeError("Nonsensical input specification: %s" % input)
if isinstance(input[0], string_types):
if isinstance(input[0], str):
name = input[0]
input = input[1:]
else:
......@@ -2133,7 +2128,7 @@ def get_info_on_inputs(named_inputs, n_unnamed_inputs):
)
else:
if n_unnamed_inputs == 0:
msg = "The function has %s named input%s (%s)." % (
msg = "The function has {} named input{} ({}).".format(
n_named_inputs,
get_plural(n_named_inputs),
", ".join(named_inputs),
......
......@@ -6,8 +6,6 @@ Define `SymbolicInput`, `SymbolicOutput`, `In`, `Out`.
import logging
from six import string_types
from theano import gof
from .sharedvalue import SharedVariable
......@@ -18,7 +16,7 @@ _logger = logging.getLogger("theano.compile.io")
__docformat__ = "restructuredtext en"
class SymbolicInput(object):
class SymbolicInput:
"""
Represents a symbolic input for use with function or FunctionMaker.
......@@ -79,7 +77,7 @@ class SymbolicInput(object):
else:
self.name = name
if self.name is not None and not isinstance(self.name, string_types):
if self.name is not None and not isinstance(self.name, str):
raise TypeError("name must be a string! (got: %s)" % self.name)
self.update = update
if update is not None:
......@@ -102,7 +100,7 @@ class SymbolicInput(object):
def __str__(self):
if self.update:
return "In(%s -> %s)" % (self.variable, self.update)
return "In({} -> {})".format(self.variable, self.update)
else:
return "In(%s)" % self.variable
......@@ -216,7 +214,7 @@ class In(SymbolicInput):
implicit = isinstance(value, gof.Container) or isinstance(
value, SharedVariable
)
super(In, self).__init__(
super().__init__(
variable=variable,
name=name,
update=update,
......@@ -231,7 +229,7 @@ class In(SymbolicInput):
raise TypeError("An implicit input must be given a default value")
class SymbolicOutput(object):
class SymbolicOutput:
"""
Represents a symbolic output for use with function or FunctionMaker.
......@@ -250,10 +248,10 @@ class SymbolicOutput(object):
self.borrow = borrow
def __str__(self):
return "Out(%s,%s)" % (self.variable, self.borrow)
return "Out({},{})".format(self.variable, self.borrow)
def __repr__(self):
return "Out(%s,%s)" % (self.variable, self.borrow)
return "Out({},{})".format(self.variable, self.borrow)
Out = SymbolicOutput
......@@ -6,8 +6,6 @@ WRITEME
import logging
import warnings
from six import string_types
import theano
import theano.gof.vm
from theano import config, gof
......@@ -132,7 +130,7 @@ class AddDestroyHandler(gof.Optimizer):
)
def add_requirements(self, fgraph):
super(AddDestroyHandler, self).add_requirements(fgraph)
super().add_requirements(fgraph)
fgraph.attach_feature(gof.DestroyHandler())
......@@ -145,7 +143,7 @@ class AddFeatureOptimizer(gof.Optimizer):
self.feature = feature
def add_requirements(self, fgraph):
super(AddFeatureOptimizer, self).add_requirements(fgraph)
super().add_requirements(fgraph)
fgraph.attach_feature(self.feature)
......@@ -259,7 +257,7 @@ optdb.register("CheckStackTrace", gof.CheckStackTraceOptimization(), -1, *_tags)
del _tags
class Mode(object):
class Mode:
"""
The Mode represents a way to optimize and then link a computation graph.
......@@ -303,10 +301,10 @@ class Mode(object):
linker, optimizer = state
self.provided_linker = linker
self.provided_optimizer = optimizer
if isinstance(linker, string_types) or linker is None:
if isinstance(linker, str) or linker is None:
linker = predefined_linkers[linker]
self.linker = linker
if isinstance(optimizer, string_types) or optimizer is None:
if isinstance(optimizer, str) or optimizer is None:
optimizer = predefined_optimizers[optimizer]
if isinstance(optimizer, gof.Query):
self.provided_optimizer = optimizer
......@@ -315,7 +313,7 @@ class Mode(object):
self.fn_time = 0
def __str__(self):
return "%s(linker = %s, optimizer = %s)" % (
return "{}(linker = {}, optimizer = {})".format(
self.__class__.__name__,
self.provided_linker,
self.provided_optimizer,
......@@ -330,9 +328,9 @@ class Mode(object):
optimizer = property(__get_optimizer)
def get_linker_optimizer(self, linker, optimizer):
if isinstance(linker, string_types) or linker is None:
if isinstance(linker, str) or linker is None:
linker = predefined_linkers[linker]
if isinstance(optimizer, string_types) or optimizer is None:
if isinstance(optimizer, str) or optimizer is None:
optimizer = predefined_optimizers[optimizer]
return (linker, optimizer)
......@@ -432,7 +430,7 @@ def get_mode(orig_string):
string = config.mode
else:
string = orig_string
if not isinstance(string, string_types):
if not isinstance(string, str):
return string # it is hopefully already a mode...
global instantiated_default_mode
......
......@@ -52,17 +52,17 @@ class MonitorMode(Mode):
linker,
)
super(MonitorMode, self).__init__(wrap_linker, optimizer=optimizer)
super().__init__(wrap_linker, optimizer=optimizer)
def __getstate__(self):
lnk, opt = super(MonitorMode, self).__getstate__()
lnk, opt = super().__getstate__()
return (lnk, opt, self.pre_func, self.post_func)
def __setstate__(self, state):
lnk, opt, pre_func, post_func = state
self.pre_func = pre_func
self.post_func = post_func
super(MonitorMode, self).__setstate__((lnk, opt))
super().__setstate__((lnk, opt))
def eval(self, i, node, fn):
"""
......
......@@ -295,6 +295,4 @@ class NanGuardMode(Mode):
wrap_linker = theano.gof.vm.VM_Linker(
callback=nan_check, callback_input=nan_check_input
)
super(NanGuardMode, self).__init__(
wrap_linker, optimizer=self.provided_optimizer
)
super().__init__(wrap_linker, optimizer=self.provided_optimizer)
......@@ -11,7 +11,6 @@ from collections import OrderedDict
import numpy as np
import six.moves.cPickle as pickle
from six import integer_types
import theano
from theano.gof import Apply, Op, ParamsType, Variable
......@@ -71,7 +70,7 @@ class ViewOp(Op):
return code % locals()
# Else, no C code
return super(ViewOp, self).c_code(node, nodename, inp, out, sub)
return super().c_code(node, nodename, inp, out, sub)
def c_code_cache_version(self):
version = []
......@@ -206,7 +205,7 @@ class DeepCopyOp(Op):
return code % locals()
# Else, no C code
return super(DeepCopyOp, self).c_code(node, name, inames, onames, sub)
return super().c_code(node, name, inames, onames, sub)
deep_copy_op = DeepCopyOp()
......@@ -296,7 +295,7 @@ class Shape(Op):
return code % locals()
# Else, no C code
return super(Shape, self).c_code(node, name, inames, onames, sub)
return super().c_code(node, name, inames, onames, sub)
def c_code_cache_version(self):
version = []
......@@ -423,7 +422,7 @@ class Shape_i(Op):
return (check_input + code) % locals()
# Else, no C code
return super(Shape_i, self).c_code(node, name, inames, onames, sub)
return super().c_code(node, name, inames, onames, sub)
def infer_shape(self, node, input_shapes):
return [()]
......@@ -583,7 +582,7 @@ class FromFunctionOp(Op):
obj = load_back(mod, name)
except (ImportError, KeyError, AttributeError):
raise pickle.PicklingError(
"Can't pickle as_op(), not found as %s.%s" % (mod, name)
"Can't pickle as_op(), not found as {}.{}".format(mod, name)
)
else:
if obj is not self:
......@@ -699,7 +698,7 @@ class Rebroadcast(Op):
items = sorted(axis)
self.axis = OrderedDict(items)
for axis, broad in self.axis.items():
if not isinstance(axis, (np.integer, integer_types)):
if not isinstance(axis, (np.integer, int)):
raise TypeError(
"Rebroadcast needs integer axes. " "Got {}".format(axis)
)
......@@ -723,7 +722,7 @@ class Rebroadcast(Op):
broadcast_pattern = ["?" for i in range(1 + max(self.axis.keys()))]
for k, v in self.axis.items():
broadcast_pattern[k] = str(int(v))
return "%s{%s}" % (self.__class__.__name__, ",".join(broadcast_pattern))
return "{}{{{}}}".format(self.__class__.__name__, ",".join(broadcast_pattern))
def make_node(self, x):
if self.axis.keys() and (x.ndim <= max(self.axis.keys())):
......@@ -797,7 +796,7 @@ class Rebroadcast(Op):
"""
% locals()
)
return super(Rebroadcast, self).c_code(node, nodename, inp, out, sub)
return super().c_code(node, nodename, inp, out, sub)
def c_code_cache_version(self):
version = []
......@@ -929,7 +928,7 @@ class SpecifyShape(Op):
_, _, support_code = self.c_code_and_version[itype]
if support_code:
return support_code
return super(SpecifyShape, self).c_support_code_apply(node, name)
return super().c_support_code_apply(node, name)
def c_code(self, node, name, inames, onames, sub):
iname, shape = inames
......@@ -941,7 +940,7 @@ class SpecifyShape(Op):
code, version, _ = self.c_code_and_version[itype]
return code % locals()
return super(SpecifyShape, self).c_code(node, node, inames, onames, sub)
return super().c_code(node, node, inames, onames, sub)
def c_code_cache_version(self):
version = []
......
......@@ -300,7 +300,7 @@ class Param(In):
" by theano.In(value=N)",
stacklevel=2,
)
super(Param, self).__init__(
super().__init__(
variable,
name=name,
value=default,
......@@ -447,12 +447,10 @@ def pfunc(
if v in in_variables[(i + 1) :]:
dup_v_i = in_variables.index(v, (i + 1))
raise UnusedInputError(
(
"Variable %s is used twice in inputs to theano.function, "
"at indices %i and %i. This would result in values "
"provided for it being ignored. Please do not duplicate "
"variables in the inputs list." % (v, i, dup_v_i)
)
"Variable %s is used twice in inputs to theano.function, "
"at indices %i and %i. This would result in values "
"provided for it being ignored. Please do not duplicate "
"variables in the inputs list." % (v, i, dup_v_i)
)
# Check that we are not using `givens` to replace input variables, because
......
......@@ -181,7 +181,7 @@ def register_profiler_printer(fct):
return fct
class ProfileStats(object):
class ProfileStats:
"""
Object to store runtime and memory profiling information for all of
......@@ -851,7 +851,7 @@ class ProfileStats(object):
for node, t in sorted(
self.linker_make_thunk_time.items(), key=operator.itemgetter(1)
)[::-1][:5]:
print(" Node %s time %es" % (node, t), file=file)
print(" Node {} time {:e}s".format(node, t), file=file)
print("", file=file)
# The validation time is a subset of optimizer_time
......@@ -1071,7 +1071,7 @@ class ProfileStats(object):
mem_bound = np.inf
# This take only the inputs/outputs dependencies.
dependencies = fgraph.profile.dependencies
done_set = set([])
done_set = set()
done_dict = {}
# Initial compute_map which is used to check if a node is valid
......@@ -1451,7 +1451,10 @@ class ProfileStats(object):
else:
size = "%10s" % "Unknown"
print(" %s %s %s %s" % (size, shapes, " ".join(code), node), file=file)
print(
" {} {} {} {}".format(size, shapes, " ".join(code), node),
file=file,
)
sum_remaining = sum(size for _, size in items[N:])
size_sum_dense = sum(node_mem.values())
......@@ -1499,7 +1502,7 @@ class ProfileStats(object):
file=file,
)
if config.profiling.debugprint:
fcts = set([n.fgraph for n in self.apply_time.keys()])
fcts = {n.fgraph for n in self.apply_time.keys()}
theano.printing.debugprint(fcts, print_type=True)
if self.variable_shape or self.variable_strides:
self.summary_memory(file, n_apply_to_print)
......@@ -1686,10 +1689,7 @@ class ProfileStats(object):
# tip 6
for a in self.apply_time:
node = a
if (
isinstance(node.op, T.Dot)
and len(set(i.dtype for i in node.inputs)) != 1
):
if isinstance(node.op, T.Dot) and len({i.dtype for i in node.inputs}) != 1:
print(
" - You have a dot operation that has different dtype "
" for inputs (%s). Make sure that the inputs have same "
......@@ -1742,7 +1742,7 @@ class ScanProfileStats(ProfileStats):
call_time = 0.0
def __init__(self, atexit_print=True, name=None, **kwargs):
super(ScanProfileStats, self).__init__(atexit_print, **kwargs)
super().__init__(atexit_print, **kwargs)
self.name = name
def summary_globals(self, file):
......
......@@ -67,9 +67,7 @@ class SharedVariable(Variable):
# or the "no_default_updates" list passed to "function" contains it.
def __init__(self, name, type, value, strict, allow_downcast=None, container=None):
super(SharedVariable, self).__init__(
type=type, name=name, owner=None, index=None
)
super().__init__(type=type, name=name, owner=None, index=None)
if container is not None:
self.container = container
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论