提交 ffa5e139 authored 作者: Brandon T. Willard's avatar Brandon T. Willard

Apply pyupgrade to theano.compile

上级 86dbc392
...@@ -37,7 +37,7 @@ _logger = logging.getLogger("theano.compile.debugmode") ...@@ -37,7 +37,7 @@ _logger = logging.getLogger("theano.compile.debugmode")
# Filter to avoid duplicating optimization warnings # Filter to avoid duplicating optimization warnings
class NoDuplicateOptWarningFilter(logging.Filter): class NoDuplicateOptWarningFilter(logging.Filter):
prev_msgs = set([]) prev_msgs = set()
def filter(self, record): def filter(self, record):
msg = record.getMessage() msg = record.getMessage()
...@@ -64,8 +64,6 @@ class DebugModeError(Exception): ...@@ -64,8 +64,6 @@ class DebugModeError(Exception):
""" """
pass
class BadThunkOutput(DebugModeError): class BadThunkOutput(DebugModeError):
""" """
...@@ -99,7 +97,7 @@ class BadThunkOutput(DebugModeError): ...@@ -99,7 +97,7 @@ class BadThunkOutput(DebugModeError):
""" """
def __init__(self, r, thunk1, val1, thunk2, val2, inputs_val=()): def __init__(self, r, thunk1, val1, thunk2, val2, inputs_val=()):
super(BadThunkOutput, self).__init__() super().__init__()
self.r = r self.r = r
self.thunk1 = thunk1 self.thunk1 = thunk1
self.val1 = val1 self.val1 = val1
...@@ -170,7 +168,7 @@ class BadDestroyMap(DebugModeError): ...@@ -170,7 +168,7 @@ class BadDestroyMap(DebugModeError):
""" """
def __init__(self, node, idx, old_val, new_val, perform): def __init__(self, node, idx, old_val, new_val, perform):
super(BadDestroyMap, self).__init__() super().__init__()
self.node = node self.node = node
self.idx = idx self.idx = idx
self.old_val = old_val self.old_val = old_val
...@@ -254,7 +252,7 @@ class BadViewMap(DebugModeError): ...@@ -254,7 +252,7 @@ class BadViewMap(DebugModeError):
def __init__( def __init__(
self, node, output_idx, out_storage, in_alias_idx=None, out_alias_idx=None self, node, output_idx, out_storage, in_alias_idx=None, out_alias_idx=None
): ):
super(BadViewMap, self).__init__() super().__init__()
self.node = node self.node = node
self.output_idx = output_idx self.output_idx = output_idx
self.out_storage = out_storage self.out_storage = out_storage
...@@ -290,8 +288,6 @@ class StochasticOrder(DebugModeError): ...@@ -290,8 +288,6 @@ class StochasticOrder(DebugModeError):
""" """
pass
class InvalidValueError(DebugModeError): class InvalidValueError(DebugModeError):
""" """
...@@ -304,7 +300,7 @@ class InvalidValueError(DebugModeError): ...@@ -304,7 +300,7 @@ class InvalidValueError(DebugModeError):
""" """
def __init__(self, r, v=None, client_node=None, hint="none", specific_hint="none"): def __init__(self, r, v=None, client_node=None, hint="none", specific_hint="none"):
super(InvalidValueError, self).__init__() super().__init__()
self.r = r self.r = r
self.v = v self.v = v
self.client_node = client_node self.client_node = client_node
...@@ -718,7 +714,7 @@ def debugprint( ...@@ -718,7 +714,7 @@ def debugprint(
else: else:
outer_id_str = get_id_str(outer_r) outer_id_str = get_id_str(outer_r)
print( print(
"%s%s %s%s -> %s" % (prefix, r, id_str, type_str, outer_id_str), "{}{} {}{} -> {}".format(prefix, r, id_str, type_str, outer_id_str),
file=file, file=file,
) )
else: else:
...@@ -727,7 +723,7 @@ def debugprint( ...@@ -727,7 +723,7 @@ def debugprint(
if smap: if smap:
data = " " + str(smap.get(r, "")) data = " " + str(smap.get(r, ""))
id_str = get_id_str(r) id_str = get_id_str(r)
print("%s%s %s%s%s" % (prefix, r, id_str, type_str, data), file=file) print("{}{} {}{}{}".format(prefix, r, id_str, type_str, data), file=file)
return file return file
...@@ -1073,9 +1069,9 @@ def _find_bad_optimizations1(order, reasons, r_vals): ...@@ -1073,9 +1069,9 @@ def _find_bad_optimizations1(order, reasons, r_vals):
for i, node in enumerate(order): for i, node in enumerate(order):
program_position[node] = i program_position[node] = i
for new_r in node.outputs: for new_r in node.outputs:
equivalence_sets.setdefault(new_r, set([new_r])) equivalence_sets.setdefault(new_r, {new_r})
for reason, r, old_graph_str, new_graph_str in reasons[new_r]: for reason, r, old_graph_str, new_graph_str in reasons[new_r]:
equivalence_sets[new_r].update(equivalence_sets.setdefault(r, set([r]))) equivalence_sets[new_r].update(equivalence_sets.setdefault(r, {r}))
for er in equivalence_sets[r]: for er in equivalence_sets[r]:
equivalence_sets[er] = equivalence_sets[new_r] equivalence_sets[er] = equivalence_sets[new_r]
...@@ -1474,7 +1470,7 @@ def _check_preallocated_output( ...@@ -1474,7 +1470,7 @@ def _check_preallocated_output(
): ):
_logger.debug(" name = %s", name) _logger.debug(" name = %s", name)
thunk_name = "%s with %s output" % (perform, name) thunk_name = "{} with {} output".format(perform, name)
if not out_map: if not out_map:
# Map is empty, there is no need to execute thunk() again # Map is empty, there is no need to execute thunk() again
...@@ -1541,7 +1537,7 @@ def _check_preallocated_output( ...@@ -1541,7 +1537,7 @@ def _check_preallocated_output(
fn.maker.mode = backup_mode fn.maker.mode = backup_mode
class _FunctionGraphEvent(object): class _FunctionGraphEvent:
""" """
A record of an event in the life of an FunctionGraph. A record of an event in the life of an FunctionGraph.
...@@ -1613,7 +1609,7 @@ class _FunctionGraphEvent(object): ...@@ -1613,7 +1609,7 @@ class _FunctionGraphEvent(object):
return not (self == other) return not (self == other)
class _VariableEquivalenceTracker(object): class _VariableEquivalenceTracker:
""" """
A FunctionGraph Feature that keeps tabs on an FunctionGraph and A FunctionGraph Feature that keeps tabs on an FunctionGraph and
tries to detect problems. tries to detect problems.
...@@ -1684,7 +1680,7 @@ class _VariableEquivalenceTracker(object): ...@@ -1684,7 +1680,7 @@ class _VariableEquivalenceTracker(object):
else: else:
for r in node.outputs: for r in node.outputs:
assert r not in self.equiv assert r not in self.equiv
self.equiv[r] = set([r]) self.equiv[r] = {r}
self.all_variables_ever.append(r) self.all_variables_ever.append(r)
self.reasons.setdefault(r, []) self.reasons.setdefault(r, [])
self.replaced_by.setdefault(r, []) self.replaced_by.setdefault(r, [])
...@@ -1740,13 +1736,13 @@ class _VariableEquivalenceTracker(object): ...@@ -1740,13 +1736,13 @@ class _VariableEquivalenceTracker(object):
if r in self.equiv: if r in self.equiv:
r_set = self.equiv[r] r_set = self.equiv[r]
else: else:
r_set = self.equiv.setdefault(r, set([r])) r_set = self.equiv.setdefault(r, {r})
self.all_variables_ever.append(r) self.all_variables_ever.append(r)
if new_r in self.equiv: if new_r in self.equiv:
new_r_set = self.equiv[new_r] new_r_set = self.equiv[new_r]
else: else:
new_r_set = self.equiv.setdefault(new_r, set([new_r])) new_r_set = self.equiv.setdefault(new_r, {new_r})
self.all_variables_ever.append(new_r) self.all_variables_ever.append(new_r)
assert new_r in new_r_set assert new_r in new_r_set
...@@ -1779,7 +1775,7 @@ default_make_thunk = [get_unbound_function(theano.gof.Op.make_thunk)] ...@@ -1779,7 +1775,7 @@ default_make_thunk = [get_unbound_function(theano.gof.Op.make_thunk)]
# the external requirements of the .linker attribute of a mode # the external requirements of the .linker attribute of a mode
# 1) it's a class instance # 1) it's a class instance
# 2) it a has a .clone() method # 2) it a has a .clone() method
class _DummyLinker(object): class _DummyLinker:
# This is not a real linker anyway # This is not a real linker anyway
def clone(self, allow_gc=None): def clone(self, allow_gc=None):
return self return self
...@@ -2746,7 +2742,7 @@ class DebugMode(Mode): ...@@ -2746,7 +2742,7 @@ class DebugMode(Mode):
linker, linker,
) )
super(DebugMode, self).__init__(optimizer=optimizer, linker=linker) super().__init__(optimizer=optimizer, linker=linker)
if stability_patience is not None: if stability_patience is not None:
self.stability_patience = stability_patience self.stability_patience = stability_patience
...@@ -2771,7 +2767,7 @@ class DebugMode(Mode): ...@@ -2771,7 +2767,7 @@ class DebugMode(Mode):
raise ValueError("DebugMode has to check at least one of c and py " "code") raise ValueError("DebugMode has to check at least one of c and py " "code")
def __str__(self): def __str__(self):
return "DebugMode(linker=%s, optimizer=%s)" % ( return "DebugMode(linker={}, optimizer={})".format(
self.provided_linker, self.provided_linker,
self.provided_optimizer, self.provided_optimizer,
) )
......
...@@ -9,8 +9,6 @@ import re ...@@ -9,8 +9,6 @@ import re
import traceback as tb import traceback as tb
import warnings import warnings
from six import string_types
from theano import compat from theano import compat
from theano.compile.function_module import orig_function from theano.compile.function_module import orig_function
from theano.compile.pfunc import pfunc from theano.compile.pfunc import pfunc
...@@ -67,7 +65,7 @@ def function_dump( ...@@ -67,7 +65,7 @@ def function_dump(
`['annotations', 'replacement_of', 'aggregation_scheme', 'roles']` `['annotations', 'replacement_of', 'aggregation_scheme', 'roles']`
""" """
assert isinstance(filename, string_types) assert isinstance(filename, str)
d = dict( d = dict(
inputs=inputs, inputs=inputs,
outputs=outputs, outputs=outputs,
...@@ -258,7 +256,7 @@ def function( ...@@ -258,7 +256,7 @@ def function(
output_items = list(outputs.items()) output_items = list(outputs.items())
for item_pair in output_items: for item_pair in output_items:
assert isinstance(item_pair[0], string_types) assert isinstance(item_pair[0], str)
output_items_sorted = sorted(output_items) output_items_sorted = sorted(output_items)
......
...@@ -12,7 +12,6 @@ from itertools import chain ...@@ -12,7 +12,6 @@ from itertools import chain
import numpy as np import numpy as np
import six.moves.copyreg as copyreg import six.moves.copyreg as copyreg
import six.moves.cPickle as pickle import six.moves.cPickle as pickle
from six import string_types
import theano import theano
import theano.compile.profiling import theano.compile.profiling
...@@ -35,8 +34,6 @@ class UnusedInputError(Exception): ...@@ -35,8 +34,6 @@ class UnusedInputError(Exception):
""" """
pass
def alias_root(v): def alias_root(v):
""" """
...@@ -94,7 +91,7 @@ def infer_reuse_pattern(fgraph, outputs_to_disown): ...@@ -94,7 +91,7 @@ def infer_reuse_pattern(fgraph, outputs_to_disown):
for o in outputs_to_disown: for o in outputs_to_disown:
view_tree_set(alias_root(o), rval) view_tree_set(alias_root(o), rval)
# remove from rval all of the inputs, constants, values. # remove from rval all of the inputs, constants, values.
rval = set(r for r in rval if r.owner is not None) rval = {r for r in rval if r.owner is not None}
return rval return rval
...@@ -219,8 +216,6 @@ class AliasedMemoryError(Exception): ...@@ -219,8 +216,6 @@ class AliasedMemoryError(Exception):
""" """
pass
### ###
# Function # Function
...@@ -230,7 +225,7 @@ class AliasedMemoryError(Exception): ...@@ -230,7 +225,7 @@ class AliasedMemoryError(Exception):
DUPLICATE = ["DUPLICATE"] DUPLICATE = ["DUPLICATE"]
class Function(object): class Function:
""" """
Type of the functions returned by theano.function or Type of the functions returned by theano.function or
theano.FunctionMaker.create. theano.FunctionMaker.create.
...@@ -478,7 +473,7 @@ class Function(object): ...@@ -478,7 +473,7 @@ class Function(object):
# this class is important in overriding the square-bracket notation: # this class is important in overriding the square-bracket notation:
# fn.value[x] # fn.value[x]
# self reference is available via the closure on the class # self reference is available via the closure on the class
class ValueAttribute(object): class ValueAttribute:
def __getitem__(self, item): def __getitem__(self, item):
try: try:
s = finder[item] s = finder[item]
...@@ -501,7 +496,9 @@ class Function(object): ...@@ -501,7 +496,9 @@ class Function(object):
except KeyError: except KeyError:
# Print informative error message. # Print informative error message.
msg = get_info_on_inputs(named_inputs, n_unnamed_inputs) msg = get_info_on_inputs(named_inputs, n_unnamed_inputs)
raise TypeError("Unknown input or state: %s. %s" % (str(item), msg)) raise TypeError(
"Unknown input or state: {}. {}".format(str(item), msg)
)
if s is DUPLICATE: if s is DUPLICATE:
raise TypeError( raise TypeError(
"Ambiguous name: %s - please check the " "Ambiguous name: %s - please check the "
...@@ -520,7 +517,7 @@ class Function(object): ...@@ -520,7 +517,7 @@ class Function(object):
# this class is important in overriding the square-bracket notation: # this class is important in overriding the square-bracket notation:
# fn.container[x] # fn.container[x]
# self reference is available via the closure on the class # self reference is available via the closure on the class
class ContainerAttribute(object): class ContainerAttribute:
def __getitem__(self, item): def __getitem__(self, item):
return finder[item] return finder[item]
...@@ -1065,10 +1062,10 @@ class Function(object): ...@@ -1065,10 +1062,10 @@ class Function(object):
if output_subset is None: if output_subset is None:
return dict(zip(self.output_keys, outputs)) return dict(zip(self.output_keys, outputs))
else: else:
return dict( return {
(self.output_keys[index], outputs[index]) self.output_keys[index]: outputs[index]
for index in output_subset for index in output_subset
) }
if output_subset is None: if output_subset is None:
return outputs return outputs
...@@ -1201,13 +1198,11 @@ def insert_deepcopy(fgraph, wrapped_inputs, wrapped_outputs): ...@@ -1201,13 +1198,11 @@ def insert_deepcopy(fgraph, wrapped_inputs, wrapped_outputs):
assert len(wrapped_inputs) == len(fgraph.inputs) assert len(wrapped_inputs) == len(fgraph.inputs)
assert len(wrapped_outputs) == len(fgraph.outputs) assert len(wrapped_outputs) == len(fgraph.outputs)
reason = "insert_deepcopy" reason = "insert_deepcopy"
updated_fgraph_inputs = set( updated_fgraph_inputs = {
[ fgraph_i
fgraph_i for i, fgraph_i in zip(wrapped_inputs, fgraph.inputs)
for i, fgraph_i in zip(wrapped_inputs, fgraph.inputs) if getattr(i, "update", False)
if getattr(i, "update", False) }
]
)
# We can't use fgraph.inputs as this don't include Constant Value. # We can't use fgraph.inputs as this don't include Constant Value.
all_graph_inputs = gof.graph.inputs(fgraph.outputs) all_graph_inputs = gof.graph.inputs(fgraph.outputs)
...@@ -1286,7 +1281,7 @@ def insert_deepcopy(fgraph, wrapped_inputs, wrapped_outputs): ...@@ -1286,7 +1281,7 @@ def insert_deepcopy(fgraph, wrapped_inputs, wrapped_outputs):
NODEFAULT = ["NODEFAULT"] NODEFAULT = ["NODEFAULT"]
class FunctionMaker(object): class FunctionMaker:
""" """
`FunctionMaker` is the class to `create` `Function` instances. `FunctionMaker` is the class to `create` `Function` instances.
...@@ -2041,7 +2036,7 @@ def convert_function_input(input): ...@@ -2041,7 +2036,7 @@ def convert_function_input(input):
orig = input orig = input
if not input: if not input:
raise TypeError("Nonsensical input specification: %s" % input) raise TypeError("Nonsensical input specification: %s" % input)
if isinstance(input[0], string_types): if isinstance(input[0], str):
name = input[0] name = input[0]
input = input[1:] input = input[1:]
else: else:
...@@ -2133,7 +2128,7 @@ def get_info_on_inputs(named_inputs, n_unnamed_inputs): ...@@ -2133,7 +2128,7 @@ def get_info_on_inputs(named_inputs, n_unnamed_inputs):
) )
else: else:
if n_unnamed_inputs == 0: if n_unnamed_inputs == 0:
msg = "The function has %s named input%s (%s)." % ( msg = "The function has {} named input{} ({}).".format(
n_named_inputs, n_named_inputs,
get_plural(n_named_inputs), get_plural(n_named_inputs),
", ".join(named_inputs), ", ".join(named_inputs),
......
...@@ -6,8 +6,6 @@ Define `SymbolicInput`, `SymbolicOutput`, `In`, `Out`. ...@@ -6,8 +6,6 @@ Define `SymbolicInput`, `SymbolicOutput`, `In`, `Out`.
import logging import logging
from six import string_types
from theano import gof from theano import gof
from .sharedvalue import SharedVariable from .sharedvalue import SharedVariable
...@@ -18,7 +16,7 @@ _logger = logging.getLogger("theano.compile.io") ...@@ -18,7 +16,7 @@ _logger = logging.getLogger("theano.compile.io")
__docformat__ = "restructuredtext en" __docformat__ = "restructuredtext en"
class SymbolicInput(object): class SymbolicInput:
""" """
Represents a symbolic input for use with function or FunctionMaker. Represents a symbolic input for use with function or FunctionMaker.
...@@ -79,7 +77,7 @@ class SymbolicInput(object): ...@@ -79,7 +77,7 @@ class SymbolicInput(object):
else: else:
self.name = name self.name = name
if self.name is not None and not isinstance(self.name, string_types): if self.name is not None and not isinstance(self.name, str):
raise TypeError("name must be a string! (got: %s)" % self.name) raise TypeError("name must be a string! (got: %s)" % self.name)
self.update = update self.update = update
if update is not None: if update is not None:
...@@ -102,7 +100,7 @@ class SymbolicInput(object): ...@@ -102,7 +100,7 @@ class SymbolicInput(object):
def __str__(self): def __str__(self):
if self.update: if self.update:
return "In(%s -> %s)" % (self.variable, self.update) return "In({} -> {})".format(self.variable, self.update)
else: else:
return "In(%s)" % self.variable return "In(%s)" % self.variable
...@@ -216,7 +214,7 @@ class In(SymbolicInput): ...@@ -216,7 +214,7 @@ class In(SymbolicInput):
implicit = isinstance(value, gof.Container) or isinstance( implicit = isinstance(value, gof.Container) or isinstance(
value, SharedVariable value, SharedVariable
) )
super(In, self).__init__( super().__init__(
variable=variable, variable=variable,
name=name, name=name,
update=update, update=update,
...@@ -231,7 +229,7 @@ class In(SymbolicInput): ...@@ -231,7 +229,7 @@ class In(SymbolicInput):
raise TypeError("An implicit input must be given a default value") raise TypeError("An implicit input must be given a default value")
class SymbolicOutput(object): class SymbolicOutput:
""" """
Represents a symbolic output for use with function or FunctionMaker. Represents a symbolic output for use with function or FunctionMaker.
...@@ -250,10 +248,10 @@ class SymbolicOutput(object): ...@@ -250,10 +248,10 @@ class SymbolicOutput(object):
self.borrow = borrow self.borrow = borrow
def __str__(self): def __str__(self):
return "Out(%s,%s)" % (self.variable, self.borrow) return "Out({},{})".format(self.variable, self.borrow)
def __repr__(self): def __repr__(self):
return "Out(%s,%s)" % (self.variable, self.borrow) return "Out({},{})".format(self.variable, self.borrow)
Out = SymbolicOutput Out = SymbolicOutput
...@@ -6,8 +6,6 @@ WRITEME ...@@ -6,8 +6,6 @@ WRITEME
import logging import logging
import warnings import warnings
from six import string_types
import theano import theano
import theano.gof.vm import theano.gof.vm
from theano import config, gof from theano import config, gof
...@@ -132,7 +130,7 @@ class AddDestroyHandler(gof.Optimizer): ...@@ -132,7 +130,7 @@ class AddDestroyHandler(gof.Optimizer):
) )
def add_requirements(self, fgraph): def add_requirements(self, fgraph):
super(AddDestroyHandler, self).add_requirements(fgraph) super().add_requirements(fgraph)
fgraph.attach_feature(gof.DestroyHandler()) fgraph.attach_feature(gof.DestroyHandler())
...@@ -145,7 +143,7 @@ class AddFeatureOptimizer(gof.Optimizer): ...@@ -145,7 +143,7 @@ class AddFeatureOptimizer(gof.Optimizer):
self.feature = feature self.feature = feature
def add_requirements(self, fgraph): def add_requirements(self, fgraph):
super(AddFeatureOptimizer, self).add_requirements(fgraph) super().add_requirements(fgraph)
fgraph.attach_feature(self.feature) fgraph.attach_feature(self.feature)
...@@ -259,7 +257,7 @@ optdb.register("CheckStackTrace", gof.CheckStackTraceOptimization(), -1, *_tags) ...@@ -259,7 +257,7 @@ optdb.register("CheckStackTrace", gof.CheckStackTraceOptimization(), -1, *_tags)
del _tags del _tags
class Mode(object): class Mode:
""" """
The Mode represents a way to optimize and then link a computation graph. The Mode represents a way to optimize and then link a computation graph.
...@@ -303,10 +301,10 @@ class Mode(object): ...@@ -303,10 +301,10 @@ class Mode(object):
linker, optimizer = state linker, optimizer = state
self.provided_linker = linker self.provided_linker = linker
self.provided_optimizer = optimizer self.provided_optimizer = optimizer
if isinstance(linker, string_types) or linker is None: if isinstance(linker, str) or linker is None:
linker = predefined_linkers[linker] linker = predefined_linkers[linker]
self.linker = linker self.linker = linker
if isinstance(optimizer, string_types) or optimizer is None: if isinstance(optimizer, str) or optimizer is None:
optimizer = predefined_optimizers[optimizer] optimizer = predefined_optimizers[optimizer]
if isinstance(optimizer, gof.Query): if isinstance(optimizer, gof.Query):
self.provided_optimizer = optimizer self.provided_optimizer = optimizer
...@@ -315,7 +313,7 @@ class Mode(object): ...@@ -315,7 +313,7 @@ class Mode(object):
self.fn_time = 0 self.fn_time = 0
def __str__(self): def __str__(self):
return "%s(linker = %s, optimizer = %s)" % ( return "{}(linker = {}, optimizer = {})".format(
self.__class__.__name__, self.__class__.__name__,
self.provided_linker, self.provided_linker,
self.provided_optimizer, self.provided_optimizer,
...@@ -330,9 +328,9 @@ class Mode(object): ...@@ -330,9 +328,9 @@ class Mode(object):
optimizer = property(__get_optimizer) optimizer = property(__get_optimizer)
def get_linker_optimizer(self, linker, optimizer): def get_linker_optimizer(self, linker, optimizer):
if isinstance(linker, string_types) or linker is None: if isinstance(linker, str) or linker is None:
linker = predefined_linkers[linker] linker = predefined_linkers[linker]
if isinstance(optimizer, string_types) or optimizer is None: if isinstance(optimizer, str) or optimizer is None:
optimizer = predefined_optimizers[optimizer] optimizer = predefined_optimizers[optimizer]
return (linker, optimizer) return (linker, optimizer)
...@@ -432,7 +430,7 @@ def get_mode(orig_string): ...@@ -432,7 +430,7 @@ def get_mode(orig_string):
string = config.mode string = config.mode
else: else:
string = orig_string string = orig_string
if not isinstance(string, string_types): if not isinstance(string, str):
return string # it is hopefully already a mode... return string # it is hopefully already a mode...
global instantiated_default_mode global instantiated_default_mode
......
...@@ -52,17 +52,17 @@ class MonitorMode(Mode): ...@@ -52,17 +52,17 @@ class MonitorMode(Mode):
linker, linker,
) )
super(MonitorMode, self).__init__(wrap_linker, optimizer=optimizer) super().__init__(wrap_linker, optimizer=optimizer)
def __getstate__(self): def __getstate__(self):
lnk, opt = super(MonitorMode, self).__getstate__() lnk, opt = super().__getstate__()
return (lnk, opt, self.pre_func, self.post_func) return (lnk, opt, self.pre_func, self.post_func)
def __setstate__(self, state): def __setstate__(self, state):
lnk, opt, pre_func, post_func = state lnk, opt, pre_func, post_func = state
self.pre_func = pre_func self.pre_func = pre_func
self.post_func = post_func self.post_func = post_func
super(MonitorMode, self).__setstate__((lnk, opt)) super().__setstate__((lnk, opt))
def eval(self, i, node, fn): def eval(self, i, node, fn):
""" """
......
...@@ -295,6 +295,4 @@ class NanGuardMode(Mode): ...@@ -295,6 +295,4 @@ class NanGuardMode(Mode):
wrap_linker = theano.gof.vm.VM_Linker( wrap_linker = theano.gof.vm.VM_Linker(
callback=nan_check, callback_input=nan_check_input callback=nan_check, callback_input=nan_check_input
) )
super(NanGuardMode, self).__init__( super().__init__(wrap_linker, optimizer=self.provided_optimizer)
wrap_linker, optimizer=self.provided_optimizer
)
...@@ -11,7 +11,6 @@ from collections import OrderedDict ...@@ -11,7 +11,6 @@ from collections import OrderedDict
import numpy as np import numpy as np
import six.moves.cPickle as pickle import six.moves.cPickle as pickle
from six import integer_types
import theano import theano
from theano.gof import Apply, Op, ParamsType, Variable from theano.gof import Apply, Op, ParamsType, Variable
...@@ -71,7 +70,7 @@ class ViewOp(Op): ...@@ -71,7 +70,7 @@ class ViewOp(Op):
return code % locals() return code % locals()
# Else, no C code # Else, no C code
return super(ViewOp, self).c_code(node, nodename, inp, out, sub) return super().c_code(node, nodename, inp, out, sub)
def c_code_cache_version(self): def c_code_cache_version(self):
version = [] version = []
...@@ -206,7 +205,7 @@ class DeepCopyOp(Op): ...@@ -206,7 +205,7 @@ class DeepCopyOp(Op):
return code % locals() return code % locals()
# Else, no C code # Else, no C code
return super(DeepCopyOp, self).c_code(node, name, inames, onames, sub) return super().c_code(node, name, inames, onames, sub)
deep_copy_op = DeepCopyOp() deep_copy_op = DeepCopyOp()
...@@ -296,7 +295,7 @@ class Shape(Op): ...@@ -296,7 +295,7 @@ class Shape(Op):
return code % locals() return code % locals()
# Else, no C code # Else, no C code
return super(Shape, self).c_code(node, name, inames, onames, sub) return super().c_code(node, name, inames, onames, sub)
def c_code_cache_version(self): def c_code_cache_version(self):
version = [] version = []
...@@ -423,7 +422,7 @@ class Shape_i(Op): ...@@ -423,7 +422,7 @@ class Shape_i(Op):
return (check_input + code) % locals() return (check_input + code) % locals()
# Else, no C code # Else, no C code
return super(Shape_i, self).c_code(node, name, inames, onames, sub) return super().c_code(node, name, inames, onames, sub)
def infer_shape(self, node, input_shapes): def infer_shape(self, node, input_shapes):
return [()] return [()]
...@@ -583,7 +582,7 @@ class FromFunctionOp(Op): ...@@ -583,7 +582,7 @@ class FromFunctionOp(Op):
obj = load_back(mod, name) obj = load_back(mod, name)
except (ImportError, KeyError, AttributeError): except (ImportError, KeyError, AttributeError):
raise pickle.PicklingError( raise pickle.PicklingError(
"Can't pickle as_op(), not found as %s.%s" % (mod, name) "Can't pickle as_op(), not found as {}.{}".format(mod, name)
) )
else: else:
if obj is not self: if obj is not self:
...@@ -699,7 +698,7 @@ class Rebroadcast(Op): ...@@ -699,7 +698,7 @@ class Rebroadcast(Op):
items = sorted(axis) items = sorted(axis)
self.axis = OrderedDict(items) self.axis = OrderedDict(items)
for axis, broad in self.axis.items(): for axis, broad in self.axis.items():
if not isinstance(axis, (np.integer, integer_types)): if not isinstance(axis, (np.integer, int)):
raise TypeError( raise TypeError(
"Rebroadcast needs integer axes. " "Got {}".format(axis) "Rebroadcast needs integer axes. " "Got {}".format(axis)
) )
...@@ -723,7 +722,7 @@ class Rebroadcast(Op): ...@@ -723,7 +722,7 @@ class Rebroadcast(Op):
broadcast_pattern = ["?" for i in range(1 + max(self.axis.keys()))] broadcast_pattern = ["?" for i in range(1 + max(self.axis.keys()))]
for k, v in self.axis.items(): for k, v in self.axis.items():
broadcast_pattern[k] = str(int(v)) broadcast_pattern[k] = str(int(v))
return "%s{%s}" % (self.__class__.__name__, ",".join(broadcast_pattern)) return "{}{{{}}}".format(self.__class__.__name__, ",".join(broadcast_pattern))
def make_node(self, x): def make_node(self, x):
if self.axis.keys() and (x.ndim <= max(self.axis.keys())): if self.axis.keys() and (x.ndim <= max(self.axis.keys())):
...@@ -797,7 +796,7 @@ class Rebroadcast(Op): ...@@ -797,7 +796,7 @@ class Rebroadcast(Op):
""" """
% locals() % locals()
) )
return super(Rebroadcast, self).c_code(node, nodename, inp, out, sub) return super().c_code(node, nodename, inp, out, sub)
def c_code_cache_version(self): def c_code_cache_version(self):
version = [] version = []
...@@ -929,7 +928,7 @@ class SpecifyShape(Op): ...@@ -929,7 +928,7 @@ class SpecifyShape(Op):
_, _, support_code = self.c_code_and_version[itype] _, _, support_code = self.c_code_and_version[itype]
if support_code: if support_code:
return support_code return support_code
return super(SpecifyShape, self).c_support_code_apply(node, name) return super().c_support_code_apply(node, name)
def c_code(self, node, name, inames, onames, sub): def c_code(self, node, name, inames, onames, sub):
iname, shape = inames iname, shape = inames
...@@ -941,7 +940,7 @@ class SpecifyShape(Op): ...@@ -941,7 +940,7 @@ class SpecifyShape(Op):
code, version, _ = self.c_code_and_version[itype] code, version, _ = self.c_code_and_version[itype]
return code % locals() return code % locals()
return super(SpecifyShape, self).c_code(node, node, inames, onames, sub) return super().c_code(node, node, inames, onames, sub)
def c_code_cache_version(self): def c_code_cache_version(self):
version = [] version = []
......
...@@ -300,7 +300,7 @@ class Param(In): ...@@ -300,7 +300,7 @@ class Param(In):
" by theano.In(value=N)", " by theano.In(value=N)",
stacklevel=2, stacklevel=2,
) )
super(Param, self).__init__( super().__init__(
variable, variable,
name=name, name=name,
value=default, value=default,
...@@ -447,12 +447,10 @@ def pfunc( ...@@ -447,12 +447,10 @@ def pfunc(
if v in in_variables[(i + 1) :]: if v in in_variables[(i + 1) :]:
dup_v_i = in_variables.index(v, (i + 1)) dup_v_i = in_variables.index(v, (i + 1))
raise UnusedInputError( raise UnusedInputError(
( "Variable %s is used twice in inputs to theano.function, "
"Variable %s is used twice in inputs to theano.function, " "at indices %i and %i. This would result in values "
"at indices %i and %i. This would result in values " "provided for it being ignored. Please do not duplicate "
"provided for it being ignored. Please do not duplicate " "variables in the inputs list." % (v, i, dup_v_i)
"variables in the inputs list." % (v, i, dup_v_i)
)
) )
# Check that we are not using `givens` to replace input variables, because # Check that we are not using `givens` to replace input variables, because
......
...@@ -181,7 +181,7 @@ def register_profiler_printer(fct): ...@@ -181,7 +181,7 @@ def register_profiler_printer(fct):
return fct return fct
class ProfileStats(object): class ProfileStats:
""" """
Object to store runtime and memory profiling information for all of Object to store runtime and memory profiling information for all of
...@@ -851,7 +851,7 @@ class ProfileStats(object): ...@@ -851,7 +851,7 @@ class ProfileStats(object):
for node, t in sorted( for node, t in sorted(
self.linker_make_thunk_time.items(), key=operator.itemgetter(1) self.linker_make_thunk_time.items(), key=operator.itemgetter(1)
)[::-1][:5]: )[::-1][:5]:
print(" Node %s time %es" % (node, t), file=file) print(" Node {} time {:e}s".format(node, t), file=file)
print("", file=file) print("", file=file)
# The validation time is a subset of optimizer_time # The validation time is a subset of optimizer_time
...@@ -1071,7 +1071,7 @@ class ProfileStats(object): ...@@ -1071,7 +1071,7 @@ class ProfileStats(object):
mem_bound = np.inf mem_bound = np.inf
# This take only the inputs/outputs dependencies. # This take only the inputs/outputs dependencies.
dependencies = fgraph.profile.dependencies dependencies = fgraph.profile.dependencies
done_set = set([]) done_set = set()
done_dict = {} done_dict = {}
# Initial compute_map which is used to check if a node is valid # Initial compute_map which is used to check if a node is valid
...@@ -1451,7 +1451,10 @@ class ProfileStats(object): ...@@ -1451,7 +1451,10 @@ class ProfileStats(object):
else: else:
size = "%10s" % "Unknown" size = "%10s" % "Unknown"
print(" %s %s %s %s" % (size, shapes, " ".join(code), node), file=file) print(
" {} {} {} {}".format(size, shapes, " ".join(code), node),
file=file,
)
sum_remaining = sum(size for _, size in items[N:]) sum_remaining = sum(size for _, size in items[N:])
size_sum_dense = sum(node_mem.values()) size_sum_dense = sum(node_mem.values())
...@@ -1499,7 +1502,7 @@ class ProfileStats(object): ...@@ -1499,7 +1502,7 @@ class ProfileStats(object):
file=file, file=file,
) )
if config.profiling.debugprint: if config.profiling.debugprint:
fcts = set([n.fgraph for n in self.apply_time.keys()]) fcts = {n.fgraph for n in self.apply_time.keys()}
theano.printing.debugprint(fcts, print_type=True) theano.printing.debugprint(fcts, print_type=True)
if self.variable_shape or self.variable_strides: if self.variable_shape or self.variable_strides:
self.summary_memory(file, n_apply_to_print) self.summary_memory(file, n_apply_to_print)
...@@ -1686,10 +1689,7 @@ class ProfileStats(object): ...@@ -1686,10 +1689,7 @@ class ProfileStats(object):
# tip 6 # tip 6
for a in self.apply_time: for a in self.apply_time:
node = a node = a
if ( if isinstance(node.op, T.Dot) and len({i.dtype for i in node.inputs}) != 1:
isinstance(node.op, T.Dot)
and len(set(i.dtype for i in node.inputs)) != 1
):
print( print(
" - You have a dot operation that has different dtype " " - You have a dot operation that has different dtype "
" for inputs (%s). Make sure that the inputs have same " " for inputs (%s). Make sure that the inputs have same "
...@@ -1742,7 +1742,7 @@ class ScanProfileStats(ProfileStats): ...@@ -1742,7 +1742,7 @@ class ScanProfileStats(ProfileStats):
call_time = 0.0 call_time = 0.0
def __init__(self, atexit_print=True, name=None, **kwargs): def __init__(self, atexit_print=True, name=None, **kwargs):
super(ScanProfileStats, self).__init__(atexit_print, **kwargs) super().__init__(atexit_print, **kwargs)
self.name = name self.name = name
def summary_globals(self, file): def summary_globals(self, file):
......
...@@ -67,9 +67,7 @@ class SharedVariable(Variable): ...@@ -67,9 +67,7 @@ class SharedVariable(Variable):
# or the "no_default_updates" list passed to "function" contains it. # or the "no_default_updates" list passed to "function" contains it.
def __init__(self, name, type, value, strict, allow_downcast=None, container=None): def __init__(self, name, type, value, strict, allow_downcast=None, container=None):
super(SharedVariable, self).__init__( super().__init__(type=type, name=name, owner=None, index=None)
type=type, name=name, owner=None, index=None
)
if container is not None: if container is not None:
self.container = container self.container = container
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论