提交 46a235ab authored 作者: Brendan Murphy's avatar Brendan Murphy 提交者: Ricardo Vieira

Change rng.__getstate__ to rng.bit_generator.state

numpy.random.Generator.__getstate__() now returns none; to see the state of the bit generator, you need to use Generator.bit_generator.state. This change affects `RandomGeneratorType`, and several of the random tests (including some for Jax.)
上级 fb20e58a
......@@ -56,7 +56,7 @@ def assert_size_argument_jax_compatible(node):
@jax_typify.register(Generator)
def jax_typify_Generator(rng, **kwargs):
state = rng.__getstate__()
state = rng.bit_generator.state
state["bit_generator"] = numpy_bit_gens[state["bit_generator"]]
# XXX: Is this a reasonable approach?
......
......@@ -87,8 +87,8 @@ class RandomGeneratorType(RandomType[np.random.Generator]):
@staticmethod
def values_eq(a, b):
sa = a if isinstance(a, dict) else a.__getstate__()
sb = b if isinstance(b, dict) else b.__getstate__()
sa = a if isinstance(a, dict) else a.bit_generator.state
sb = b if isinstance(b, dict) else b.bit_generator.state
def _eq(sa, sb):
for key in sa:
......
......@@ -63,7 +63,9 @@ def test_random_updates(rng_ctor):
assert all(
a == b if not isinstance(a, np.ndarray) else np.array_equal(a, b)
for a, b in zip(
rng.get_value().__getstate__(), original_value.__getstate__(), strict=True
rng.get_value().bit_generator.state,
original_value.bit_generator.state,
strict=True,
)
)
......
......@@ -52,7 +52,7 @@ class TestRandomGeneratorType:
with pytest.raises(TypeError):
rng_type.filter(1)
rng_dict = rng.__getstate__()
rng_dict = rng.bit_generator.state
assert rng_type.is_valid_value(rng_dict) is False
assert rng_type.is_valid_value(rng_dict, strict=False)
......@@ -88,13 +88,13 @@ class TestRandomGeneratorType:
assert rng_type.values_eq(bitgen_g, bitgen_h)
assert rng_type.is_valid_value(bitgen_a, strict=True)
assert rng_type.is_valid_value(bitgen_b.__getstate__(), strict=False)
assert rng_type.is_valid_value(bitgen_b.bit_generator.state, strict=False)
assert rng_type.is_valid_value(bitgen_c, strict=True)
assert rng_type.is_valid_value(bitgen_d.__getstate__(), strict=False)
assert rng_type.is_valid_value(bitgen_d.bit_generator.state, strict=False)
assert rng_type.is_valid_value(bitgen_e, strict=True)
assert rng_type.is_valid_value(bitgen_f.__getstate__(), strict=False)
assert rng_type.is_valid_value(bitgen_f.bit_generator.state, strict=False)
assert rng_type.is_valid_value(bitgen_g, strict=True)
assert rng_type.is_valid_value(bitgen_h.__getstate__(), strict=False)
assert rng_type.is_valid_value(bitgen_h.bit_generator.state, strict=False)
def test_may_share_memory(self):
bg_a = np.random.PCG64()
......
......@@ -165,14 +165,20 @@ class TestSharedRandomStream:
state_rng = random.state_updates[0][0].get_value(borrow=True)
if hasattr(state_rng, "get_state"):
ref_state = ref_rng.get_state()
random_state = state_rng.get_state()
# hack to try to get something reasonable for ref_rng
try:
ref_state = ref_rng.get_state()
except AttributeError:
ref_state = list(ref_rng.bit_generator.state.values())
assert np.array_equal(random_state[1], ref_state[1])
assert random_state[0] == ref_state[0]
assert random_state[2:] == ref_state[2:]
else:
ref_state = ref_rng.__getstate__()
random_state = state_rng.__getstate__()
ref_state = ref_rng.bit_generator.state
random_state = state_rng.bit_generator.state
assert random_state["bit_generator"] == ref_state["bit_generator"]
assert random_state["state"] == ref_state["state"]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论