提交 7a0175af authored 作者: Virgile Andreani's avatar Virgile Andreani 提交者: Virgile Andreani

Simplify _ChangeFlagDecorator

上级 158a7d01
...@@ -32,11 +32,7 @@ class ConfigAccessViolation(AttributeError): ...@@ -32,11 +32,7 @@ class ConfigAccessViolation(AttributeError):
class _ChangeFlagsDecorator: class _ChangeFlagsDecorator:
def __init__(self, *args, _root=None, **kwargs): def __init__(self, _root=None, **kwargs):
# the old API supported passing a dict as the first argument:
if args:
assert len(args) == 1 and isinstance(args[0], dict)
kwargs = dict(**args[0], **kwargs)
self.confs = {k: _root._config_var_dict[k] for k in kwargs} self.confs = {k: _root._config_var_dict[k] for k in kwargs}
self.new_vals = kwargs self.new_vals = kwargs
self._root = _root self._root = _root
...@@ -310,14 +306,14 @@ class PyTensorConfigParser: ...@@ -310,14 +306,14 @@ class PyTensorConfigParser:
except (NoOptionError, NoSectionError): except (NoOptionError, NoSectionError):
raise KeyError(key) raise KeyError(key)
def change_flags(self, *args, **kwargs) -> _ChangeFlagsDecorator: def change_flags(self, **kwargs) -> _ChangeFlagsDecorator:
""" """
Use this as a decorator or context manager to change the value of Use this as a decorator or context manager to change the value of
PyTensor config variables. PyTensor config variables.
Useful during tests. Useful during tests.
""" """
return _ChangeFlagsDecorator(*args, _root=self, **kwargs) return _ChangeFlagsDecorator(_root=self, **kwargs)
def warn_unused_flags(self): def warn_unused_flags(self):
for key in self._flags_dict: for key in self._flags_dict:
......
...@@ -287,6 +287,6 @@ class TestEnumTypes: ...@@ -287,6 +287,6 @@ class TestEnumTypes:
assert val_billion == val_million * 1000 assert val_billion == val_million * 1000
assert val_two_billions == val_billion * 2 assert val_two_billions == val_billion * 2
@pytensor.config.change_flags(**{"cmodule__debug": True}) @pytensor.config.change_flags(cmodule__debug=True)
def test_op_with_cenumtype_debug(self): def test_op_with_cenumtype_debug(self):
self.test_op_with_cenumtype() self.test_op_with_cenumtype()
...@@ -514,7 +514,7 @@ class TestGemmNoFlags: ...@@ -514,7 +514,7 @@ class TestGemmNoFlags:
C = self.get_value(C, transpose_C, slice_C) C = self.get_value(C, transpose_C, slice_C)
return alpha * np.dot(A, B) + beta * C return alpha * np.dot(A, B) + beta * C
@config.change_flags({"blas__ldflags": ""}) @config.change_flags(blas__ldflags="")
def run_gemm( def run_gemm(
self, self,
dtype, dtype,
......
...@@ -168,7 +168,7 @@ def test_config_context(): ...@@ -168,7 +168,7 @@ def test_config_context():
with root.change_flags(test__config_context="new_value"): with root.change_flags(test__config_context="new_value"):
assert root.test__config_context == "new_value" assert root.test__config_context == "new_value"
with root.change_flags({"test__config_context": "new_value2"}): with root.change_flags(test__config_context="new_value2"):
assert root.test__config_context == "new_value2" assert root.test__config_context == "new_value2"
assert root.test__config_context == "new_value" assert root.test__config_context == "new_value"
assert root.test__config_context == "test_default" assert root.test__config_context == "test_default"
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论