提交 d69eaabd authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Use context managers with open

上级 b4912d97
...@@ -33,14 +33,13 @@ def cleanup(): ...@@ -33,14 +33,13 @@ def cleanup():
""" """
compiledir = config.compiledir compiledir = config.compiledir
for directory in os.listdir(compiledir): for directory in os.listdir(compiledir):
file = None
try: try:
try: filename = os.path.join(compiledir, directory, "key.pkl")
filename = os.path.join(compiledir, directory, "key.pkl") # print file
file = open(filename, "rb") with open(filename, "rb") as file:
# print file
try: try:
keydata = pickle.load(file) keydata = pickle.load(file)
for key in list(keydata.keys): for key in list(keydata.keys):
have_npy_abi_version = False have_npy_abi_version = False
have_c_compiler = False have_c_compiler = False
...@@ -86,14 +85,11 @@ def cleanup(): ...@@ -86,14 +85,11 @@ def cleanup():
"the clean-up, please remove manually " "the clean-up, please remove manually "
"the directory containing it." "the directory containing it."
) )
except OSError: except OSError:
_logger.error( _logger.error(
f"Could not clean up this directory: '{directory}'. To complete " f"Could not clean up this directory: '{directory}'. To complete "
"the clean-up, please remove it manually." "the clean-up, please remove it manually."
) )
finally:
if file is not None:
file.close()
def print_title(title, overline="", underline=""): def print_title(title, overline="", underline=""):
......
...@@ -15,6 +15,7 @@ import operator ...@@ -15,6 +15,7 @@ import operator
import sys import sys
import time import time
from collections import defaultdict from collections import defaultdict
from contextlib import contextmanager
from typing import Dict, List from typing import Dict, List
import numpy as np import numpy as np
...@@ -25,6 +26,17 @@ from aesara.graph.basic import Constant, Variable ...@@ -25,6 +26,17 @@ from aesara.graph.basic import Constant, Variable
from aesara.link.utils import get_destroy_dependencies from aesara.link.utils import get_destroy_dependencies
@contextmanager
def extended_open(filename, mode="r"):
if filename == "<stdout>":
yield sys.stdout
elif filename == "<stderr>":
yield sys.stderr
else:
with open(filename, mode=mode) as f:
yield f
logger = logging.getLogger("aesara.compile.profiling") logger = logging.getLogger("aesara.compile.profiling")
aesara_imported_time = time.time() aesara_imported_time = time.time()
...@@ -37,93 +49,92 @@ _atexit_registered = False ...@@ -37,93 +49,92 @@ _atexit_registered = False
def _atexit_print_fn(): def _atexit_print_fn():
""" """Print `ProfileStat` objects in `_atexit_print_list` to `_atexit_print_file`."""
Print ProfileStat objects in _atexit_print_list to _atexit_print_file.
"""
if config.profile: if config.profile:
to_sum = [] to_sum = []
if config.profiling__destination == "stderr": if config.profiling__destination == "stderr":
destination_file = sys.stderr destination_file = "<stderr>"
elif config.profiling__destination == "stdout": elif config.profiling__destination == "stdout":
destination_file = sys.stdout destination_file = "<stdout>"
else: else:
destination_file = open(config.profiling__destination, "w") destination_file = config.profiling__destination
# Reverse sort in the order of compile+exec time with extended_open(destination_file, mode="w"):
for ps in sorted(
_atexit_print_list, key=lambda a: a.compile_time + a.fct_call_time # Reverse sort in the order of compile+exec time
)[::-1]: for ps in sorted(
if ( _atexit_print_list, key=lambda a: a.compile_time + a.fct_call_time
ps.fct_callcount >= 1 )[::-1]:
or ps.compile_time > 1 if (
or getattr(ps, "callcount", 0) > 1 ps.fct_callcount >= 1
): or ps.compile_time > 1
ps.summary( or getattr(ps, "callcount", 0) > 1
):
ps.summary(
file=destination_file,
n_ops_to_print=config.profiling__n_ops,
n_apply_to_print=config.profiling__n_apply,
)
if ps.show_sum:
to_sum.append(ps)
else:
# TODO print the name if there is one!
print("Skipping empty Profile")
if len(to_sum) > 1:
# Make a global profile
cum = copy.copy(to_sum[0])
msg = f"Sum of all({len(to_sum)}) printed profiles at exit."
cum.message = msg
for ps in to_sum[1:]:
for attr in [
"compile_time",
"fct_call_time",
"fct_callcount",
"vm_call_time",
"optimizer_time",
"linker_time",
"validate_time",
"import_time",
"linker_node_make_thunks",
]:
setattr(cum, attr, getattr(cum, attr) + getattr(ps, attr))
# merge dictionary
for attr in [
"apply_time",
"apply_callcount",
"apply_cimpl",
"variable_shape",
"variable_strides",
"variable_offset",
"linker_make_thunk_time",
]:
cum_attr = getattr(cum, attr)
for key, val in getattr(ps, attr.items()):
assert key not in cum_attr, (key, cum_attr)
cum_attr[key] = val
if cum.optimizer_profile and ps.optimizer_profile:
try:
merge = cum.optimizer_profile[0].merge_profile(
cum.optimizer_profile[1], ps.optimizer_profile[1]
)
assert len(merge) == len(cum.optimizer_profile[1])
cum.optimizer_profile = (cum.optimizer_profile[0], merge)
except Exception as e:
print(e)
cum.optimizer_profile = None
else:
cum.optimizer_profile = None
cum.summary(
file=destination_file, file=destination_file,
n_ops_to_print=config.profiling__n_ops, n_ops_to_print=config.profiling__n_ops,
n_apply_to_print=config.profiling__n_apply, n_apply_to_print=config.profiling__n_apply,
) )
if ps.show_sum:
to_sum.append(ps)
else:
# TODO print the name if there is one!
print("Skipping empty Profile")
if len(to_sum) > 1:
# Make a global profile
cum = copy.copy(to_sum[0])
msg = f"Sum of all({len(to_sum)}) printed profiles at exit."
cum.message = msg
for ps in to_sum[1:]:
for attr in [
"compile_time",
"fct_call_time",
"fct_callcount",
"vm_call_time",
"optimizer_time",
"linker_time",
"validate_time",
"import_time",
"linker_node_make_thunks",
]:
setattr(cum, attr, getattr(cum, attr) + getattr(ps, attr))
# merge dictionary
for attr in [
"apply_time",
"apply_callcount",
"apply_cimpl",
"variable_shape",
"variable_strides",
"variable_offset",
"linker_make_thunk_time",
]:
cum_attr = getattr(cum, attr)
for key, val in getattr(ps, attr.items()):
assert key not in cum_attr, (key, cum_attr)
cum_attr[key] = val
if cum.optimizer_profile and ps.optimizer_profile:
try:
merge = cum.optimizer_profile[0].merge_profile(
cum.optimizer_profile[1], ps.optimizer_profile[1]
)
assert len(merge) == len(cum.optimizer_profile[1])
cum.optimizer_profile = (cum.optimizer_profile[0], merge)
except Exception as e:
print(e)
cum.optimizer_profile = None
else:
cum.optimizer_profile = None
cum.summary(
file=destination_file,
n_ops_to_print=config.profiling__n_ops,
n_apply_to_print=config.profiling__n_apply,
)
if config.print_global_stats: if config.print_global_stats:
print_global_stats() print_global_stats()
...@@ -139,24 +150,25 @@ def print_global_stats(): ...@@ -139,24 +150,25 @@ def print_global_stats():
""" """
if config.profiling__destination == "stderr": if config.profiling__destination == "stderr":
destination_file = sys.stderr destination_file = "<stderr>"
elif config.profiling__destination == "stdout": elif config.profiling__destination == "stdout":
destination_file = sys.stdout destination_file = "<stdout>"
else: else:
destination_file = open(config.profiling__destination, "w") destination_file = config.profiling__destination
print("=" * 50, file=destination_file) with extended_open(destination_file, mode="w"):
print( print("=" * 50, file=destination_file)
( print(
"Global stats: ", (
f"Time elasped since Aesara import = {time.time() - aesara_imported_time:6.3f}s, " "Global stats: ",
f"Time spent in Aesara functions = {total_fct_exec_time:6.3f}s, " f"Time elasped since Aesara import = {time.time() - aesara_imported_time:6.3f}s, "
"Time spent compiling Aesara functions: " f"Time spent in Aesara functions = {total_fct_exec_time:6.3f}s, "
f" optimization = {total_graph_opt_time:6.3f}s, linker = {total_time_linker:6.3f}s ", "Time spent compiling Aesara functions: "
), f" optimization = {total_graph_opt_time:6.3f}s, linker = {total_time_linker:6.3f}s ",
file=destination_file, ),
) file=destination_file,
print("=" * 50, file=destination_file) )
print("=" * 50, file=destination_file)
_profiler_printers = [] _profiler_printers = []
......
...@@ -1300,7 +1300,8 @@ def _filter_compiledir(path): ...@@ -1300,7 +1300,8 @@ def _filter_compiledir(path):
init_file = os.path.join(path, "__init__.py") init_file = os.path.join(path, "__init__.py")
if not os.path.exists(init_file): if not os.path.exists(init_file):
try: try:
open(init_file, "w").close() with open(init_file, "w"):
pass
except OSError as e: except OSError as e:
if os.path.exists(init_file): if os.path.exists(init_file):
pass # has already been created pass # has already been created
......
...@@ -1008,8 +1008,8 @@ class ModuleCache: ...@@ -1008,8 +1008,8 @@ class ModuleCache:
entry = key_data.get_entry() entry = key_data.get_entry()
try: try:
# Test to see that the file is [present and] readable. # Test to see that the file is [present and] readable.
open(entry).close() with open(entry):
gone = False gone = False
except OSError: except OSError:
gone = True gone = True
...@@ -1505,8 +1505,8 @@ class ModuleCache: ...@@ -1505,8 +1505,8 @@ class ModuleCache:
if filename.startswith("tmp"): if filename.startswith("tmp"):
try: try:
fname = os.path.join(self.dirname, filename, "key.pkl") fname = os.path.join(self.dirname, filename, "key.pkl")
open(fname).close() with open(fname):
has_key = True has_key = True
except OSError: except OSError:
has_key = False has_key = False
if not has_key: if not has_key:
...@@ -1599,7 +1599,8 @@ def _rmtree( ...@@ -1599,7 +1599,8 @@ def _rmtree(
if os.path.exists(parent): if os.path.exists(parent):
try: try:
_logger.info(f'placing "delete.me" in {parent}') _logger.info(f'placing "delete.me" in {parent}')
open(os.path.join(parent, "delete.me"), "w").close() with open(os.path.join(parent, "delete.me"), "w"):
pass
except Exception as ee: except Exception as ee:
_logger.warning( _logger.warning(
f"Failed to remove or mark cache directory {parent} for removal {ee}" f"Failed to remove or mark cache directory {parent} for removal {ee}"
...@@ -2641,7 +2642,8 @@ class GCC_compiler(Compiler): ...@@ -2641,7 +2642,8 @@ class GCC_compiler(Compiler):
if py_module: if py_module:
# touch the __init__ file # touch the __init__ file
open(os.path.join(location, "__init__.py"), "w").close() with open(os.path.join(location, "__init__.py"), "w"):
pass
assert os.path.isfile(lib_filename) assert os.path.isfile(lib_filename)
return dlimport(lib_filename) return dlimport(lib_filename)
......
...@@ -96,7 +96,8 @@ try: ...@@ -96,7 +96,8 @@ try:
assert e.errno == errno.EEXIST assert e.errno == errno.EEXIST
assert os.path.exists(location), location assert os.path.exists(location), location
if not os.path.exists(os.path.join(location, "__init__.py")): if not os.path.exists(os.path.join(location, "__init__.py")):
open(os.path.join(location, "__init__.py"), "w").close() with open(os.path.join(location, "__init__.py"), "w"):
pass
try: try:
from cutils_ext.cutils_ext import * # noqa from cutils_ext.cutils_ext import * # noqa
......
...@@ -59,7 +59,8 @@ try: ...@@ -59,7 +59,8 @@ try:
init_file = os.path.join(location, "__init__.py") init_file = os.path.join(location, "__init__.py")
if not os.path.exists(init_file): if not os.path.exists(init_file):
try: try:
open(init_file, "w").close() with open(init_file, "w"):
pass
except OSError as e: except OSError as e:
if os.path.exists(init_file): if os.path.exists(init_file):
pass # has already been created pass # has already been created
...@@ -126,10 +127,12 @@ except ImportError: ...@@ -126,10 +127,12 @@ except ImportError:
"code generation." "code generation."
) )
raise ImportError("The file lazylinker_c.c is not available.") raise ImportError("The file lazylinker_c.c is not available.")
code = open(cfile).read()
with open(cfile) as f:
code = f.read()
loc = os.path.join(config.compiledir, dirname) loc = os.path.join(config.compiledir, dirname)
if not os.path.exists(loc): if not os.path.exists(loc):
try: try:
os.mkdir(loc) os.mkdir(loc)
except OSError as e: except OSError as e:
...@@ -140,14 +143,17 @@ except ImportError: ...@@ -140,14 +143,17 @@ except ImportError:
GCC_compiler.compile_str(dirname, code, location=loc, preargs=args) GCC_compiler.compile_str(dirname, code, location=loc, preargs=args)
# Save version into the __init__.py file. # Save version into the __init__.py file.
init_py = os.path.join(loc, "__init__.py") init_py = os.path.join(loc, "__init__.py")
with open(init_py, "w") as f: with open(init_py, "w") as f:
f.write(f"_version = {version}\n") f.write(f"_version = {version}\n")
# If we just compiled the module for the first time, then it was # If we just compiled the module for the first time, then it was
# imported at the same time: we need to make sure we do not # imported at the same time: we need to make sure we do not
# reload the now outdated __init__.pyc below. # reload the now outdated __init__.pyc below.
init_pyc = os.path.join(loc, "__init__.pyc") init_pyc = os.path.join(loc, "__init__.pyc")
if os.path.isfile(init_pyc): if os.path.isfile(init_pyc):
os.remove(init_pyc) os.remove(init_pyc)
try_import() try_import()
try_reload() try_reload()
from lazylinker_ext import lazylinker_ext as lazy_c from lazylinker_ext import lazylinker_ext as lazy_c
......
...@@ -42,21 +42,19 @@ Pickler = pickle.Pickler ...@@ -42,21 +42,19 @@ Pickler = pickle.Pickler
class StripPickler(Pickler): class StripPickler(Pickler):
""" """Subclass of `Pickler` that strips unnecessary attributes from Aesara objects.
Subclass of Pickler that strips unnecessary attributes from Aesara objects.
.. versionadded:: 0.8
Example of use:: Example
-------
fn_args = dict(inputs=inputs, fn_args = dict(inputs=inputs,
outputs=outputs, outputs=outputs,
updates=updates) updates=updates)
dest_pkl = 'my_test.pkl' dest_pkl = 'my_test.pkl'
f = open(dest_pkl, 'wb') with open(dest_pkl, 'wb') as f:
strip_pickler = StripPickler(f, protocol=-1) strip_pickler = StripPickler(f, protocol=-1)
strip_pickler.dump(fn_args) strip_pickler.dump(fn_args)
f.close()
""" """
def __init__(self, file, protocol=0, extra_tag_to_remove=None): def __init__(self, file, protocol=0, extra_tag_to_remove=None):
......
...@@ -118,7 +118,8 @@ For instance, you can define functions along the lines of: ...@@ -118,7 +118,8 @@ For instance, you can define functions along the lines of:
def __setstate__(self, d): def __setstate__(self, d):
self.__dict__.update(d) self.__dict__.update(d)
self.training_set = cPickle.load(open(self.training_set_file, 'rb')) with open(self.training_set_file, 'rb') as f:
self.training_set = cPickle.load(f)
Robust Serialization Robust Serialization
......
...@@ -66,6 +66,6 @@ class TestStripPickler: ...@@ -66,6 +66,6 @@ class TestStripPickler:
with open("test.pkl", "wb") as f: with open("test.pkl", "wb") as f:
m = matrix() m = matrix()
dest_pkl = "my_test.pkl" dest_pkl = "my_test.pkl"
f = open(dest_pkl, "wb") with open(dest_pkl, "wb") as f:
strip_pickler = StripPickler(f, protocol=-1) strip_pickler = StripPickler(f, protocol=-1)
strip_pickler.dump(m) strip_pickler.dump(m)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论