提交 d69eaabd authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Use context managers with open

上级 b4912d97
......@@ -33,14 +33,13 @@ def cleanup():
"""
compiledir = config.compiledir
for directory in os.listdir(compiledir):
file = None
try:
try:
filename = os.path.join(compiledir, directory, "key.pkl")
file = open(filename, "rb")
# print file
with open(filename, "rb") as file:
try:
keydata = pickle.load(file)
for key in list(keydata.keys):
have_npy_abi_version = False
have_c_compiler = False
......@@ -91,9 +90,6 @@ def cleanup():
f"Could not clean up this directory: '{directory}'. To complete "
"the clean-up, please remove it manually."
)
finally:
if file is not None:
file.close()
def print_title(title, overline="", underline=""):
......
......@@ -15,6 +15,7 @@ import operator
import sys
import time
from collections import defaultdict
from contextlib import contextmanager
from typing import Dict, List
import numpy as np
......@@ -25,6 +26,17 @@ from aesara.graph.basic import Constant, Variable
from aesara.link.utils import get_destroy_dependencies
@contextmanager
def extended_open(filename, mode="r"):
if filename == "<stdout>":
yield sys.stdout
elif filename == "<stderr>":
yield sys.stderr
else:
with open(filename, mode=mode) as f:
yield f
logger = logging.getLogger("aesara.compile.profiling")
aesara_imported_time = time.time()
......@@ -37,19 +49,18 @@ _atexit_registered = False
def _atexit_print_fn():
"""
Print ProfileStat objects in _atexit_print_list to _atexit_print_file.
"""
"""Print `ProfileStat` objects in `_atexit_print_list` to `_atexit_print_file`."""
if config.profile:
to_sum = []
if config.profiling__destination == "stderr":
destination_file = sys.stderr
destination_file = "<stderr>"
elif config.profiling__destination == "stdout":
destination_file = sys.stdout
destination_file = "<stdout>"
else:
destination_file = open(config.profiling__destination, "w")
destination_file = config.profiling__destination
with extended_open(destination_file, mode="w"):
# Reverse sort in the order of compile+exec time
for ps in sorted(
......@@ -139,12 +150,13 @@ def print_global_stats():
"""
if config.profiling__destination == "stderr":
destination_file = sys.stderr
destination_file = "<stderr>"
elif config.profiling__destination == "stdout":
destination_file = sys.stdout
destination_file = "<stdout>"
else:
destination_file = open(config.profiling__destination, "w")
destination_file = config.profiling__destination
with extended_open(destination_file, mode="w"):
print("=" * 50, file=destination_file)
print(
(
......
......@@ -1300,7 +1300,8 @@ def _filter_compiledir(path):
init_file = os.path.join(path, "__init__.py")
if not os.path.exists(init_file):
try:
open(init_file, "w").close()
with open(init_file, "w"):
pass
except OSError as e:
if os.path.exists(init_file):
pass # has already been created
......
......@@ -1008,7 +1008,7 @@ class ModuleCache:
entry = key_data.get_entry()
try:
# Test to see that the file is [present and] readable.
open(entry).close()
with open(entry):
gone = False
except OSError:
gone = True
......@@ -1505,7 +1505,7 @@ class ModuleCache:
if filename.startswith("tmp"):
try:
fname = os.path.join(self.dirname, filename, "key.pkl")
open(fname).close()
with open(fname):
has_key = True
except OSError:
has_key = False
......@@ -1599,7 +1599,8 @@ def _rmtree(
if os.path.exists(parent):
try:
_logger.info(f'placing "delete.me" in {parent}')
open(os.path.join(parent, "delete.me"), "w").close()
with open(os.path.join(parent, "delete.me"), "w"):
pass
except Exception as ee:
_logger.warning(
f"Failed to remove or mark cache directory {parent} for removal {ee}"
......@@ -2641,7 +2642,8 @@ class GCC_compiler(Compiler):
if py_module:
# touch the __init__ file
open(os.path.join(location, "__init__.py"), "w").close()
with open(os.path.join(location, "__init__.py"), "w"):
pass
assert os.path.isfile(lib_filename)
return dlimport(lib_filename)
......
......@@ -96,7 +96,8 @@ try:
assert e.errno == errno.EEXIST
assert os.path.exists(location), location
if not os.path.exists(os.path.join(location, "__init__.py")):
open(os.path.join(location, "__init__.py"), "w").close()
with open(os.path.join(location, "__init__.py"), "w"):
pass
try:
from cutils_ext.cutils_ext import * # noqa
......
......@@ -59,7 +59,8 @@ try:
init_file = os.path.join(location, "__init__.py")
if not os.path.exists(init_file):
try:
open(init_file, "w").close()
with open(init_file, "w"):
pass
except OSError as e:
if os.path.exists(init_file):
pass # has already been created
......@@ -126,10 +127,12 @@ except ImportError:
"code generation."
)
raise ImportError("The file lazylinker_c.c is not available.")
code = open(cfile).read()
with open(cfile) as f:
code = f.read()
loc = os.path.join(config.compiledir, dirname)
if not os.path.exists(loc):
try:
os.mkdir(loc)
except OSError as e:
......@@ -140,14 +143,17 @@ except ImportError:
GCC_compiler.compile_str(dirname, code, location=loc, preargs=args)
# Save version into the __init__.py file.
init_py = os.path.join(loc, "__init__.py")
with open(init_py, "w") as f:
f.write(f"_version = {version}\n")
# If we just compiled the module for the first time, then it was
# imported at the same time: we need to make sure we do not
# reload the now outdated __init__.pyc below.
init_pyc = os.path.join(loc, "__init__.pyc")
if os.path.isfile(init_pyc):
os.remove(init_pyc)
try_import()
try_reload()
from lazylinker_ext import lazylinker_ext as lazy_c
......
......@@ -42,21 +42,19 @@ Pickler = pickle.Pickler
class StripPickler(Pickler):
"""
Subclass of Pickler that strips unnecessary attributes from Aesara objects.
.. versionadded:: 0.8
"""Subclass of `Pickler` that strips unnecessary attributes from Aesara objects.
Example of use::
Example
-------
fn_args = dict(inputs=inputs,
outputs=outputs,
updates=updates)
dest_pkl = 'my_test.pkl'
f = open(dest_pkl, 'wb')
with open(dest_pkl, 'wb') as f:
strip_pickler = StripPickler(f, protocol=-1)
strip_pickler.dump(fn_args)
f.close()
"""
def __init__(self, file, protocol=0, extra_tag_to_remove=None):
......
......@@ -118,7 +118,8 @@ For instance, you can define functions along the lines of:
def __setstate__(self, d):
self.__dict__.update(d)
self.training_set = cPickle.load(open(self.training_set_file, 'rb'))
with open(self.training_set_file, 'rb') as f:
self.training_set = cPickle.load(f)
Robust Serialization
......
......@@ -66,6 +66,6 @@ class TestStripPickler:
with open("test.pkl", "wb") as f:
m = matrix()
dest_pkl = "my_test.pkl"
f = open(dest_pkl, "wb")
with open(dest_pkl, "wb") as f:
strip_pickler = StripPickler(f, protocol=-1)
strip_pickler.dump(m)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论