提交 d69eaabd authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Use context managers with open

上级 b4912d97
......@@ -33,14 +33,13 @@ def cleanup():
"""
compiledir = config.compiledir
for directory in os.listdir(compiledir):
file = None
try:
try:
filename = os.path.join(compiledir, directory, "key.pkl")
file = open(filename, "rb")
# print file
filename = os.path.join(compiledir, directory, "key.pkl")
# print file
with open(filename, "rb") as file:
try:
keydata = pickle.load(file)
for key in list(keydata.keys):
have_npy_abi_version = False
have_c_compiler = False
......@@ -86,14 +85,11 @@ def cleanup():
"the clean-up, please remove manually "
"the directory containing it."
)
except OSError:
_logger.error(
f"Could not clean up this directory: '{directory}'. To complete "
"the clean-up, please remove it manually."
)
finally:
if file is not None:
file.close()
except OSError:
_logger.error(
f"Could not clean up this directory: '{directory}'. To complete "
"the clean-up, please remove it manually."
)
def print_title(title, overline="", underline=""):
......
......@@ -15,6 +15,7 @@ import operator
import sys
import time
from collections import defaultdict
from contextlib import contextmanager
from typing import Dict, List
import numpy as np
......@@ -25,6 +26,17 @@ from aesara.graph.basic import Constant, Variable
from aesara.link.utils import get_destroy_dependencies
@contextmanager
def extended_open(filename, mode="r"):
if filename == "<stdout>":
yield sys.stdout
elif filename == "<stderr>":
yield sys.stderr
else:
with open(filename, mode=mode) as f:
yield f
logger = logging.getLogger("aesara.compile.profiling")
aesara_imported_time = time.time()
......@@ -37,93 +49,92 @@ _atexit_registered = False
def _atexit_print_fn():
"""
Print ProfileStat objects in _atexit_print_list to _atexit_print_file.
"""
"""Print `ProfileStat` objects in `_atexit_print_list` to `_atexit_print_file`."""
if config.profile:
to_sum = []
if config.profiling__destination == "stderr":
destination_file = sys.stderr
destination_file = "<stderr>"
elif config.profiling__destination == "stdout":
destination_file = sys.stdout
destination_file = "<stdout>"
else:
destination_file = open(config.profiling__destination, "w")
# Reverse sort in the order of compile+exec time
for ps in sorted(
_atexit_print_list, key=lambda a: a.compile_time + a.fct_call_time
)[::-1]:
if (
ps.fct_callcount >= 1
or ps.compile_time > 1
or getattr(ps, "callcount", 0) > 1
):
ps.summary(
destination_file = config.profiling__destination
with extended_open(destination_file, mode="w"):
# Reverse sort in the order of compile+exec time
for ps in sorted(
_atexit_print_list, key=lambda a: a.compile_time + a.fct_call_time
)[::-1]:
if (
ps.fct_callcount >= 1
or ps.compile_time > 1
or getattr(ps, "callcount", 0) > 1
):
ps.summary(
file=destination_file,
n_ops_to_print=config.profiling__n_ops,
n_apply_to_print=config.profiling__n_apply,
)
if ps.show_sum:
to_sum.append(ps)
else:
# TODO print the name if there is one!
print("Skipping empty Profile")
if len(to_sum) > 1:
# Make a global profile
cum = copy.copy(to_sum[0])
msg = f"Sum of all({len(to_sum)}) printed profiles at exit."
cum.message = msg
for ps in to_sum[1:]:
for attr in [
"compile_time",
"fct_call_time",
"fct_callcount",
"vm_call_time",
"optimizer_time",
"linker_time",
"validate_time",
"import_time",
"linker_node_make_thunks",
]:
setattr(cum, attr, getattr(cum, attr) + getattr(ps, attr))
# merge dictionary
for attr in [
"apply_time",
"apply_callcount",
"apply_cimpl",
"variable_shape",
"variable_strides",
"variable_offset",
"linker_make_thunk_time",
]:
cum_attr = getattr(cum, attr)
for key, val in getattr(ps, attr.items()):
assert key not in cum_attr, (key, cum_attr)
cum_attr[key] = val
if cum.optimizer_profile and ps.optimizer_profile:
try:
merge = cum.optimizer_profile[0].merge_profile(
cum.optimizer_profile[1], ps.optimizer_profile[1]
)
assert len(merge) == len(cum.optimizer_profile[1])
cum.optimizer_profile = (cum.optimizer_profile[0], merge)
except Exception as e:
print(e)
cum.optimizer_profile = None
else:
cum.optimizer_profile = None
cum.summary(
file=destination_file,
n_ops_to_print=config.profiling__n_ops,
n_apply_to_print=config.profiling__n_apply,
)
if ps.show_sum:
to_sum.append(ps)
else:
# TODO print the name if there is one!
print("Skipping empty Profile")
if len(to_sum) > 1:
# Make a global profile
cum = copy.copy(to_sum[0])
msg = f"Sum of all({len(to_sum)}) printed profiles at exit."
cum.message = msg
for ps in to_sum[1:]:
for attr in [
"compile_time",
"fct_call_time",
"fct_callcount",
"vm_call_time",
"optimizer_time",
"linker_time",
"validate_time",
"import_time",
"linker_node_make_thunks",
]:
setattr(cum, attr, getattr(cum, attr) + getattr(ps, attr))
# merge dictionary
for attr in [
"apply_time",
"apply_callcount",
"apply_cimpl",
"variable_shape",
"variable_strides",
"variable_offset",
"linker_make_thunk_time",
]:
cum_attr = getattr(cum, attr)
for key, val in getattr(ps, attr.items()):
assert key not in cum_attr, (key, cum_attr)
cum_attr[key] = val
if cum.optimizer_profile and ps.optimizer_profile:
try:
merge = cum.optimizer_profile[0].merge_profile(
cum.optimizer_profile[1], ps.optimizer_profile[1]
)
assert len(merge) == len(cum.optimizer_profile[1])
cum.optimizer_profile = (cum.optimizer_profile[0], merge)
except Exception as e:
print(e)
cum.optimizer_profile = None
else:
cum.optimizer_profile = None
cum.summary(
file=destination_file,
n_ops_to_print=config.profiling__n_ops,
n_apply_to_print=config.profiling__n_apply,
)
if config.print_global_stats:
print_global_stats()
......@@ -139,24 +150,25 @@ def print_global_stats():
"""
if config.profiling__destination == "stderr":
destination_file = sys.stderr
destination_file = "<stderr>"
elif config.profiling__destination == "stdout":
destination_file = sys.stdout
destination_file = "<stdout>"
else:
destination_file = open(config.profiling__destination, "w")
print("=" * 50, file=destination_file)
print(
(
"Global stats: ",
f"Time elasped since Aesara import = {time.time() - aesara_imported_time:6.3f}s, "
f"Time spent in Aesara functions = {total_fct_exec_time:6.3f}s, "
"Time spent compiling Aesara functions: "
f" optimization = {total_graph_opt_time:6.3f}s, linker = {total_time_linker:6.3f}s ",
),
file=destination_file,
)
print("=" * 50, file=destination_file)
destination_file = config.profiling__destination
with extended_open(destination_file, mode="w"):
print("=" * 50, file=destination_file)
print(
(
"Global stats: ",
f"Time elasped since Aesara import = {time.time() - aesara_imported_time:6.3f}s, "
f"Time spent in Aesara functions = {total_fct_exec_time:6.3f}s, "
"Time spent compiling Aesara functions: "
f" optimization = {total_graph_opt_time:6.3f}s, linker = {total_time_linker:6.3f}s ",
),
file=destination_file,
)
print("=" * 50, file=destination_file)
_profiler_printers = []
......
......@@ -1300,7 +1300,8 @@ def _filter_compiledir(path):
init_file = os.path.join(path, "__init__.py")
if not os.path.exists(init_file):
try:
open(init_file, "w").close()
with open(init_file, "w"):
pass
except OSError as e:
if os.path.exists(init_file):
pass # has already been created
......
......@@ -1008,8 +1008,8 @@ class ModuleCache:
entry = key_data.get_entry()
try:
# Test to see that the file is [present and] readable.
open(entry).close()
gone = False
with open(entry):
gone = False
except OSError:
gone = True
......@@ -1505,8 +1505,8 @@ class ModuleCache:
if filename.startswith("tmp"):
try:
fname = os.path.join(self.dirname, filename, "key.pkl")
open(fname).close()
has_key = True
with open(fname):
has_key = True
except OSError:
has_key = False
if not has_key:
......@@ -1599,7 +1599,8 @@ def _rmtree(
if os.path.exists(parent):
try:
_logger.info(f'placing "delete.me" in {parent}')
open(os.path.join(parent, "delete.me"), "w").close()
with open(os.path.join(parent, "delete.me"), "w"):
pass
except Exception as ee:
_logger.warning(
f"Failed to remove or mark cache directory {parent} for removal {ee}"
......@@ -2641,7 +2642,8 @@ class GCC_compiler(Compiler):
if py_module:
# touch the __init__ file
open(os.path.join(location, "__init__.py"), "w").close()
with open(os.path.join(location, "__init__.py"), "w"):
pass
assert os.path.isfile(lib_filename)
return dlimport(lib_filename)
......
......@@ -96,7 +96,8 @@ try:
assert e.errno == errno.EEXIST
assert os.path.exists(location), location
if not os.path.exists(os.path.join(location, "__init__.py")):
open(os.path.join(location, "__init__.py"), "w").close()
with open(os.path.join(location, "__init__.py"), "w"):
pass
try:
from cutils_ext.cutils_ext import * # noqa
......
......@@ -59,7 +59,8 @@ try:
init_file = os.path.join(location, "__init__.py")
if not os.path.exists(init_file):
try:
open(init_file, "w").close()
with open(init_file, "w"):
pass
except OSError as e:
if os.path.exists(init_file):
pass # has already been created
......@@ -126,10 +127,12 @@ except ImportError:
"code generation."
)
raise ImportError("The file lazylinker_c.c is not available.")
code = open(cfile).read()
with open(cfile) as f:
code = f.read()
loc = os.path.join(config.compiledir, dirname)
if not os.path.exists(loc):
try:
os.mkdir(loc)
except OSError as e:
......@@ -140,14 +143,17 @@ except ImportError:
GCC_compiler.compile_str(dirname, code, location=loc, preargs=args)
# Save version into the __init__.py file.
init_py = os.path.join(loc, "__init__.py")
with open(init_py, "w") as f:
f.write(f"_version = {version}\n")
# If we just compiled the module for the first time, then it was
# imported at the same time: we need to make sure we do not
# reload the now outdated __init__.pyc below.
init_pyc = os.path.join(loc, "__init__.pyc")
if os.path.isfile(init_pyc):
os.remove(init_pyc)
try_import()
try_reload()
from lazylinker_ext import lazylinker_ext as lazy_c
......
......@@ -42,21 +42,19 @@ Pickler = pickle.Pickler
class StripPickler(Pickler):
"""
Subclass of Pickler that strips unnecessary attributes from Aesara objects.
.. versionadded:: 0.8
"""Subclass of `Pickler` that strips unnecessary attributes from Aesara objects.
Example of use::
Example
-------
fn_args = dict(inputs=inputs,
outputs=outputs,
updates=updates)
dest_pkl = 'my_test.pkl'
f = open(dest_pkl, 'wb')
strip_pickler = StripPickler(f, protocol=-1)
strip_pickler.dump(fn_args)
f.close()
with open(dest_pkl, 'wb') as f:
strip_pickler = StripPickler(f, protocol=-1)
strip_pickler.dump(fn_args)
"""
def __init__(self, file, protocol=0, extra_tag_to_remove=None):
......
......@@ -118,7 +118,8 @@ For instance, you can define functions along the lines of:
def __setstate__(self, d):
self.__dict__.update(d)
self.training_set = cPickle.load(open(self.training_set_file, 'rb'))
with open(self.training_set_file, 'rb') as f:
self.training_set = cPickle.load(f)
Robust Serialization
......
......@@ -66,6 +66,6 @@ class TestStripPickler:
with open("test.pkl", "wb") as f:
m = matrix()
dest_pkl = "my_test.pkl"
f = open(dest_pkl, "wb")
strip_pickler = StripPickler(f, protocol=-1)
strip_pickler.dump(m)
with open(dest_pkl, "wb") as f:
strip_pickler = StripPickler(f, protocol=-1)
strip_pickler.dump(m)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论