提交 0850be0d authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Make change_flags also work as a context manager.

上级 1f01eab6
...@@ -10,7 +10,7 @@ import sys ...@@ -10,7 +10,7 @@ import sys
import warnings import warnings
from functools import wraps from functools import wraps
from six import StringIO, PY3 from six import StringIO, PY3, iteritems
import theano import theano
from theano.compat import configparser as ConfigParser from theano.compat import configparser as ConfigParser
...@@ -90,38 +90,43 @@ theano_cfg.read(config_files) ...@@ -90,38 +90,43 @@ theano_cfg.read(config_files)
theano_raw_cfg = ConfigParser.RawConfigParser() theano_raw_cfg = ConfigParser.RawConfigParser()
theano_raw_cfg.read(config_files) theano_raw_cfg.read(config_files)
class change_flags(object):
def change_flags(**kwargs):
""" """
Use this as a decorator to change the value of Theano config variable. Use this as a decorator or context manager to change the value of
Theano config variable.
Useful during tests. Useful during tests.
""" """
def change_flags_exec(f): def __init__(self, **kwargs):
@wraps(f) confs = dict()
def inner(*args, **kwargs_):
old_val = {}
for k in kwargs:
l = [v for v in theano.configparser._config_var_list
if v.fullname == k]
assert len(l) == 1
old_val[k] = l[0].__get__(True, None)
try:
for k in kwargs: for k in kwargs:
l = [v for v in theano.configparser._config_var_list l = [v for v in theano.configparser._config_var_list
if v.fullname == k] if v.fullname == k]
assert len(l) == 1 assert len(l) == 1
l[0].__set__(None, kwargs[k]) confs[k] = l[0]
return f(*args, **kwargs_) self.confs = confs
finally: self.new_vals = kwargs
for k in kwargs:
l = [v for v in theano.configparser._config_var_list
if v.fullname == k]
assert len(l) == 1
l[0].__set__(None, old_val[k])
return inner def __call__(self, f):
return change_flags_exec @wraps(f)
def res(*args, **kwargs):
with self:
return f(*args, **kwargs)
def __enter__(self):
self.old_vals = {}
for k, v in iteritems(self.confs):
self.old_vals[k] = v.__get__(True, None)
try:
for k, v in iteritems(self.confs):
v.__set__(None, self.new_vals[k])
except:
self.__exit__()
raise
def __exit__(self, *args):
for k, v in iteritems(self.confs):
v.__set__(None, self.old_vals[k])
def fetch_val_for_key(key, delete_key=False): def fetch_val_for_key(key, delete_key=False):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论