提交 abc28761 authored 作者: Piotr Frankowski's avatar Piotr Frankowski

#3429 - python 'with' statement in reset modules

上级 54b194d5
...@@ -1207,22 +1207,17 @@ class FunctionMaker(object): ...@@ -1207,22 +1207,17 @@ class FunctionMaker(object):
print('graph_db already exists') print('graph_db already exists')
else: else:
# create graph_db # create graph_db
f = open(graph_db_file, 'wb') with open(graph_db_file, 'wb') as f:
print('create new graph_db in %s' % graph_db_file) print('create new graph_db in %s' % graph_db_file)
# file needs to be open and closed for every pickle
f.close()
# load the graph_db dictionary # load the graph_db dictionary
try: try:
f = open(graph_db_file, 'rb') with open(graph_db_file, 'rb') as f:
# Temporary hack to allow # Temporary hack to allow
# theano.scan_module.tests.test_scan.T_Scan to # theano.scan_module.tests.test_scan.T_Scan to
# finish. Should be changed in definitive version. # finish. Should be changed in definitive version.
tmp = theano.config.unpickle_function tmp = theano.config.unpickle_function
theano.config.unpickle_function = False theano.config.unpickle_function = False
graph_db = pickle.load(f) graph_db = pickle.load(f)
# hack end
f.close()
print('graph_db loaded and it is not empty') print('graph_db loaded and it is not empty')
except EOFError as e: except EOFError as e:
# the file has nothing in it # the file has nothing in it
...@@ -1351,9 +1346,8 @@ class FunctionMaker(object): ...@@ -1351,9 +1346,8 @@ class FunctionMaker(object):
before_opt = self.fgraph.clone(check_integrity=False) before_opt = self.fgraph.clone(check_integrity=False)
optimizer_profile = optimizer(self.fgraph) optimizer_profile = optimizer(self.fgraph)
graph_db.update({before_opt: self.fgraph}) graph_db.update({before_opt: self.fgraph})
f = open(graph_db_file, 'wb') with open(graph_db_file, 'wb') as f:
pickle.dump(graph_db, f, -1) pickle.dump(graph_db, f, -1)
f.close()
print('new graph saved into graph_db') print('new graph saved into graph_db')
release_lock() release_lock()
return optimizer_profile return optimizer_profile
......
...@@ -18,9 +18,8 @@ def test_function_dump(): ...@@ -18,9 +18,8 @@ def test_function_dump():
tmpdir = tempfile.mkdtemp() tmpdir = tempfile.mkdtemp()
fname = os.path.join(tmpdir, 'test_function_dump.pkl') fname = os.path.join(tmpdir, 'test_function_dump.pkl')
theano.function_dump(fname, [v], v + 1) theano.function_dump(fname, [v], v + 1)
f = open(fname, 'rb') with open(fname, 'rb') as f:
l = pickle.load(f) l = pickle.load(f)
f.close()
finally: finally:
if tmpdir is not None: if tmpdir is not None:
shutil.rmtree(tmpdir) shutil.rmtree(tmpdir)
......
...@@ -84,9 +84,8 @@ def d3viz(fct, outfile, copy_deps=True, *args, **kwargs): ...@@ -84,9 +84,8 @@ def d3viz(fct, outfile, copy_deps=True, *args, **kwargs):
# Read template HTML file # Read template HTML file
template_file = os.path.join(__path__, 'html', 'template.html') template_file = os.path.join(__path__, 'html', 'template.html')
f = open(template_file) with open(template_file) as f:
template = f.read() template = f.read()
f.close()
# Copy dependencies to output directory # Copy dependencies to output directory
src_deps = __path__ src_deps = __path__
......
...@@ -10,18 +10,16 @@ class CallCache(object): ...@@ -10,18 +10,16 @@ class CallCache(object):
try: try:
if filename is None: if filename is None:
raise IOError('bad filename') # just goes to except raise IOError('bad filename') # just goes to except
f = open(filename, 'r') with open(filename, 'r') as f:
self.cache = pickle.load(f) self.cache = pickle.load(f)
f.close()
except IOError: except IOError:
self.cache = {} self.cache = {}
def persist(self, filename=None): def persist(self, filename=None):
if filename is None: if filename is None:
filename = self.filename filename = self.filename
f = open(filename, 'w') with open(filename, 'w') as f:
pickle.dump(self.cache, f) pickle.dump(self.cache, f)
f.close()
def call(self, fn, args=(), key=None): def call(self, fn, args=(), key=None):
if key is None: if key is None:
......
...@@ -2112,15 +2112,14 @@ class GCC_compiler(Compiler): ...@@ -2112,15 +2112,14 @@ class GCC_compiler(Compiler):
lib_dirs.append(python_lib) lib_dirs.append(python_lib)
cppfilename = os.path.join(location, 'mod.cpp') cppfilename = os.path.join(location, 'mod.cpp')
cppfile = open(cppfilename, 'w') with open(cppfilename, 'w') as cppfile:
_logger.debug('Writing module C++ code to %s', cppfilename) _logger.debug('Writing module C++ code to %s', cppfilename)
cppfile.write(src_code) cppfile.write(src_code)
# Avoid gcc warning "no newline at end of file". # Avoid gcc warning "no newline at end of file".
if not src_code.endswith('\n'): if not src_code.endswith('\n'):
cppfile.write('\n') cppfile.write('\n')
cppfile.close()
lib_filename = os.path.join( lib_filename = os.path.join(
location, location,
......
...@@ -349,35 +349,27 @@ def print_compiledir_content(): ...@@ -349,35 +349,27 @@ def print_compiledir_content():
total_key_sizes = 0 total_key_sizes = 0
nb_keys = {} nb_keys = {}
for dir in os.listdir(compiledir): for dir in os.listdir(compiledir):
file = None filename = os.path.join(compiledir, dir, "key.pkl")
try: with open(filename, 'rb') as file:
try: keydata = pickle.load(file)
filename = os.path.join(compiledir, dir, "key.pkl") ops = list(set([x for x in flatten(keydata.keys)
file = open(filename, 'rb') if isinstance(x, theano.gof.Op)]))
keydata = pickle.load(file) if len(ops) == 0:
ops = list(set([x for x in flatten(keydata.keys) zeros_op += 1
if isinstance(x, theano.gof.Op)])) elif len(ops) > 1:
if len(ops) == 0: more_than_one_ops += 1
zeros_op += 1 else:
elif len(ops) > 1: types = list(set([x for x in flatten(keydata.keys)
more_than_one_ops += 1 if isinstance(x, theano.gof.Type)]))
else: table.append((dir, ops[0], types))
types = list(set([x for x in flatten(keydata.keys)
if isinstance(x, theano.gof.Type)])) size = os.path.getsize(filename)
table.append((dir, ops[0], types)) total_key_sizes += size
if size > max_key_file_size:
size = os.path.getsize(filename) big_key_files.append((dir, size, ops))
total_key_sizes += size
if size > max_key_file_size: nb_keys.setdefault(len(keydata.keys), 0)
big_key_files.append((dir, size, ops)) nb_keys[len(keydata.keys)] += 1
nb_keys.setdefault(len(keydata.keys), 0)
nb_keys[len(keydata.keys)] += 1
except IOError:
pass
finally:
if file is not None:
file.close()
print("List of %d compiled individual ops in this theano cache %s:" % ( print("List of %d compiled individual ops in this theano cache %s:" % (
len(table), compiledir)) len(table), compiledir))
......
...@@ -339,9 +339,8 @@ def refresh_lock(lock_file): ...@@ -339,9 +339,8 @@ def refresh_lock(lock_file):
''.join([str(random.randint(0, 9)) for i in range(10)]), ''.join([str(random.randint(0, 9)) for i in range(10)]),
hostname) hostname)
try: try:
lock_write = open(lock_file, 'w') with open(lock_file, 'w') as lock_write:
lock_write.write(unique_id + '\n') lock_write.write(unique_id + '\n')
lock_write.close()
except Exception: except Exception:
# In some strange case, this happen. To prevent all tests # In some strange case, this happen. To prevent all tests
# from failing, we release the lock, but as there is a # from failing, we release the lock, but as there is a
......
...@@ -24,6 +24,7 @@ if __name__ == "__main__": ...@@ -24,6 +24,7 @@ if __name__ == "__main__":
import pdb import pdb
pdb.set_trace() pdb.set_trace()
if len(sys.argv) > 1: if len(sys.argv) > 1:
print(filter_output(open(sys.argv[1]))) with open(sys.argv[1]) as f:
print(filter_output(f))
else: else:
print(filter_output(sys.stdin)) print(filter_output(sys.stdin))
...@@ -23,9 +23,8 @@ for dir in dirs: ...@@ -23,9 +23,8 @@ for dir in dirs:
key = None key = None
try: try:
f = open(os.path.join(dir, "key.pkl")) with open(os.path.join(dir, "key.pkl")) as f:
key = f.read() key = f.read()
f.close()
keys.setdefault(key, 0) keys.setdefault(key, 0)
keys[key] += 1 keys[key] += 1
del f del f
...@@ -36,9 +35,8 @@ for dir in dirs: ...@@ -36,9 +35,8 @@ for dir in dirs:
path = os.path.join(dir, "mod.cpp") path = os.path.join(dir, "mod.cpp")
if not os.path.exists(path): if not os.path.exists(path):
path = os.path.join(dir, "mod.cu") path = os.path.join(dir, "mod.cu")
f = open(path) with open(path) as f:
mod = f.read() mod = f.read()
f.close()
mods.setdefault(mod, ()) mods.setdefault(mod, ())
mods[mod] += (key,) mods[mod] += (key,)
del mod del mod
......
...@@ -145,31 +145,27 @@ def get_file_contents(filename, revision="tip"): ...@@ -145,31 +145,27 @@ def get_file_contents(filename, revision="tip"):
def save_commit_message(filename): def save_commit_message(filename):
commit_message = run_mercurial_command("tip --template '{desc}'") commit_message = run_mercurial_command("tip --template '{desc}'")
save_file = open(filename, "w") with open(filename, "w") as save_file:
save_file.write(commit_message) save_file.write(commit_message)
save_file.close()
def save_diffs(diffs, filename): def save_diffs(diffs, filename):
diff = "\n\n".join(diffs) diff = "\n\n".join(diffs)
diff_file = open(filename, "w") with open(filename, "w") as diff_file:
diff_file.write(diff) diff_file.write(diff)
diff_file.close()
def should_skip_commit(): def should_skip_commit():
if not os.path.exists(SKIP_WHITESPACE_CHECK_FILENAME): if not os.path.exists(SKIP_WHITESPACE_CHECK_FILENAME):
return False return False
whitespace_check_file = open(SKIP_WHITESPACE_CHECK_FILENAME, "r") with open(SKIP_WHITESPACE_CHECK_FILENAME, "r") as whitespace_check_file:
whitespace_check_changeset = whitespace_check_file.read() whitespace_check_changeset = whitespace_check_file.read()
whitespace_check_file.close()
return whitespace_check_changeset == parent_commit() return whitespace_check_changeset == parent_commit()
def save_skip_next_commit(): def save_skip_next_commit():
whitespace_check_file = open(SKIP_WHITESPACE_CHECK_FILENAME, "w") with open(SKIP_WHITESPACE_CHECK_FILENAME, "w") as whitespace_check_file:
whitespace_check_file.write(parent_commit()) whitespace_check_file.write(parent_commit())
whitespace_check_file.close()
def main(argv=None): def main(argv=None):
......
...@@ -132,9 +132,8 @@ def check(file): ...@@ -132,9 +132,8 @@ def check(file):
shutil.copyfile(file, bak) shutil.copyfile(file, bak)
if verbose: if verbose:
print("backed up", file, "to", bak) print("backed up", file, "to", bak)
f = open(file, "w") with open(file, "w") as f:
r.write(f) r.write(f)
f.close()
if verbose: if verbose:
print("wrote new", file) print("wrote new", file)
return True return True
......
...@@ -301,12 +301,11 @@ class NVCC_compiler(Compiler): ...@@ -301,12 +301,11 @@ class NVCC_compiler(Compiler):
lib_dirs.append(python_lib) lib_dirs.append(python_lib)
cppfilename = os.path.join(location, 'mod.cu') cppfilename = os.path.join(location, 'mod.cu')
cppfile = open(cppfilename, 'w') with open(cppfilename, 'w') as cppfile:
_logger.debug('Writing module C++ code to %s', cppfilename) _logger.debug('Writing module C++ code to %s', cppfilename)
cppfile.write(src_code)
cppfile.write(src_code)
cppfile.close()
lib_filename = os.path.join(location, '%s.%s' % lib_filename = os.path.join(location, '%s.%s' %
(module_name, get_lib_extension())) (module_name, get_lib_extension()))
......
...@@ -76,7 +76,8 @@ except ImportError: ...@@ -76,7 +76,8 @@ except ImportError:
) )
raise ImportError("The file lazylinker_c.c is not available.") raise ImportError("The file lazylinker_c.c is not available.")
code = open(cfile).read() with open(cfile) as f:
code = f.read()
loc = os.path.join(config.compiledir, dirname) loc = os.path.join(config.compiledir, dirname)
if not os.path.exists(loc): if not os.path.exists(loc):
try: try:
...@@ -114,7 +115,8 @@ except ImportError: ...@@ -114,7 +115,8 @@ except ImportError:
hide_symbols=False) hide_symbols=False)
# Save version into the __init__.py file. # Save version into the __init__.py file.
init_py = os.path.join(loc, '__init__.py') init_py = os.path.join(loc, '__init__.py')
open(init_py, 'w').write('_version = %s\n' % version) with open(init_py, 'w') as f:
f.write('_version = %s\n' % version)
# If we just compiled the module for the first time, then it was # If we just compiled the module for the first time, then it was
# imported at the same time: we need to make sure we do not # imported at the same time: we need to make sure we do not
# reload the now outdated __init__.pyc below. # reload the now outdated __init__.pyc below.
......
...@@ -250,16 +250,10 @@ class T_Scan(unittest.TestCase): ...@@ -250,16 +250,10 @@ class T_Scan(unittest.TestCase):
tmpdir = mkdtemp() tmpdir = mkdtemp()
os.chdir(tmpdir) os.chdir(tmpdir)
f_out = open('tmp_scan_test_pickle.pkl', 'wb') with open('tmp_scan_test_pickle.pkl', 'wb') as f_out:
try:
pickle.dump(_my_f, f_out, protocol=-1) pickle.dump(_my_f, f_out, protocol=-1)
finally: with open('tmp_scan_test_pickle.pkl', 'rb') as f_in:
f_out.close()
f_in = open('tmp_scan_test_pickle.pkl', 'rb')
try:
my_f = pickle.load(f_in) my_f = pickle.load(f_in)
finally:
f_in.close()
finally: finally:
# Get back to the original dir, and delete the temporary one. # Get back to the original dir, and delete the temporary one.
os.chdir(origdir) os.chdir(origdir)
......
...@@ -3674,16 +3674,13 @@ class test_shapeoptimizer(unittest.TestCase): ...@@ -3674,16 +3674,13 @@ class test_shapeoptimizer(unittest.TestCase):
# Due to incompatibilities between python 2 and 3 in the format # Due to incompatibilities between python 2 and 3 in the format
# of pickled numpy ndarray, we have to force an encoding # of pickled numpy ndarray, we have to force an encoding
from theano.misc.pkl_utils import CompatUnpickler from theano.misc.pkl_utils import CompatUnpickler
pkl_file = open(pkl_filename, "rb") with open(pkl_filename, "rb") as pkl_file:
try:
if PY3: if PY3:
u = CompatUnpickler(pkl_file, encoding="latin1") u = CompatUnpickler(pkl_file, encoding="latin1")
else: else:
u = CompatUnpickler(pkl_file) u = CompatUnpickler(pkl_file)
fn_args = u.load() fn_args = u.load()
theano.function(**fn_args) theano.function(**fn_args)
finally:
pkl_file.close()
class test_assert(utt.InferShapeTester): class test_assert(utt.InferShapeTester):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论