提交 8e0b1560 authored 作者: kc611's avatar kc611 提交者: Thomas Wiecki

Make RandomStateType compatible with array-form state values

上级 b852bd24
......@@ -14,6 +14,9 @@ class RandomStateType(Type):
with the `==` operator. This `Type` exists to provide an equals function
that is used by `DebugMode`.
Also works with a `dict` derived from RandomState.get_state() unless
the `strict` argument is explicitly set to `True`.
"""
def __repr__(self):
......@@ -21,19 +24,38 @@ class RandomStateType(Type):
@classmethod
def filter(cls, data, strict=False, allow_downcast=None):
if cls.is_valid_value(data):
if cls.is_valid_value(data, strict):
return data
else:
raise TypeError()
@staticmethod
def is_valid_value(a):
return isinstance(a, np.random.RandomState)
def is_valid_value(a, strict):
if isinstance(a, np.random.RandomState):
return True
if not strict and isinstance(a, dict):
gen_keys = ["bit_generator", "gauss", "has_gauss", "state"]
state_keys = ["key", "pos"]
for key in gen_keys:
if key not in a:
return False
for key in state_keys:
if key not in a["state"]:
return False
state_key = a["state"]["key"]
if state_key.shape == (624,) and state_key.dtype == np.uint32:
return True
return False
@staticmethod
def values_eq(a, b):
sa = a.get_state(legacy=False)
sb = b.get_state(legacy=False)
sa = a if isinstance(a, dict) else a.get_state(legacy=False)
sb = b if isinstance(b, dict) else b.get_state(legacy=False)
def _eq(sa, sb):
for key in sa:
......
......@@ -51,6 +51,16 @@ class TestRandomStateType:
with pytest.raises(TypeError):
rng_type.filter(1)
rng = rng.get_state(legacy=False)
assert rng_type.is_valid_value(rng, strict=False)
rng["state"] = {}
assert rng_type.is_valid_value(rng, strict=False) is False
rng = {}
assert rng_type.is_valid_value(rng, strict=False) is False
def test_values_eq(self):
rng_type = random_state_type
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论