提交 3e42c5c5 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Use SeedSequence to seed RNG states in RandomStream

上级 33eaccac
from collections.abc import Sequence from collections.abc import Sequence
from functools import wraps from functools import wraps
from itertools import zip_longest from itertools import zip_longest
from typing import Optional, Union from types import ModuleType
from typing import TYPE_CHECKING, Optional, Union
import numpy as np import numpy as np
from typing_extensions import Literal
from aesara.compile.sharedvalue import shared from aesara.compile.sharedvalue import shared
from aesara.graph.basic import Constant, Variable from aesara.graph.basic import Constant, Variable
...@@ -13,6 +15,11 @@ from aesara.tensor.extra_ops import broadcast_to ...@@ -13,6 +15,11 @@ from aesara.tensor.extra_ops import broadcast_to
from aesara.tensor.math import maximum from aesara.tensor.math import maximum
from aesara.tensor.shape import specify_shape from aesara.tensor.shape import specify_shape
from aesara.tensor.type import int_dtypes from aesara.tensor.type import int_dtypes
from aesara.tensor.var import TensorVariable
if TYPE_CHECKING:
from aesara.tensor.random.op import RandomVariable
def params_broadcast_shapes(param_shapes, ndims_params, use_aesara=True): def params_broadcast_shapes(param_shapes, ndims_params, use_aesara=True):
...@@ -161,7 +168,14 @@ class RandomStream: ...@@ -161,7 +168,14 @@ class RandomStream:
""" """
def __init__(self, seed=None, namespace=None, rng_ctor=np.random.default_rng): def __init__(
self,
seed: Optional[int] = None,
namespace: Optional[ModuleType] = None,
rng_ctor: Literal[
np.random.RandomState, np.random.Generator
] = np.random.default_rng,
):
if namespace is None: if namespace is None:
from aesara.tensor.random import basic # pylint: disable=import-self from aesara.tensor.random import basic # pylint: disable=import-self
...@@ -171,7 +185,14 @@ class RandomStream: ...@@ -171,7 +185,14 @@ class RandomStream:
self.default_instance_seed = seed self.default_instance_seed = seed
self.state_updates = [] self.state_updates = []
self.gen_seedgen = np.random.default_rng(seed) self.gen_seedgen = np.random.SeedSequence(seed)
if isinstance(rng_ctor, type) and issubclass(rng_ctor, np.random.RandomState):
# The legacy state does not accept `SeedSequence`s directly
def rng_ctor(seed):
return np.random.RandomState(np.random.MT19937(seed))
self.rng_ctor = rng_ctor self.rng_ctor = rng_ctor
def __getattr__(self, obj): def __getattr__(self, obj):
...@@ -206,7 +227,7 @@ class RandomStream: ...@@ -206,7 +227,7 @@ class RandomStream:
Parameters Parameters
---------- ----------
seed : None or integer in range 0 to 2**30 seed : None or integer
Each random stream will be assigned a unique state that depends Each random stream will be assigned a unique state that depends
deterministically on this value. deterministically on this value.
...@@ -218,18 +239,18 @@ class RandomStream: ...@@ -218,18 +239,18 @@ class RandomStream:
if seed is None: if seed is None:
seed = self.default_instance_seed seed = self.default_instance_seed
self.gen_seedgen = np.random.default_rng(seed) self.gen_seedgen = np.random.SeedSequence(seed)
old_r_seeds = self.gen_seedgen.spawn(len(self.state_updates))
for old_r, new_r in self.state_updates: for (old_r, new_r), old_r_seed in zip(self.state_updates, old_r_seeds):
old_r_seed = self.gen_seedgen.integers(2**30) old_r.set_value(self.rng_ctor(old_r_seed), borrow=True)
old_r.set_value(self.rng_ctor(int(old_r_seed)), borrow=True)
def gen(self, op, *args, **kwargs): def gen(self, op: "RandomVariable", *args, **kwargs) -> TensorVariable:
"""Create a new random stream in this container. r"""Generate a draw from `op` seeded from this `RandomStream`.
Parameters Parameters
---------- ----------
op : RandomVariable op
A `RandomVariable` instance A `RandomVariable` instance
args args
Positional arguments passed to `op`. Positional arguments passed to `op`.
...@@ -238,10 +259,8 @@ class RandomStream: ...@@ -238,10 +259,8 @@ class RandomStream:
Returns Returns
------- -------
TensorVariable The symbolic random draw performed by `op`. This function stores
The symbolic random draw part of op()'s return value. the updated `RandomType`\s for use at compile time.
This function stores the updated `RandomGeneratorType` variable
for use at `build` time.
""" """
if "rng" in kwargs: if "rng" in kwargs:
...@@ -250,7 +269,7 @@ class RandomStream: ...@@ -250,7 +269,7 @@ class RandomStream:
) )
# Generate a new random state # Generate a new random state
seed = int(self.gen_seedgen.integers(2**30)) (seed,) = self.gen_seedgen.spawn(1)
rng = shared(self.rng_ctor(seed), borrow=True) rng = shared(self.rng_ctor(seed), borrow=True)
# Generate the sample # Generate the sample
......
...@@ -507,7 +507,7 @@ def test_normal0(): ...@@ -507,7 +507,7 @@ def test_normal0():
sys.stdout.flush() sys.stdout.flush()
RR = RandomStream(234) RR = RandomStream(235)
nn = RR.normal(avg, std, size=size) nn = RR.normal(avg, std, size=size)
ff = function(var_input, nn) ff = function(var_input, nn)
......
...@@ -888,8 +888,9 @@ class TestScan: ...@@ -888,8 +888,9 @@ class TestScan:
) )
my_f = function([], values, updates=updates, allow_input_downcast=True) my_f = function([], values, updates=updates, allow_input_downcast=True)
rng_seed = np.random.default_rng(utt.fetch_seed()).integers(2**30) rng_seed = np.random.SeedSequence(utt.fetch_seed())
rng = np.random.default_rng(int(rng_seed)) # int() is for 32bit (rng_seed,) = rng_seed.spawn(1)
rng = aesara_rng.rng_ctor(rng_seed)
numpy_v = np.zeros((10, 2)) numpy_v = np.zeros((10, 2))
for i in range(10): for i in range(10):
...@@ -2698,12 +2699,10 @@ class TestExamples: ...@@ -2698,12 +2699,10 @@ class TestExamples:
[vsample], aesara_vsamples[-1], updates=updates, allow_input_downcast=True [vsample], aesara_vsamples[-1], updates=updates, allow_input_downcast=True
) )
_rng = np.random.default_rng(utt.fetch_seed()) rng_seed = np.random.SeedSequence(utt.fetch_seed())
rng_seed = _rng.integers(2**30) (rng_seed_1, rng_seed_2) = rng_seed.spawn(2)
nrng1 = np.random.default_rng(int(rng_seed)) # int() is for 32bit nrng1 = trng.rng_ctor(rng_seed_1)
nrng2 = trng.rng_ctor(rng_seed_2)
rng_seed = _rng.integers(2**30)
nrng2 = np.random.default_rng(int(rng_seed)) # int() is for 32bit
def numpy_implementation(vsample): def numpy_implementation(vsample):
for idx in range(10): for idx in range(10):
......
...@@ -119,8 +119,9 @@ class TestSharedRandomStream: ...@@ -119,8 +119,9 @@ class TestSharedRandomStream:
fn_val0 = fn() fn_val0 = fn()
fn_val1 = fn() fn_val1 = fn()
rng_seed = np.random.default_rng(utt.fetch_seed()).integers(2**30) rng_seed = np.random.SeedSequence(utt.fetch_seed())
rng = rng_ctor(int(rng_seed)) # int() is for 32bit (rng_seed,) = rng_seed.spawn(1)
rng = random.rng_ctor(rng_seed)
numpy_val0 = rng.uniform(0, 1, size=(2, 2)) numpy_val0 = rng.uniform(0, 1, size=(2, 2))
numpy_val1 = rng.uniform(0, 1, size=(2, 2)) numpy_val1 = rng.uniform(0, 1, size=(2, 2))
...@@ -133,26 +134,18 @@ class TestSharedRandomStream: ...@@ -133,26 +134,18 @@ class TestSharedRandomStream:
init_seed = 234 init_seed = 234
random = RandomStream(init_seed, rng_ctor=rng_ctor) random = RandomStream(init_seed, rng_ctor=rng_ctor)
ref_state = np.random.default_rng(init_seed).__getstate__()
random_state = random.gen_seedgen.__getstate__()
assert random.default_instance_seed == init_seed assert random.default_instance_seed == init_seed
assert random_state["bit_generator"] == ref_state["bit_generator"]
assert random_state["state"] == ref_state["state"]
new_seed = 43298 new_seed = 43298
random.seed(new_seed) random.seed(new_seed)
ref_state = np.random.default_rng(new_seed).__getstate__() rng_seed = np.random.SeedSequence(new_seed)
random_state = random.gen_seedgen.__getstate__() assert random.gen_seedgen.entropy == rng_seed.entropy
assert random_state["bit_generator"] == ref_state["bit_generator"]
assert random_state["state"] == ref_state["state"]
random.seed() random.seed()
ref_state = np.random.default_rng(init_seed).__getstate__()
random_state = random.gen_seedgen.__getstate__() rng_seed = np.random.SeedSequence(init_seed)
assert random.default_instance_seed == init_seed assert random.gen_seedgen.entropy == rng_seed.entropy
assert random_state["bit_generator"] == ref_state["bit_generator"]
assert random_state["state"] == ref_state["state"]
# Reset the seed # Reset the seed
random.seed(new_seed) random.seed(new_seed)
...@@ -163,8 +156,9 @@ class TestSharedRandomStream: ...@@ -163,8 +156,9 @@ class TestSharedRandomStream:
# Now, change the seed when there are state updates # Now, change the seed when there are state updates
random.seed(new_seed) random.seed(new_seed)
update_seed = np.random.default_rng(new_seed).integers(2**30) update_seed = np.random.SeedSequence(new_seed)
ref_rng = rng_ctor(update_seed) (update_seed,) = update_seed.spawn(1)
ref_rng = random.rng_ctor(update_seed)
state_rng = random.state_updates[0][0].get_value(borrow=True) state_rng = random.state_updates[0][0].get_value(borrow=True)
if hasattr(state_rng, "get_state"): if hasattr(state_rng, "get_state"):
...@@ -188,8 +182,10 @@ class TestSharedRandomStream: ...@@ -188,8 +182,10 @@ class TestSharedRandomStream:
fn_val0 = fn() fn_val0 = fn()
fn_val1 = fn() fn_val1 = fn()
rng_seed = np.random.default_rng(utt.fetch_seed()).integers(2**30) rng_seed = np.random.SeedSequence(utt.fetch_seed())
rng = rng_ctor(int(rng_seed)) # int() is for 32bit (rng_seed,) = rng_seed.spawn(1)
rng = random.rng_ctor(rng_seed)
numpy_val0 = rng.uniform(-1, 1, size=(2, 2)) numpy_val0 = rng.uniform(-1, 1, size=(2, 2))
numpy_val1 = rng.uniform(-1, 1, size=(2, 2)) numpy_val1 = rng.uniform(-1, 1, size=(2, 2))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论