提交 2b658090 authored 作者: Brandon T. Willard's avatar Brandon T. Willard

Apply pyupgrade to theano.misc

上级 bcccce7a
...@@ -293,7 +293,7 @@ if __name__ == "__main__": ...@@ -293,7 +293,7 @@ if __name__ == "__main__":
print("(%d, %d) and (%d, %d)." % (M, N, N, K)) print("(%d, %d) and (%d, %d)." % (M, N, N, K))
print() print()
print("Total execution time: %.2fs on %s." % (t, impl)) print("Total execution time: {:.2f}s on {}.".format(t, impl))
print() print()
print( print(
"Try to run this script a few times. Experience shows that" "Try to run this script a few times. Experience shows that"
......
...@@ -27,7 +27,7 @@ for dir in dirs: ...@@ -27,7 +27,7 @@ for dir in dirs:
keys.setdefault(key, 0) keys.setdefault(key, 0)
keys[key] += 1 keys[key] += 1
del f del f
except IOError: except OSError:
# print dir, "don't have a key.pkl file" # print dir, "don't have a key.pkl file"
pass pass
try: try:
...@@ -41,9 +41,8 @@ for dir in dirs: ...@@ -41,9 +41,8 @@ for dir in dirs:
del mod del mod
del f del f
del path del path
except IOError: except OSError:
print(dir, "don't have a mod.{cpp,cu} file") print(dir, "don't have a mod.{cpp,cu} file")
pass
if DISPLAY_DUPLICATE_KEYS: if DISPLAY_DUPLICATE_KEYS:
for k, v in keys.items(): for k, v in keys.items():
......
...@@ -56,7 +56,7 @@ def main(dev1, dev2): ...@@ -56,7 +56,7 @@ def main(dev1, dev2):
t2 = time.time() t2 = time.time()
r = None r = None
print("one ctx async %f" % (t2 - t,)) print("one ctx async {:f}".format(t2 - t))
t = time.time() t = time.time()
r = f2.fn() r = f2.fn()
...@@ -64,7 +64,7 @@ def main(dev1, dev2): ...@@ -64,7 +64,7 @@ def main(dev1, dev2):
t2 = time.time() t2 = time.time()
r = None r = None
print("two ctx async %f" % (t2 - t,)) print("two ctx async {:f}".format(t2 - t))
t = time.time() t = time.time()
r = f3.fn() r = f3.fn()
...@@ -74,14 +74,14 @@ def main(dev1, dev2): ...@@ -74,14 +74,14 @@ def main(dev1, dev2):
t2 = time.time() t2 = time.time()
r = None r = None
print("two ctx, 2 fct async %f" % (t2 - t,)) print("two ctx, 2 fct async {:f}".format(t2 - t))
t = time.time() t = time.time()
r = f5.fn() r = f5.fn()
r2 = f6.fn() r2 = f6.fn()
t2 = time.time() t2 = time.time()
r = None r = None
print("two ctx, 2 fct with transfer %f" % (t2 - t,)) print("two ctx, 2 fct with transfer {:f}".format(t2 - t))
# Multi-thread version # Multi-thread version
class myThread(threading.Thread): class myThread(threading.Thread):
...@@ -110,7 +110,7 @@ def main(dev1, dev2): ...@@ -110,7 +110,7 @@ def main(dev1, dev2):
thread2.join() thread2.join()
t2 = time.time() t2 = time.time()
print("two ctx, 2 fct async, 2 threads %f" % (t2 - t,)) print("two ctx, 2 fct async, 2 threads {:f}".format(t2 - t))
thread1 = myThread("Thread-5", f5, False) thread1 = myThread("Thread-5", f5, False)
thread2 = myThread("Thread-6", f6, False) thread2 = myThread("Thread-6", f6, False)
...@@ -121,7 +121,7 @@ def main(dev1, dev2): ...@@ -121,7 +121,7 @@ def main(dev1, dev2):
thread2.join() thread2.join()
t2 = time.time() t2 = time.time()
print("two ctx, 2 fct with transfer, 2 threads %f" % (t2 - t,)) print("two ctx, 2 fct with transfer, 2 threads {:f}".format(t2 - t))
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -67,5 +67,5 @@ if __name__ == "__main__": ...@@ -67,5 +67,5 @@ if __name__ == "__main__":
(cheapTime, costlyTime) = ElemwiseOpTime(N=options.N, script=options.script) (cheapTime, costlyTime) = ElemwiseOpTime(N=options.N, script=options.script)
if options.script: if options.script:
sys.stdout.write("%2.9f %2.9f\n" % (cheapTime, costlyTime)) sys.stdout.write("{:2.9f} {:2.9f}\n".format(cheapTime, costlyTime))
sys.stdout.flush() sys.stdout.flush()
...@@ -36,7 +36,7 @@ class frozendict(Mapping): ...@@ -36,7 +36,7 @@ class frozendict(Mapping):
return len(self._dict) return len(self._dict)
def __repr__(self): def __repr__(self):
return "<%s %r>" % (self.__class__.__name__, self._dict) return "<{} {!r}>".format(self.__class__.__name__, self._dict)
def __hash__(self): def __hash__(self):
if self._hash is None: if self._hash is None:
......
...@@ -150,7 +150,7 @@ def is_python_file(filename): ...@@ -150,7 +150,7 @@ def is_python_file(filename):
def get_file_contents(filename, revision="tip"): def get_file_contents(filename, revision="tip"):
hg_out = run_mercurial_command("cat -r %s %s" % (revision, filename)) hg_out = run_mercurial_command("cat -r {} {}".format(revision, filename))
return hg_out return hg_out
...@@ -169,7 +169,7 @@ def save_diffs(diffs, filename): ...@@ -169,7 +169,7 @@ def save_diffs(diffs, filename):
def should_skip_commit(): def should_skip_commit():
if not os.path.exists(SKIP_WHITESPACE_CHECK_FILENAME): if not os.path.exists(SKIP_WHITESPACE_CHECK_FILENAME):
return False return False
with open(SKIP_WHITESPACE_CHECK_FILENAME, "r") as whitespace_check_file: with open(SKIP_WHITESPACE_CHECK_FILENAME) as whitespace_check_file:
whitespace_check_changeset = whitespace_check_file.read() whitespace_check_changeset = whitespace_check_file.read()
return whitespace_check_changeset == parent_commit() return whitespace_check_changeset == parent_commit()
...@@ -253,7 +253,8 @@ def main(argv=None): ...@@ -253,7 +253,8 @@ def main(argv=None):
parse_error = get_parse_error(code) parse_error = get_parse_error(code)
if parse_error is not None: if parse_error is not None:
print( print(
"*** %s has parse error: %s" % (filename, parse_error), file=sys.stderr "*** {} has parse error: {}".format(filename, parse_error),
file=sys.stderr,
) )
block_commit = True block_commit = True
else: else:
......
...@@ -120,8 +120,8 @@ def check(file): ...@@ -120,8 +120,8 @@ def check(file):
print("checking", file, "...", end=" ") print("checking", file, "...", end=" ")
try: try:
f = open(file) f = open(file)
except IOError as msg: except OSError as msg:
errprint("%s: I/O Error: %s" % (file, str(msg))) errprint("{}: I/O Error: {}".format(file, str(msg)))
return return
r = Reindenter(f) r = Reindenter(f)
......
...@@ -2,8 +2,6 @@ import types ...@@ -2,8 +2,6 @@ import types
import weakref import weakref
from collections.abc import MutableSet from collections.abc import MutableSet
from six import string_types
def check_deterministic(iterable): def check_deterministic(iterable):
# Most places where OrderedSet is used, theano interprets any exception # Most places where OrderedSet is used, theano interprets any exception
...@@ -14,7 +12,7 @@ def check_deterministic(iterable): ...@@ -14,7 +12,7 @@ def check_deterministic(iterable):
# theano to use exceptions correctly, so that this can be a TypeError. # theano to use exceptions correctly, so that this can be a TypeError.
if iterable is not None: if iterable is not None:
if not isinstance( if not isinstance(
iterable, (list, tuple, OrderedSet, types.GeneratorType, string_types) iterable, (list, tuple, OrderedSet, types.GeneratorType, str)
): ):
if len(iterable) > 1: if len(iterable) > 1:
# We need to accept length 1 size to allow unpickle in tests. # We need to accept length 1 size to allow unpickle in tests.
...@@ -45,7 +43,7 @@ def check_deterministic(iterable): ...@@ -45,7 +43,7 @@ def check_deterministic(iterable):
# {{{ http://code.activestate.com/recipes/576696/ (r5) # {{{ http://code.activestate.com/recipes/576696/ (r5)
class Link(object): class Link:
# This make that we need to use a different pickle protocol # This make that we need to use a different pickle protocol
# then the default. Othewise, there is pickling errors # then the default. Othewise, there is pickling errors
__slots__ = "prev", "next", "key", "__weakref__" __slots__ = "prev", "next", "key", "__weakref__"
...@@ -173,8 +171,8 @@ class OrderedSet(MutableSet): ...@@ -173,8 +171,8 @@ class OrderedSet(MutableSet):
def __repr__(self): def __repr__(self):
if not self: if not self:
return "%s()" % (self.__class__.__name__,) return "{}()".format(self.__class__.__name__)
return "%s(%r)" % (self.__class__.__name__, list(self)) return "{}({!r})".format(self.__class__.__name__, list(self))
def __eq__(self, other): def __eq__(self, other):
# Note that we implement only the comparison to another # Note that we implement only the comparison to another
......
...@@ -82,7 +82,7 @@ class StripPickler(Pickler): ...@@ -82,7 +82,7 @@ class StripPickler(Pickler):
return Pickler.save(self, obj) return Pickler.save(self, obj)
class PersistentNdarrayID(object): class PersistentNdarrayID:
"""Persist ndarrays in an object by saving them to a zip file. """Persist ndarrays in an object by saving them to a zip file.
:param zip_file: A zip file handle that the NumPy arrays will be saved to. :param zip_file: A zip file handle that the NumPy arrays will be saved to.
...@@ -104,7 +104,7 @@ class PersistentNdarrayID(object): ...@@ -104,7 +104,7 @@ class PersistentNdarrayID(object):
def _resolve_name(self, obj): def _resolve_name(self, obj):
"""Determine the name the object should be saved under.""" """Determine the name the object should be saved under."""
name = "array_{0}".format(self.count) name = "array_{}".format(self.count)
self.count += 1 self.count += 1
return name return name
...@@ -117,7 +117,7 @@ class PersistentNdarrayID(object): ...@@ -117,7 +117,7 @@ class PersistentNdarrayID(object):
name = self._resolve_name(obj) name = self._resolve_name(obj)
zipadd(write_array, self.zip_file, name) zipadd(write_array, self.zip_file, name)
self.seen[id(obj)] = "ndarray.{0}".format(name) self.seen[id(obj)] = "ndarray.{}".format(name)
return self.seen[id(obj)] return self.seen[id(obj)]
...@@ -139,9 +139,9 @@ class PersistentGpuArrayID(PersistentNdarrayID): ...@@ -139,9 +139,9 @@ class PersistentGpuArrayID(PersistentNdarrayID):
name = self._resolve_name(obj) name = self._resolve_name(obj)
zipadd(write_array, self.zip_file, name) zipadd(write_array, self.zip_file, name)
self.seen[id(obj)] = "gpuarray.{0}".format(name) self.seen[id(obj)] = "gpuarray.{}".format(name)
return self.seen[id(obj)] return self.seen[id(obj)]
return super(PersistentGpuArrayID, self).__call__(obj) return super().__call__(obj)
class PersistentSharedVariableID(PersistentGpuArrayID): class PersistentSharedVariableID(PersistentGpuArrayID):
...@@ -169,7 +169,7 @@ class PersistentSharedVariableID(PersistentGpuArrayID): ...@@ -169,7 +169,7 @@ class PersistentSharedVariableID(PersistentGpuArrayID):
""" """
def __init__(self, zip_file, allow_unnamed=True, allow_duplicates=True): def __init__(self, zip_file, allow_unnamed=True, allow_duplicates=True):
super(PersistentSharedVariableID, self).__init__(zip_file) super().__init__(zip_file)
self.name_counter = defaultdict(int) self.name_counter = defaultdict(int)
self.ndarray_names = {} self.ndarray_names = {}
self.allow_unnamed = allow_unnamed self.allow_unnamed = allow_unnamed
...@@ -184,11 +184,11 @@ class PersistentSharedVariableID(PersistentGpuArrayID): ...@@ -184,11 +184,11 @@ class PersistentSharedVariableID(PersistentGpuArrayID):
if not self.allow_duplicates: if not self.allow_duplicates:
raise ValueError( raise ValueError(
"multiple shared variables with the name " "multiple shared variables with the name "
"`{0}` found".format(name) "`{}` found".format(name)
) )
name = "{0}_{1}".format(name, count + 1) name = "{}_{}".format(name, count + 1)
return name return name
return super(PersistentSharedVariableID, self)._resolve_name(obj) return super()._resolve_name(obj)
def __call__(self, obj): def __call__(self, obj):
if isinstance(obj, SharedVariable): if isinstance(obj, SharedVariable):
...@@ -197,11 +197,11 @@ class PersistentSharedVariableID(PersistentGpuArrayID): ...@@ -197,11 +197,11 @@ class PersistentSharedVariableID(PersistentGpuArrayID):
ValueError("can't pickle shared variable with name `pkl`") ValueError("can't pickle shared variable with name `pkl`")
self.ndarray_names[id(obj.container.storage[0])] = obj.name self.ndarray_names[id(obj.container.storage[0])] = obj.name
elif not self.allow_unnamed: elif not self.allow_unnamed:
raise ValueError("unnamed shared variable, {0}".format(obj)) raise ValueError("unnamed shared variable, {}".format(obj))
return super(PersistentSharedVariableID, self).__call__(obj) return super().__call__(obj)
class PersistentNdarrayLoad(object): class PersistentNdarrayLoad:
"""Load NumPy arrays that were persisted to a zip file when pickling. """Load NumPy arrays that were persisted to a zip file when pickling.
:param zip_file: The zip file handle in which the NumPy arrays are saved. :param zip_file: The zip file handle in which the NumPy arrays are saved.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论