提交 3e42c5c5 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Use SeedSequence to seed RNG states in RandomStream

上级 33eaccac
from collections.abc import Sequence
from functools import wraps
from itertools import zip_longest
from typing import Optional, Union
from types import ModuleType
from typing import TYPE_CHECKING, Optional, Union
import numpy as np
from typing_extensions import Literal
from aesara.compile.sharedvalue import shared
from aesara.graph.basic import Constant, Variable
......@@ -13,6 +15,11 @@ from aesara.tensor.extra_ops import broadcast_to
from aesara.tensor.math import maximum
from aesara.tensor.shape import specify_shape
from aesara.tensor.type import int_dtypes
from aesara.tensor.var import TensorVariable
if TYPE_CHECKING:
from aesara.tensor.random.op import RandomVariable
def params_broadcast_shapes(param_shapes, ndims_params, use_aesara=True):
......@@ -161,7 +168,14 @@ class RandomStream:
"""
def __init__(self, seed=None, namespace=None, rng_ctor=np.random.default_rng):
def __init__(
self,
seed: Optional[int] = None,
namespace: Optional[ModuleType] = None,
rng_ctor: Literal[
np.random.RandomState, np.random.Generator
] = np.random.default_rng,
):
if namespace is None:
from aesara.tensor.random import basic # pylint: disable=import-self
......@@ -171,7 +185,14 @@ class RandomStream:
self.default_instance_seed = seed
self.state_updates = []
self.gen_seedgen = np.random.default_rng(seed)
self.gen_seedgen = np.random.SeedSequence(seed)
if isinstance(rng_ctor, type) and issubclass(rng_ctor, np.random.RandomState):
# The legacy state does not accept `SeedSequence`s directly
def rng_ctor(seed):
return np.random.RandomState(np.random.MT19937(seed))
self.rng_ctor = rng_ctor
def __getattr__(self, obj):
......@@ -206,7 +227,7 @@ class RandomStream:
Parameters
----------
seed : None or integer in range 0 to 2**30
seed : None or integer
Each random stream will be assigned a unique state that depends
deterministically on this value.
......@@ -218,18 +239,18 @@ class RandomStream:
if seed is None:
seed = self.default_instance_seed
self.gen_seedgen = np.random.default_rng(seed)
self.gen_seedgen = np.random.SeedSequence(seed)
old_r_seeds = self.gen_seedgen.spawn(len(self.state_updates))
for old_r, new_r in self.state_updates:
old_r_seed = self.gen_seedgen.integers(2**30)
old_r.set_value(self.rng_ctor(int(old_r_seed)), borrow=True)
for (old_r, new_r), old_r_seed in zip(self.state_updates, old_r_seeds):
old_r.set_value(self.rng_ctor(old_r_seed), borrow=True)
def gen(self, op, *args, **kwargs):
"""Create a new random stream in this container.
def gen(self, op: "RandomVariable", *args, **kwargs) -> TensorVariable:
r"""Generate a draw from `op` seeded from this `RandomStream`.
Parameters
----------
op : RandomVariable
op
A `RandomVariable` instance
args
Positional arguments passed to `op`.
......@@ -238,10 +259,8 @@ class RandomStream:
Returns
-------
TensorVariable
The symbolic random draw part of op()'s return value.
This function stores the updated `RandomGeneratorType` variable
for use at `build` time.
The symbolic random draw performed by `op`. This function stores
the updated `RandomType`\s for use at compile time.
"""
if "rng" in kwargs:
......@@ -250,7 +269,7 @@ class RandomStream:
)
# Generate a new random state
seed = int(self.gen_seedgen.integers(2**30))
(seed,) = self.gen_seedgen.spawn(1)
rng = shared(self.rng_ctor(seed), borrow=True)
# Generate the sample
......
......@@ -507,7 +507,7 @@ def test_normal0():
sys.stdout.flush()
RR = RandomStream(234)
RR = RandomStream(235)
nn = RR.normal(avg, std, size=size)
ff = function(var_input, nn)
......
......@@ -888,8 +888,9 @@ class TestScan:
)
my_f = function([], values, updates=updates, allow_input_downcast=True)
rng_seed = np.random.default_rng(utt.fetch_seed()).integers(2**30)
rng = np.random.default_rng(int(rng_seed)) # int() is for 32bit
rng_seed = np.random.SeedSequence(utt.fetch_seed())
(rng_seed,) = rng_seed.spawn(1)
rng = aesara_rng.rng_ctor(rng_seed)
numpy_v = np.zeros((10, 2))
for i in range(10):
......@@ -2698,12 +2699,10 @@ class TestExamples:
[vsample], aesara_vsamples[-1], updates=updates, allow_input_downcast=True
)
_rng = np.random.default_rng(utt.fetch_seed())
rng_seed = _rng.integers(2**30)
nrng1 = np.random.default_rng(int(rng_seed)) # int() is for 32bit
rng_seed = _rng.integers(2**30)
nrng2 = np.random.default_rng(int(rng_seed)) # int() is for 32bit
rng_seed = np.random.SeedSequence(utt.fetch_seed())
(rng_seed_1, rng_seed_2) = rng_seed.spawn(2)
nrng1 = trng.rng_ctor(rng_seed_1)
nrng2 = trng.rng_ctor(rng_seed_2)
def numpy_implementation(vsample):
for idx in range(10):
......
......@@ -119,8 +119,9 @@ class TestSharedRandomStream:
fn_val0 = fn()
fn_val1 = fn()
rng_seed = np.random.default_rng(utt.fetch_seed()).integers(2**30)
rng = rng_ctor(int(rng_seed)) # int() is for 32bit
rng_seed = np.random.SeedSequence(utt.fetch_seed())
(rng_seed,) = rng_seed.spawn(1)
rng = random.rng_ctor(rng_seed)
numpy_val0 = rng.uniform(0, 1, size=(2, 2))
numpy_val1 = rng.uniform(0, 1, size=(2, 2))
......@@ -133,26 +134,18 @@ class TestSharedRandomStream:
init_seed = 234
random = RandomStream(init_seed, rng_ctor=rng_ctor)
ref_state = np.random.default_rng(init_seed).__getstate__()
random_state = random.gen_seedgen.__getstate__()
assert random.default_instance_seed == init_seed
assert random_state["bit_generator"] == ref_state["bit_generator"]
assert random_state["state"] == ref_state["state"]
new_seed = 43298
random.seed(new_seed)
ref_state = np.random.default_rng(new_seed).__getstate__()
random_state = random.gen_seedgen.__getstate__()
assert random_state["bit_generator"] == ref_state["bit_generator"]
assert random_state["state"] == ref_state["state"]
rng_seed = np.random.SeedSequence(new_seed)
assert random.gen_seedgen.entropy == rng_seed.entropy
random.seed()
ref_state = np.random.default_rng(init_seed).__getstate__()
random_state = random.gen_seedgen.__getstate__()
assert random.default_instance_seed == init_seed
assert random_state["bit_generator"] == ref_state["bit_generator"]
assert random_state["state"] == ref_state["state"]
rng_seed = np.random.SeedSequence(init_seed)
assert random.gen_seedgen.entropy == rng_seed.entropy
# Reset the seed
random.seed(new_seed)
......@@ -163,8 +156,9 @@ class TestSharedRandomStream:
# Now, change the seed when there are state updates
random.seed(new_seed)
update_seed = np.random.default_rng(new_seed).integers(2**30)
ref_rng = rng_ctor(update_seed)
update_seed = np.random.SeedSequence(new_seed)
(update_seed,) = update_seed.spawn(1)
ref_rng = random.rng_ctor(update_seed)
state_rng = random.state_updates[0][0].get_value(borrow=True)
if hasattr(state_rng, "get_state"):
......@@ -188,8 +182,10 @@ class TestSharedRandomStream:
fn_val0 = fn()
fn_val1 = fn()
rng_seed = np.random.default_rng(utt.fetch_seed()).integers(2**30)
rng = rng_ctor(int(rng_seed)) # int() is for 32bit
rng_seed = np.random.SeedSequence(utt.fetch_seed())
(rng_seed,) = rng_seed.spawn(1)
rng = random.rng_ctor(rng_seed)
numpy_val0 = rng.uniform(-1, 1, size=(2, 2))
numpy_val1 = rng.uniform(-1, 1, size=(2, 2))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论