提交 8e0b1560 authored 作者: kc611's avatar kc611 提交者: Thomas Wiecki

Make RandomStateType compatible with array-form state values

上级 b852bd24
...@@ -14,6 +14,9 @@ class RandomStateType(Type): ...@@ -14,6 +14,9 @@ class RandomStateType(Type):
with the `==` operator. This `Type` exists to provide an equals function with the `==` operator. This `Type` exists to provide an equals function
that is used by `DebugMode`. that is used by `DebugMode`.
Also works with a `dict` derived from RandomState.get_state() unless
the `strict` argument is explicitly set to `True`.
""" """
def __repr__(self): def __repr__(self):
...@@ -21,19 +24,38 @@ class RandomStateType(Type): ...@@ -21,19 +24,38 @@ class RandomStateType(Type):
@classmethod @classmethod
def filter(cls, data, strict=False, allow_downcast=None): def filter(cls, data, strict=False, allow_downcast=None):
if cls.is_valid_value(data): if cls.is_valid_value(data, strict):
return data return data
else: else:
raise TypeError() raise TypeError()
@staticmethod @staticmethod
def is_valid_value(a): def is_valid_value(a, strict):
return isinstance(a, np.random.RandomState) if isinstance(a, np.random.RandomState):
return True
if not strict and isinstance(a, dict):
gen_keys = ["bit_generator", "gauss", "has_gauss", "state"]
state_keys = ["key", "pos"]
for key in gen_keys:
if key not in a:
return False
for key in state_keys:
if key not in a["state"]:
return False
state_key = a["state"]["key"]
if state_key.shape == (624,) and state_key.dtype == np.uint32:
return True
return False
@staticmethod @staticmethod
def values_eq(a, b): def values_eq(a, b):
sa = a.get_state(legacy=False) sa = a if isinstance(a, dict) else a.get_state(legacy=False)
sb = b.get_state(legacy=False) sb = b if isinstance(b, dict) else b.get_state(legacy=False)
def _eq(sa, sb): def _eq(sa, sb):
for key in sa: for key in sa:
......
...@@ -51,6 +51,16 @@ class TestRandomStateType: ...@@ -51,6 +51,16 @@ class TestRandomStateType:
with pytest.raises(TypeError): with pytest.raises(TypeError):
rng_type.filter(1) rng_type.filter(1)
rng = rng.get_state(legacy=False)
assert rng_type.is_valid_value(rng, strict=False)
rng["state"] = {}
assert rng_type.is_valid_value(rng, strict=False) is False
rng = {}
assert rng_type.is_valid_value(rng, strict=False) is False
def test_values_eq(self): def test_values_eq(self):
rng_type = random_state_type rng_type = random_state_type
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论