提交 0850be0d authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Make change_flags also work as a context manager.

上级 1f01eab6
......@@ -10,7 +10,7 @@ import sys
import warnings
from functools import wraps
from six import StringIO, PY3
from six import StringIO, PY3, iteritems
import theano
from theano.compat import configparser as ConfigParser
......@@ -90,38 +90,43 @@ theano_cfg.read(config_files)
theano_raw_cfg = ConfigParser.RawConfigParser()
theano_raw_cfg.read(config_files)
def change_flags(**kwargs):
class change_flags(object):
"""
Use this as a decorator to change the value of Theano config variable.
Use this as a decorator or context manager to change the value of
Theano config variable.
Useful during tests.
"""
def change_flags_exec(f):
def __init__(self, **kwargs):
confs = dict()
for k in kwargs:
l = [v for v in theano.configparser._config_var_list
if v.fullname == k]
assert len(l) == 1
confs[k] = l[0]
self.confs = confs
self.new_vals = kwargs
def __call__(self, f):
@wraps(f)
def inner(*args, **kwargs_):
old_val = {}
for k in kwargs:
l = [v for v in theano.configparser._config_var_list
if v.fullname == k]
assert len(l) == 1
old_val[k] = l[0].__get__(True, None)
try:
for k in kwargs:
l = [v for v in theano.configparser._config_var_list
if v.fullname == k]
assert len(l) == 1
l[0].__set__(None, kwargs[k])
return f(*args, **kwargs_)
finally:
for k in kwargs:
l = [v for v in theano.configparser._config_var_list
if v.fullname == k]
assert len(l) == 1
l[0].__set__(None, old_val[k])
return inner
return change_flags_exec
def res(*args, **kwargs):
with self:
return f(*args, **kwargs)
def __enter__(self):
self.old_vals = {}
for k, v in iteritems(self.confs):
self.old_vals[k] = v.__get__(True, None)
try:
for k, v in iteritems(self.confs):
v.__set__(None, self.new_vals[k])
except:
self.__exit__()
raise
def __exit__(self, *args):
for k, v in iteritems(self.confs):
v.__set__(None, self.old_vals[k])
def fetch_val_for_key(key, delete_key=False):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论