提交 d69eaabd authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Use context managers with open

上级 b4912d97
...@@ -33,14 +33,13 @@ def cleanup(): ...@@ -33,14 +33,13 @@ def cleanup():
""" """
compiledir = config.compiledir compiledir = config.compiledir
for directory in os.listdir(compiledir): for directory in os.listdir(compiledir):
file = None
try:
try: try:
filename = os.path.join(compiledir, directory, "key.pkl") filename = os.path.join(compiledir, directory, "key.pkl")
file = open(filename, "rb")
# print file # print file
with open(filename, "rb") as file:
try: try:
keydata = pickle.load(file) keydata = pickle.load(file)
for key in list(keydata.keys): for key in list(keydata.keys):
have_npy_abi_version = False have_npy_abi_version = False
have_c_compiler = False have_c_compiler = False
...@@ -91,9 +90,6 @@ def cleanup(): ...@@ -91,9 +90,6 @@ def cleanup():
f"Could not clean up this directory: '{directory}'. To complete " f"Could not clean up this directory: '{directory}'. To complete "
"the clean-up, please remove it manually." "the clean-up, please remove it manually."
) )
finally:
if file is not None:
file.close()
def print_title(title, overline="", underline=""): def print_title(title, overline="", underline=""):
......
...@@ -15,6 +15,7 @@ import operator ...@@ -15,6 +15,7 @@ import operator
import sys import sys
import time import time
from collections import defaultdict from collections import defaultdict
from contextlib import contextmanager
from typing import Dict, List from typing import Dict, List
import numpy as np import numpy as np
...@@ -25,6 +26,17 @@ from aesara.graph.basic import Constant, Variable ...@@ -25,6 +26,17 @@ from aesara.graph.basic import Constant, Variable
from aesara.link.utils import get_destroy_dependencies from aesara.link.utils import get_destroy_dependencies
@contextmanager
def extended_open(filename, mode="r"):
if filename == "<stdout>":
yield sys.stdout
elif filename == "<stderr>":
yield sys.stderr
else:
with open(filename, mode=mode) as f:
yield f
logger = logging.getLogger("aesara.compile.profiling") logger = logging.getLogger("aesara.compile.profiling")
aesara_imported_time = time.time() aesara_imported_time = time.time()
...@@ -37,19 +49,18 @@ _atexit_registered = False ...@@ -37,19 +49,18 @@ _atexit_registered = False
def _atexit_print_fn(): def _atexit_print_fn():
""" """Print `ProfileStat` objects in `_atexit_print_list` to `_atexit_print_file`."""
Print ProfileStat objects in _atexit_print_list to _atexit_print_file.
"""
if config.profile: if config.profile:
to_sum = [] to_sum = []
if config.profiling__destination == "stderr": if config.profiling__destination == "stderr":
destination_file = sys.stderr destination_file = "<stderr>"
elif config.profiling__destination == "stdout": elif config.profiling__destination == "stdout":
destination_file = sys.stdout destination_file = "<stdout>"
else: else:
destination_file = open(config.profiling__destination, "w") destination_file = config.profiling__destination
with extended_open(destination_file, mode="w"):
# Reverse sort in the order of compile+exec time # Reverse sort in the order of compile+exec time
for ps in sorted( for ps in sorted(
...@@ -139,12 +150,13 @@ def print_global_stats(): ...@@ -139,12 +150,13 @@ def print_global_stats():
""" """
if config.profiling__destination == "stderr": if config.profiling__destination == "stderr":
destination_file = sys.stderr destination_file = "<stderr>"
elif config.profiling__destination == "stdout": elif config.profiling__destination == "stdout":
destination_file = sys.stdout destination_file = "<stdout>"
else: else:
destination_file = open(config.profiling__destination, "w") destination_file = config.profiling__destination
with extended_open(destination_file, mode="w"):
print("=" * 50, file=destination_file) print("=" * 50, file=destination_file)
print( print(
( (
......
...@@ -1300,7 +1300,8 @@ def _filter_compiledir(path): ...@@ -1300,7 +1300,8 @@ def _filter_compiledir(path):
init_file = os.path.join(path, "__init__.py") init_file = os.path.join(path, "__init__.py")
if not os.path.exists(init_file): if not os.path.exists(init_file):
try: try:
open(init_file, "w").close() with open(init_file, "w"):
pass
except OSError as e: except OSError as e:
if os.path.exists(init_file): if os.path.exists(init_file):
pass # has already been created pass # has already been created
......
...@@ -1008,7 +1008,7 @@ class ModuleCache: ...@@ -1008,7 +1008,7 @@ class ModuleCache:
entry = key_data.get_entry() entry = key_data.get_entry()
try: try:
# Test to see that the file is [present and] readable. # Test to see that the file is [present and] readable.
open(entry).close() with open(entry):
gone = False gone = False
except OSError: except OSError:
gone = True gone = True
...@@ -1505,7 +1505,7 @@ class ModuleCache: ...@@ -1505,7 +1505,7 @@ class ModuleCache:
if filename.startswith("tmp"): if filename.startswith("tmp"):
try: try:
fname = os.path.join(self.dirname, filename, "key.pkl") fname = os.path.join(self.dirname, filename, "key.pkl")
open(fname).close() with open(fname):
has_key = True has_key = True
except OSError: except OSError:
has_key = False has_key = False
...@@ -1599,7 +1599,8 @@ def _rmtree( ...@@ -1599,7 +1599,8 @@ def _rmtree(
if os.path.exists(parent): if os.path.exists(parent):
try: try:
_logger.info(f'placing "delete.me" in {parent}') _logger.info(f'placing "delete.me" in {parent}')
open(os.path.join(parent, "delete.me"), "w").close() with open(os.path.join(parent, "delete.me"), "w"):
pass
except Exception as ee: except Exception as ee:
_logger.warning( _logger.warning(
f"Failed to remove or mark cache directory {parent} for removal {ee}" f"Failed to remove or mark cache directory {parent} for removal {ee}"
...@@ -2641,7 +2642,8 @@ class GCC_compiler(Compiler): ...@@ -2641,7 +2642,8 @@ class GCC_compiler(Compiler):
if py_module: if py_module:
# touch the __init__ file # touch the __init__ file
open(os.path.join(location, "__init__.py"), "w").close() with open(os.path.join(location, "__init__.py"), "w"):
pass
assert os.path.isfile(lib_filename) assert os.path.isfile(lib_filename)
return dlimport(lib_filename) return dlimport(lib_filename)
......
...@@ -96,7 +96,8 @@ try: ...@@ -96,7 +96,8 @@ try:
assert e.errno == errno.EEXIST assert e.errno == errno.EEXIST
assert os.path.exists(location), location assert os.path.exists(location), location
if not os.path.exists(os.path.join(location, "__init__.py")): if not os.path.exists(os.path.join(location, "__init__.py")):
open(os.path.join(location, "__init__.py"), "w").close() with open(os.path.join(location, "__init__.py"), "w"):
pass
try: try:
from cutils_ext.cutils_ext import * # noqa from cutils_ext.cutils_ext import * # noqa
......
...@@ -59,7 +59,8 @@ try: ...@@ -59,7 +59,8 @@ try:
init_file = os.path.join(location, "__init__.py") init_file = os.path.join(location, "__init__.py")
if not os.path.exists(init_file): if not os.path.exists(init_file):
try: try:
open(init_file, "w").close() with open(init_file, "w"):
pass
except OSError as e: except OSError as e:
if os.path.exists(init_file): if os.path.exists(init_file):
pass # has already been created pass # has already been created
...@@ -126,10 +127,12 @@ except ImportError: ...@@ -126,10 +127,12 @@ except ImportError:
"code generation." "code generation."
) )
raise ImportError("The file lazylinker_c.c is not available.") raise ImportError("The file lazylinker_c.c is not available.")
code = open(cfile).read()
with open(cfile) as f:
code = f.read()
loc = os.path.join(config.compiledir, dirname) loc = os.path.join(config.compiledir, dirname)
if not os.path.exists(loc): if not os.path.exists(loc):
try: try:
os.mkdir(loc) os.mkdir(loc)
except OSError as e: except OSError as e:
...@@ -140,14 +143,17 @@ except ImportError: ...@@ -140,14 +143,17 @@ except ImportError:
GCC_compiler.compile_str(dirname, code, location=loc, preargs=args) GCC_compiler.compile_str(dirname, code, location=loc, preargs=args)
# Save version into the __init__.py file. # Save version into the __init__.py file.
init_py = os.path.join(loc, "__init__.py") init_py = os.path.join(loc, "__init__.py")
with open(init_py, "w") as f: with open(init_py, "w") as f:
f.write(f"_version = {version}\n") f.write(f"_version = {version}\n")
# If we just compiled the module for the first time, then it was # If we just compiled the module for the first time, then it was
# imported at the same time: we need to make sure we do not # imported at the same time: we need to make sure we do not
# reload the now outdated __init__.pyc below. # reload the now outdated __init__.pyc below.
init_pyc = os.path.join(loc, "__init__.pyc") init_pyc = os.path.join(loc, "__init__.pyc")
if os.path.isfile(init_pyc): if os.path.isfile(init_pyc):
os.remove(init_pyc) os.remove(init_pyc)
try_import() try_import()
try_reload() try_reload()
from lazylinker_ext import lazylinker_ext as lazy_c from lazylinker_ext import lazylinker_ext as lazy_c
......
...@@ -42,21 +42,19 @@ Pickler = pickle.Pickler ...@@ -42,21 +42,19 @@ Pickler = pickle.Pickler
class StripPickler(Pickler): class StripPickler(Pickler):
""" """Subclass of `Pickler` that strips unnecessary attributes from Aesara objects.
Subclass of Pickler that strips unnecessary attributes from Aesara objects.
.. versionadded:: 0.8
Example of use:: Example
-------
fn_args = dict(inputs=inputs, fn_args = dict(inputs=inputs,
outputs=outputs, outputs=outputs,
updates=updates) updates=updates)
dest_pkl = 'my_test.pkl' dest_pkl = 'my_test.pkl'
f = open(dest_pkl, 'wb') with open(dest_pkl, 'wb') as f:
strip_pickler = StripPickler(f, protocol=-1) strip_pickler = StripPickler(f, protocol=-1)
strip_pickler.dump(fn_args) strip_pickler.dump(fn_args)
f.close()
""" """
def __init__(self, file, protocol=0, extra_tag_to_remove=None): def __init__(self, file, protocol=0, extra_tag_to_remove=None):
......
...@@ -118,7 +118,8 @@ For instance, you can define functions along the lines of: ...@@ -118,7 +118,8 @@ For instance, you can define functions along the lines of:
def __setstate__(self, d): def __setstate__(self, d):
self.__dict__.update(d) self.__dict__.update(d)
self.training_set = cPickle.load(open(self.training_set_file, 'rb')) with open(self.training_set_file, 'rb') as f:
self.training_set = cPickle.load(f)
Robust Serialization Robust Serialization
......
...@@ -66,6 +66,6 @@ class TestStripPickler: ...@@ -66,6 +66,6 @@ class TestStripPickler:
with open("test.pkl", "wb") as f: with open("test.pkl", "wb") as f:
m = matrix() m = matrix()
dest_pkl = "my_test.pkl" dest_pkl = "my_test.pkl"
f = open(dest_pkl, "wb") with open(dest_pkl, "wb") as f:
strip_pickler = StripPickler(f, protocol=-1) strip_pickler = StripPickler(f, protocol=-1)
strip_pickler.dump(m) strip_pickler.dump(m)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论