提交 e0eea331 authored 作者: kc611's avatar kc611 提交者: Brandon T. Willard

Add NumPy Generator support for RandomVariables

上级 adf83fa5
......@@ -105,7 +105,7 @@ def detect_nan(fgraph, i, node, fn):
for output in fn.outputs:
if (
not isinstance(output[0], np.random.RandomState)
not isinstance(output[0], (np.random.RandomState, np.random.Generator))
and np.isnan(output[0]).any()
):
print("*** NaN detected ***")
......
......@@ -44,7 +44,7 @@ def _is_numeric_value(arr, var):
"""
if isinstance(arr, aesara.graph.type._cdata_type):
return False
elif isinstance(arr, np.random.mtrand.RandomState):
elif isinstance(arr, (np.random.mtrand.RandomState, np.random.Generator)):
return False
elif var and getattr(var.tag, "is_rng", False):
return False
......
......@@ -1841,7 +1841,7 @@ def verify_grad(
# random_projection should not have elements too small,
# otherwise too much precision is lost in numerical gradient
def random_projection():
plain = rng.rand(*o_fn_out.shape) + 0.5
plain = rng.random(o_fn_out.shape) + 0.5
if cast_to_output_type and o_output.dtype == "float32":
return np.array(plain, o_output.dtype)
return plain
......
......@@ -736,7 +736,7 @@ class MRG_RandomStream:
def set_rstate(self, seed):
# TODO : need description for method, parameter
if isinstance(seed, int):
if isinstance(seed, (int, np.int32, np.int64)):
if seed == 0:
raise ValueError("seed should not be 0", seed)
elif seed >= M2:
......
......@@ -72,7 +72,9 @@ def check_equal_numpy(x, y):
"""
if isinstance(x, np.ndarray) and isinstance(y, np.ndarray):
return x.dtype == y.dtype and x.shape == y.shape and np.all(abs(x - y) < 1e-10)
elif isinstance(x, np.random.RandomState) and isinstance(y, np.random.RandomState):
elif isinstance(x, (np.random.Generator, np.random.RandomState)) and isinstance(
y, (np.random.Generator, np.random.RandomState)
):
return builtins.all(
np.all(a == b) for a, b in zip(x.__getstate__(), y.__getstate__())
)
......
......@@ -6,7 +6,12 @@ import scipy.stats as stats
import aesara
from aesara.tensor.basic import as_tensor_variable
from aesara.tensor.random.op import RandomVariable, default_shape_from_params
from aesara.tensor.random.type import RandomGeneratorType, RandomStateType
from aesara.tensor.random.utils import broadcast_params
from aesara.tensor.random.var import (
RandomGeneratorSharedVariable,
RandomStateSharedVariable,
)
try:
......@@ -165,7 +170,7 @@ class GumbelRV(RandomVariable):
@classmethod
def rng_fn(
cls,
rng: np.random.RandomState,
rng: Union[np.random.Generator, np.random.RandomState],
loc: Union[np.ndarray, float],
scale: Union[np.ndarray, float],
size: Optional[Union[List[int], int]],
......@@ -590,7 +595,8 @@ class PolyaGammaRV(RandomVariable):
@classmethod
def rng_fn(cls, rng, b, c, size):
pg = PyPolyaGamma(rng.randint(2 ** 16))
rand_method = rng.integers if hasattr(rng, "integers") else rng.randint
pg = PyPolyaGamma(rand_method(2 ** 16))
if not size and b.shape == c.shape == ():
return pg.pgdraw(b, c)
......@@ -627,10 +633,41 @@ class RandIntRV(RandomVariable):
low, high = 0, low
return super().__call__(low, high, size=size, **kwargs)
def make_node(self, rng, *args, **kwargs):
if not isinstance(
getattr(rng, "type", None), (RandomStateType, RandomStateSharedVariable)
):
raise TypeError("`randint` is only available for `RandomStateType`s")
return super().make_node(rng, *args, **kwargs)
randint = RandIntRV()
class IntegersRV(RandomVariable):
name = "integers"
ndim_supp = 0
ndims_params = [0, 0]
dtype = "int64"
_print_name = ("integers", "\\operatorname{integers}")
def __call__(self, low, high=None, size=None, **kwargs):
if high is None:
low, high = 0, low
return super().__call__(low, high, size=size, **kwargs)
def make_node(self, rng, *args, **kwargs):
if not isinstance(
getattr(rng, "type", None),
(RandomGeneratorType, RandomGeneratorSharedVariable),
):
raise TypeError("`integers` is only available for `RandomGeneratorType`s")
return super().make_node(rng, *args, **kwargs)
integers = IntegersRV()
class ChoiceRV(RandomVariable):
name = "choice"
ndim_supp = 0
......@@ -698,6 +735,7 @@ permutation = PermutationRV()
__all__ = [
"permutation",
"choice",
"integers",
"randint",
"categorical",
"multinomial",
......
......@@ -18,7 +18,7 @@ from aesara.tensor.basic import (
)
from aesara.tensor.elemwise import Elemwise
from aesara.tensor.exceptions import NotScalarConstantError
from aesara.tensor.random.type import RandomStateType
from aesara.tensor.random.type import RandomType
from aesara.tensor.random.utils import normalize_size_param, params_broadcast_shapes
from aesara.tensor.type import TensorType, all_dtypes
......@@ -158,7 +158,7 @@ class RandomVariable(Op):
def rng_fn(self, rng, *args, **kwargs):
"""Sample a numeric random variate."""
return getattr(np.random.RandomState, self.name)(rng, *args, **kwargs)
return getattr(rng, self.name)(*args, **kwargs)
def __str__(self):
props_str = ", ".join((f"{getattr(self, prop)}" for prop in self.__props__[1:]))
......@@ -336,8 +336,8 @@ class RandomVariable(Op):
Parameters
----------
rng: RandomStateType
Existing Aesara `RandomState` object to be used. Creates a
rng: RandomGeneratorType or RandomStateType
Existing Aesara `Generator` or `RandomState` object to be used. Creates a
new one, if `None`.
size: int or Sequence
Numpy-like size of the output (i.e. replications).
......@@ -363,9 +363,11 @@ class RandomVariable(Op):
)
if rng is None:
rng = aesara.shared(np.random.RandomState())
elif not isinstance(rng.type, RandomStateType):
raise TypeError("The type of rng should be an instance of RandomStateType")
rng = aesara.shared(np.random.default_rng())
elif not isinstance(rng.type, RandomType):
raise TypeError(
"The type of rng should be an instance of either RandomGeneratorType or RandomStateType"
)
bcast = self.compute_bcast(dist_params, size)
dtype = self.dtype or dtype
......
......@@ -6,21 +6,22 @@ import aesara
from aesara.graph.type import Type
class RandomStateType(Type):
"""A Type wrapper for `numpy.random.RandomState`.
gen_states_keys = {
"MT19937": (["state"], ["key", "pos"]),
"PCG64": (["state", "has_uint32", "uinteger"], ["state", "inc"]),
"Philox": (
["state", "buffer", "buffer_pos", "has_uint32", "uinteger"],
["counter", "key"],
),
"SFC64": (["state", "has_uint32", "uinteger"], ["state"]),
}
The reason this exists (and `Generic` doesn't suffice) is that
`RandomState` objects that would appear to be equal do not compare equal
with the `==` operator. This `Type` exists to provide an equals function
that is used by `DebugMode`.
Also works with a `dict` derived from RandomState.get_state() unless
the `strict` argument is explicitly set to `True`.
# We map bit generators to an integer index so that we can avoid using strings
numpy_bit_gens = {0: "MT19937", 1: "PCG64", 2: "Philox", 3: "SFC64"}
"""
def __repr__(self):
return "RandomStateType"
class RandomType(Type):
r"""A Type wrapper for `numpy.random.Generator` and `numpy.random.RandomState`."""
@classmethod
def filter(cls, data, strict=False, allow_downcast=None):
......@@ -29,6 +30,31 @@ class RandomStateType(Type):
else:
raise TypeError()
@staticmethod
def get_shape_info(obj):
return obj.get_value(borrow=True)
@staticmethod
def may_share_memory(a, b):
return a._bit_generator is b._bit_generator
class RandomStateType(RandomType):
r"""A Type wrapper for `numpy.random.RandomState`.
The reason this exists (and `Generic` doesn't suffice) is that
`RandomState` objects that would appear to be equal do not compare equal
with the ``==`` operator.
This `Type` also works with a ``dict`` derived from
`RandomState.get_state(legacy=False)`, unless the ``strict`` argument to `Type.filter`
is explicitly set to ``True``.
"""
def __repr__(self):
return "RandomStateType"
@staticmethod
def is_valid_value(a, strict):
if isinstance(a, np.random.RandomState):
......@@ -73,18 +99,10 @@ class RandomStateType(Type):
return _eq(sa, sb)
@staticmethod
def get_shape_info(obj):
return obj.get_value(borrow=True)
@staticmethod
def get_size(shape_info):
return sys.getsizeof(shape_info.get_state(legacy=False))
@staticmethod
def may_share_memory(a, b):
return a._bit_generator is b._bit_generator
# Register `RandomStateType`'s C code for `ViewOp`.
aesara.compile.register_view_op_c_code(
......@@ -98,3 +116,89 @@ aesara.compile.register_view_op_c_code(
)
random_state_type = RandomStateType()
class RandomGeneratorType(RandomType):
r"""A Type wrapper for `numpy.random.Generator`.
The reason this exists (and `Generic` doesn't suffice) is that
`Generator` objects that would appear to be equal do not compare equal
with the ``==`` operator.
This `Type` also works with a ``dict`` derived from
`Generator.__get_state__`, unless the ``strict`` argument to `Type.filter`
is explicitly set to ``True``.
"""
def __repr__(self):
return "RandomGeneratorType"
@staticmethod
def is_valid_value(a, strict):
if isinstance(a, np.random.Generator):
return True
if not strict and isinstance(a, dict):
if "bit_generator" not in a:
return False
else:
bit_gen_key = a["bit_generator"]
if hasattr(bit_gen_key, "_value"):
bit_gen_key = int(bit_gen_key._value)
bit_gen_key = numpy_bit_gens[bit_gen_key]
gen_keys, state_keys = gen_states_keys[bit_gen_key]
for key in gen_keys:
if key not in a:
return False
for key in state_keys:
if key not in a["state"]:
return False
return True
return False
@staticmethod
def values_eq(a, b):
sa = a if isinstance(a, dict) else a.__getstate__()
sb = b if isinstance(b, dict) else b.__getstate__()
def _eq(sa, sb):
for key in sa:
if isinstance(sa[key], dict):
if not _eq(sa[key], sb[key]):
return False
elif isinstance(sa[key], np.ndarray):
if not np.array_equal(sa[key], sb[key]):
return False
else:
if sa[key] != sb[key]:
return False
return True
return _eq(sa, sb)
@staticmethod
def get_size(shape_info):
state = shape_info.__getstate__()
return sys.getsizeof(state)
# Register `RandomGeneratorType`'s C code for `ViewOp`.
aesara.compile.register_view_op_c_code(
RandomGeneratorType,
"""
Py_XDECREF(%(oname)s);
%(oname)s = %(iname)s;
Py_XINCREF(%(oname)s);
""",
1,
)
random_generator_type = RandomGeneratorType()
......@@ -117,25 +117,28 @@ def normalize_size_param(size):
class RandomStream:
"""Module component with similar interface to `numpy.random.RandomState`.
"""Module component with similar interface to `numpy.random.Generator`.
Attributes
----------
seed: None or int
A default seed to initialize the RandomState instances after build.
A default seed to initialize the `Generator` instances after build.
state_updates: list
A list of pairs of the form `(input_r, output_r)`. This will be
A list of pairs of the form ``(input_r, output_r)``. This will be
over-ridden by the module instance to contain stream generators.
default_instance_seed: int
Instance variable should take None or integer value. Used to seed the
random number generator that provides seeds for member streams.
gen_seedgen: numpy.random.RandomState
`RandomState` instance that `RandomStream.gen` uses to seed new
gen_seedgen: numpy.random.Generator
`Generator` instance that `RandomStream.gen` uses to seed new
streams.
rng_ctor: type
Constructor used to create the underlying RNG objects. The default
is `np.random.default_rng`.
"""
def __init__(self, seed=None, namespace=None):
def __init__(self, seed=None, namespace=None, rng_ctor=np.random.default_rng):
if namespace is None:
from aesara.tensor.random import basic # pylint: disable=import-self
......@@ -145,7 +148,8 @@ class RandomStream:
self.default_instance_seed = seed
self.state_updates = []
self.gen_seedgen = np.random.RandomState(seed)
self.gen_seedgen = np.random.default_rng(seed)
self.rng_ctor = rng_ctor
def __getattr__(self, obj):
......@@ -191,11 +195,11 @@ class RandomStream:
if seed is None:
seed = self.default_instance_seed
self.gen_seedgen.seed(seed)
self.gen_seedgen = np.random.default_rng(seed)
for old_r, new_r in self.state_updates:
old_r_seed = self.gen_seedgen.randint(2 ** 30)
old_r.set_value(np.random.RandomState(int(old_r_seed)), borrow=True)
old_r_seed = self.gen_seedgen.integers(2 ** 30)
old_r.set_value(self.rng_ctor(int(old_r_seed)), borrow=True)
def gen(self, op, *args, **kwargs):
"""Create a new random stream in this container.
......@@ -213,18 +217,18 @@ class RandomStream:
-------
TensorVariable
The symbolic random draw part of op()'s return value.
This function stores the updated `RandomStateType` variable
This function stores the updated `RandomGeneratorType` variable
for use at `build` time.
"""
if "rng" in kwargs:
raise TypeError(
"The rng option cannot be used with a variate in a RandomStream"
raise ValueError(
"The `rng` option cannot be used with a variate in a `RandomStream`"
)
# Generate a new random state
seed = int(self.gen_seedgen.randint(2 ** 30))
random_state_variable = shared(np.random.RandomState(seed))
seed = int(self.gen_seedgen.integers(2 ** 30))
random_state_variable = shared(self.rng_ctor(seed))
# Distinguish it from other shared variables (why?)
random_state_variable.tag.is_rng = True
......
......@@ -3,7 +3,7 @@ import copy
import numpy as np
from aesara.compile.sharedvalue import SharedVariable, shared_constructor
from aesara.tensor.random.type import random_state_type
from aesara.tensor.random.type import random_generator_type, random_state_type
class RandomStateSharedVariable(SharedVariable):
......@@ -11,20 +11,30 @@ class RandomStateSharedVariable(SharedVariable):
return "RandomStateSharedVariable({})".format(repr(self.container))
class RandomGeneratorSharedVariable(SharedVariable):
def __str__(self):
return "RandomGeneratorSharedVariable({})".format(repr(self.container))
@shared_constructor
def randomstate_constructor(
def randomgen_constructor(
value, name=None, strict=False, allow_downcast=None, borrow=False
):
"""
SharedVariable Constructor for RandomState.
r"""`SharedVariable` Constructor for NumPy's `Generator` and/or `RandomState`."""
if isinstance(value, np.random.RandomState):
rng_sv_type = RandomStateSharedVariable
rng_type = random_state_type
elif isinstance(value, np.random.Generator):
rng_sv_type = RandomGeneratorSharedVariable
rng_type = random_generator_type
else:
raise TypeError()
"""
if not isinstance(value, np.random.RandomState):
raise TypeError
if not borrow:
value = copy.deepcopy(value)
return RandomStateSharedVariable(
type=random_state_type,
return rng_sv_type(
type=rng_type,
value=value,
name=name,
strict=strict,
......
......@@ -57,7 +57,7 @@ if __name__ == "__main__":
license=LICENSE,
platforms=PLATFORMS,
packages=find_packages(exclude=["tests", "tests.*"]),
install_requires=["numpy>=1.9.1", "scipy>=0.14", "filelock"],
install_requires=["numpy>=1.17.0", "scipy>=0.14", "filelock"],
package_data={
"": [
"*.txt",
......
......@@ -29,6 +29,7 @@ from aesara.tensor.random.basic import (
halfcauchy,
halfnormal,
hypergeometric,
integers,
invgamma,
laplace,
logistic,
......@@ -58,7 +59,7 @@ def set_aesara_flags():
yield
def rv_numpy_tester(rv, *params, **kwargs):
def rv_numpy_tester(rv, *params, rng=None, **kwargs):
"""Test for correspondence between `RandomVariable` and NumPy shape and
broadcast dimensions.
"""
......@@ -70,9 +71,9 @@ def rv_numpy_tester(rv, *params, **kwargs):
if name is None:
name = rv.__name__
test_fn = getattr(np.random, name)
test_fn = getattr(rng or np.random, name)
aesara_res = rv(*params, **kwargs)
aesara_res = rv(*params, rng=shared(rng) if rng else None, **kwargs)
param_vals = [get_test_value(p) if isinstance(p, Variable) else p for p in params]
kwargs_vals = {
......@@ -738,17 +739,47 @@ def test_polyagamma_samples():
assert np.all(np.abs(np.diff(bcast_smpl.flat)) > 0.0)
def test_random_integer_samples():
def test_randint_samples():
rv_numpy_tester(randint, 10, None)
rv_numpy_tester(randint, 0, 1)
rv_numpy_tester(randint, 0, 1, size=[3])
rv_numpy_tester(randint, [0, 1, 2], 5)
rv_numpy_tester(randint, [0, 1, 2], 5, size=[3, 3])
rv_numpy_tester(randint, [0], [5], size=[1])
rv_numpy_tester(randint, aet.as_tensor_variable([-1]), [1], size=[1])
with raises(TypeError):
randint(10, rng=shared(np.random.default_rng()))
rng = np.random.RandomState(2313)
rv_numpy_tester(randint, 10, None, rng=rng)
rv_numpy_tester(randint, 0, 1, rng=rng)
rv_numpy_tester(randint, 0, 1, size=[3], rng=rng)
rv_numpy_tester(randint, [0, 1, 2], 5, rng=rng)
rv_numpy_tester(randint, [0, 1, 2], 5, size=[3, 3], rng=rng)
rv_numpy_tester(randint, [0], [5], size=[1], rng=rng)
rv_numpy_tester(randint, aet.as_tensor_variable([-1]), [1], size=[1], rng=rng)
rv_numpy_tester(
randint, aet.as_tensor_variable([-1]), [1], size=aet.as_tensor_variable([1])
randint,
aet.as_tensor_variable([-1]),
[1],
size=aet.as_tensor_variable([1]),
rng=rng,
)
def test_integers_samples():
with raises(TypeError):
integers(10, rng=shared(np.random.RandomState()))
rng = np.random.default_rng(2313)
rv_numpy_tester(integers, 10, None, rng=rng)
rv_numpy_tester(integers, 0, 1, rng=rng)
rv_numpy_tester(integers, 0, 1, size=[3], rng=rng)
rv_numpy_tester(integers, [0, 1, 2], 5, rng=rng)
rv_numpy_tester(integers, [0, 1, 2], 5, size=[3, 3], rng=rng)
rv_numpy_tester(integers, [0], [5], size=[1], rng=rng)
rv_numpy_tester(integers, aet.as_tensor_variable([-1]), [1], size=[1], rng=rng)
rv_numpy_tester(
integers,
aet.as_tensor_variable([-1]),
[1],
size=aet.as_tensor_variable([1]),
rng=rng,
)
......
......@@ -6,7 +6,12 @@ import pytest
from aesara import shared
from aesara.compile.ops import ViewOp
from aesara.tensor.random.type import RandomStateType, random_state_type
from aesara.tensor.random.type import (
RandomGeneratorType,
RandomStateType,
random_generator_type,
random_state_type,
)
# @pytest.mark.skipif(
......@@ -24,8 +29,8 @@ def test_view_op_c_code():
# rng_view,
# mode=Mode(optimizer=None, linker=CLinker()),
# )
assert ViewOp.c_code_and_version[RandomStateType]
assert ViewOp.c_code_and_version[RandomGeneratorType]
class TestRandomStateType:
......@@ -106,9 +111,112 @@ class TestRandomStateType:
assert size == sys.getsizeof(rng.get_state(legacy=False))
def test_may_share_memory(self):
rng_a = np.random.RandomState(12)
bg = np.random.PCG64()
rng_b = np.random.RandomState(bg)
bg1 = np.random.MT19937()
bg2 = np.random.MT19937()
rng_a = np.random.RandomState(bg1)
rng_b = np.random.RandomState(bg2)
rng_var_a = shared(rng_a, borrow=True)
rng_var_b = shared(rng_b, borrow=True)
shape_info_a = random_state_type.get_shape_info(rng_var_a)
shape_info_b = random_state_type.get_shape_info(rng_var_b)
assert random_state_type.may_share_memory(shape_info_a, shape_info_b) is False
rng_c = np.random.RandomState(bg2)
rng_var_c = shared(rng_c, borrow=True)
shape_info_c = random_state_type.get_shape_info(rng_var_c)
assert random_state_type.may_share_memory(shape_info_b, shape_info_c) is True
class TestRandomGeneratorType:
def test_pickle(self):
rng_r = random_generator_type()
rng_pkl = pickle.dumps(rng_r)
rng_unpkl = pickle.loads(rng_pkl)
assert isinstance(rng_unpkl, type(rng_r))
assert isinstance(rng_unpkl.type, type(rng_r.type))
def test_repr(self):
assert repr(random_generator_type) == "RandomGeneratorType"
def test_filter(self):
rng_type = random_generator_type
rng = np.random.default_rng()
assert rng_type.filter(rng) is rng
with pytest.raises(TypeError):
rng_type.filter(1)
rng = rng.__getstate__()
assert rng_type.is_valid_value(rng, strict=False)
rng["state"] = {}
assert rng_type.is_valid_value(rng, strict=False) is False
rng = {}
assert rng_type.is_valid_value(rng, strict=False) is False
def test_values_eq(self):
rng_type = random_generator_type
bg_1 = np.random.PCG64()
bg_2 = np.random.Philox()
bg_3 = np.random.MT19937()
bg_4 = np.random.SFC64()
bitgen_a = np.random.Generator(bg_1)
bitgen_b = np.random.Generator(bg_1)
assert rng_type.values_eq(bitgen_a, bitgen_b)
bitgen_c = np.random.Generator(bg_2)
bitgen_d = np.random.Generator(bg_2)
assert rng_type.values_eq(bitgen_c, bitgen_d)
bitgen_e = np.random.Generator(bg_3)
bitgen_f = np.random.Generator(bg_3)
assert rng_type.values_eq(bitgen_e, bitgen_f)
bitgen_g = np.random.Generator(bg_4)
bitgen_h = np.random.Generator(bg_4)
assert rng_type.values_eq(bitgen_g, bitgen_h)
assert rng_type.is_valid_value(bitgen_a, strict=True)
assert rng_type.is_valid_value(bitgen_b.__getstate__(), strict=False)
assert rng_type.is_valid_value(bitgen_c, strict=True)
assert rng_type.is_valid_value(bitgen_d.__getstate__(), strict=False)
assert rng_type.is_valid_value(bitgen_e, strict=True)
assert rng_type.is_valid_value(bitgen_f.__getstate__(), strict=False)
assert rng_type.is_valid_value(bitgen_g, strict=True)
assert rng_type.is_valid_value(bitgen_h.__getstate__(), strict=False)
def test_get_shape_info(self):
rng = np.random.default_rng(12)
rng_a = shared(rng)
assert isinstance(
random_generator_type.get_shape_info(rng_a), np.random.Generator
)
def test_get_size(self):
rng = np.random.Generator(np.random.PCG64(12))
rng_a = shared(rng)
shape_info = random_generator_type.get_shape_info(rng_a)
size = random_generator_type.get_size(shape_info)
assert size == sys.getsizeof(rng.__getstate__())
def test_may_share_memory(self):
bg_a = np.random.PCG64()
bg_b = np.random.PCG64()
rng_a = np.random.Generator(bg_a)
rng_b = np.random.Generator(bg_b)
rng_var_a = shared(rng_a, borrow=True)
rng_var_b = shared(rng_b, borrow=True)
......@@ -117,7 +225,7 @@ class TestRandomStateType:
assert random_state_type.may_share_memory(shape_info_a, shape_info_b) is False
rng_c = np.random.RandomState(bg)
rng_c = np.random.Generator(bg_b)
rng_var_c = shared(rng_c, borrow=True)
shape_info_c = random_state_type.get_shape_info(rng_var_c)
......
......@@ -82,9 +82,6 @@ def test_broadcast_params():
class TestSharedRandomStream:
def setup_method(self):
utt.seed_rng()
def test_tutorial(self):
srng = RandomStream(seed=234)
rv_u = srng.uniform(0, 1, size=(2, 2))
......@@ -100,19 +97,20 @@ class TestSharedRandomStream:
assert np.all(f() != f())
assert np.all(g() == g())
assert np.all(abs(nearly_zeros()) < 1e-5)
assert isinstance(rv_u.rng.get_value(borrow=True), np.random.RandomState)
assert isinstance(rv_u.rng.get_value(borrow=True), np.random.Generator)
def test_basics(self):
random = RandomStream(seed=utt.fetch_seed())
@pytest.mark.parametrize("rng_ctor", [np.random.RandomState, np.random.default_rng])
def test_basics(self, rng_ctor):
random = RandomStream(seed=utt.fetch_seed(), rng_ctor=rng_ctor)
with pytest.raises(TypeError):
random.uniform(0, 1, size=(2, 2), rng=np.random.RandomState(23))
with pytest.raises(ValueError):
random.uniform(0, 1, size=(2, 2), rng=np.random.default_rng(23))
with pytest.raises(AttributeError):
random.blah
with pytest.raises(AttributeError):
np_random = RandomStream(namespace=np)
np_random = RandomStream(namespace=np, rng_ctor=rng_ctor)
np_random.ndarray
fn = function([], random.uniform(0, 1, size=(2, 2)), updates=random.updates())
......@@ -120,8 +118,8 @@ class TestSharedRandomStream:
fn_val0 = fn()
fn_val1 = fn()
rng_seed = np.random.RandomState(utt.fetch_seed()).randint(2 ** 30)
rng = np.random.RandomState(int(rng_seed)) # int() is for 32bit
rng_seed = np.random.default_rng(utt.fetch_seed()).integers(2 ** 30)
rng = rng_ctor(int(rng_seed)) # int() is for 32bit
numpy_val0 = rng.uniform(0, 1, size=(2, 2))
numpy_val1 = rng.uniform(0, 1, size=(2, 2))
......@@ -129,33 +127,31 @@ class TestSharedRandomStream:
assert np.allclose(fn_val0, numpy_val0)
assert np.allclose(fn_val1, numpy_val1)
def test_seed(self):
@pytest.mark.parametrize("rng_ctor", [np.random.RandomState, np.random.default_rng])
def test_seed(self, rng_ctor):
init_seed = 234
random = RandomStream(init_seed)
random = RandomStream(init_seed, rng_ctor=rng_ctor)
ref_state = np.random.RandomState(init_seed).get_state()
random_state = random.gen_seedgen.get_state()
ref_state = np.random.default_rng(init_seed).__getstate__()
random_state = random.gen_seedgen.__getstate__()
assert random.default_instance_seed == init_seed
assert np.array_equal(random_state[1], ref_state[1])
assert random_state[0] == ref_state[0]
assert random_state[2:] == ref_state[2:]
assert random_state["bit_generator"] == ref_state["bit_generator"]
assert random_state["state"] == ref_state["state"]
new_seed = 43298
random.seed(new_seed)
ref_state = np.random.RandomState(new_seed).get_state()
random_state = random.gen_seedgen.get_state()
assert np.array_equal(random_state[1], ref_state[1])
assert random_state[0] == ref_state[0]
assert random_state[2:] == ref_state[2:]
ref_state = np.random.default_rng(new_seed).__getstate__()
random_state = random.gen_seedgen.__getstate__()
assert random_state["bit_generator"] == ref_state["bit_generator"]
assert random_state["state"] == ref_state["state"]
random.seed()
ref_state = np.random.RandomState(init_seed).get_state()
random_state = random.gen_seedgen.get_state()
ref_state = np.random.default_rng(init_seed).__getstate__()
random_state = random.gen_seedgen.__getstate__()
assert random.default_instance_seed == init_seed
assert np.array_equal(random_state[1], ref_state[1])
assert random_state[0] == ref_state[0]
assert random_state[2:] == ref_state[2:]
assert random_state["bit_generator"] == ref_state["bit_generator"]
assert random_state["state"] == ref_state["state"]
# Reset the seed
random.seed(new_seed)
......@@ -166,33 +162,43 @@ class TestSharedRandomStream:
# Now, change the seed when there are state updates
random.seed(new_seed)
rng = np.random.RandomState(new_seed)
update_seed = rng.randint(2 ** 30)
ref_state = np.random.RandomState(update_seed).get_state()
random_state = random.state_updates[0][0].get_value(borrow=True).get_state()
assert np.array_equal(random_state[1], ref_state[1])
assert random_state[0] == ref_state[0]
assert random_state[2:] == ref_state[2:]
def test_uniform(self):
update_seed = np.random.default_rng(new_seed).integers(2 ** 30)
ref_rng = rng_ctor(update_seed)
state_rng = random.state_updates[0][0].get_value(borrow=True)
if hasattr(state_rng, "get_state"):
ref_state = ref_rng.get_state()
random_state = state_rng.get_state()
assert np.array_equal(random_state[1], ref_state[1])
assert random_state[0] == ref_state[0]
assert random_state[2:] == ref_state[2:]
else:
ref_state = ref_rng.__getstate__()
random_state = state_rng.__getstate__()
assert random_state["bit_generator"] == ref_state["bit_generator"]
assert random_state["state"] == ref_state["state"]
@pytest.mark.parametrize("rng_ctor", [np.random.RandomState, np.random.default_rng])
def test_uniform(self, rng_ctor):
# Test that RandomStream.uniform generates the same results as numpy
# Check over two calls to see if the random state is correctly updated.
random = RandomStream(utt.fetch_seed())
random = RandomStream(utt.fetch_seed(), rng_ctor=rng_ctor)
fn = function([], random.uniform(-1, 1, size=(2, 2)))
fn_val0 = fn()
fn_val1 = fn()
rng_seed = np.random.RandomState(utt.fetch_seed()).randint(2 ** 30)
rng = np.random.RandomState(int(rng_seed)) # int() is for 32bit
rng_seed = np.random.default_rng(utt.fetch_seed()).integers(2 ** 30)
rng = rng_ctor(int(rng_seed)) # int() is for 32bit
numpy_val0 = rng.uniform(-1, 1, size=(2, 2))
numpy_val1 = rng.uniform(-1, 1, size=(2, 2))
assert np.allclose(fn_val0, numpy_val0)
assert np.allclose(fn_val1, numpy_val1)
def test_default_updates(self):
@pytest.mark.parametrize("rng_ctor", [np.random.RandomState, np.random.default_rng])
def test_default_updates(self, rng_ctor):
# Basic case: default_updates
random_a = RandomStream(utt.fetch_seed())
random_a = RandomStream(utt.fetch_seed(), rng_ctor=rng_ctor)
out_a = random_a.uniform(0, 1, size=(2, 2))
fn_a = function([], out_a)
fn_a_val0 = fn_a()
......@@ -203,7 +209,7 @@ class TestSharedRandomStream:
assert np.all(abs(nearly_zeros()) < 1e-5)
# Explicit updates #1
random_b = RandomStream(utt.fetch_seed())
random_b = RandomStream(utt.fetch_seed(), rng_ctor=rng_ctor)
out_b = random_b.uniform(0, 1, size=(2, 2))
fn_b = function([], out_b, updates=random_b.updates())
fn_b_val0 = fn_b()
......@@ -212,7 +218,7 @@ class TestSharedRandomStream:
assert np.all(fn_b_val1 == fn_a_val1)
# Explicit updates #2
random_c = RandomStream(utt.fetch_seed())
random_c = RandomStream(utt.fetch_seed(), rng_ctor=rng_ctor)
out_c = random_c.uniform(0, 1, size=(2, 2))
fn_c = function([], out_c, updates=[out_c.update])
fn_c_val0 = fn_c()
......@@ -221,7 +227,7 @@ class TestSharedRandomStream:
assert np.all(fn_c_val1 == fn_a_val1)
# No updates at all
random_d = RandomStream(utt.fetch_seed())
random_d = RandomStream(utt.fetch_seed(), rng_ctor=rng_ctor)
out_d = random_d.uniform(0, 1, size=(2, 2))
fn_d = function([], out_d, no_default_updates=True)
fn_d_val0 = fn_d()
......@@ -230,7 +236,7 @@ class TestSharedRandomStream:
assert np.all(fn_d_val1 == fn_d_val0)
# No updates for out
random_e = RandomStream(utt.fetch_seed())
random_e = RandomStream(utt.fetch_seed(), rng_ctor=rng_ctor)
out_e = random_e.uniform(0, 1, size=(2, 2))
fn_e = function([], out_e, no_default_updates=[out_e.rng])
fn_e_val0 = fn_e()
......@@ -238,24 +244,26 @@ class TestSharedRandomStream:
assert np.all(fn_e_val0 == fn_a_val0)
assert np.all(fn_e_val1 == fn_e_val0)
def test_multiple_rng_aliasing(self):
@pytest.mark.parametrize("rng_ctor", [np.random.RandomState, np.random.default_rng])
def test_multiple_rng_aliasing(self, rng_ctor):
# Test that when we have multiple random number generators, we do not alias
# the state_updates member. `state_updates` can be useful when attempting to
# copy the (random) state between two similar aesara graphs. The test is
# meant to detect a previous bug where state_updates was initialized as a
# class-attribute, instead of the __init__ function.
rng1 = RandomStream(1234)
rng2 = RandomStream(2392)
rng1 = RandomStream(1234, rng_ctor=rng_ctor)
rng2 = RandomStream(2392, rng_ctor=rng_ctor)
assert rng1.state_updates is not rng2.state_updates
assert rng1.gen_seedgen is not rng2.gen_seedgen
def test_random_state_transfer(self):
@pytest.mark.parametrize("rng_ctor", [np.random.RandomState, np.random.default_rng])
def test_random_state_transfer(self, rng_ctor):
# Test that random state can be transferred from one aesara graph to another.
class Graph:
def __init__(self, seed=123):
self.rng = RandomStream(seed)
self.rng = RandomStream(seed, rng_ctor=rng_ctor)
self.y = self.rng.uniform(0, 1, size=(1,))
g1 = Graph(seed=123)
......
import numpy as np
import pytest
from aesara import shared
def test_RandomStateSharedVariable():
rng = np.random.RandomState(123)
@pytest.mark.parametrize(
"rng", [np.random.RandomState(123), np.random.default_rng(123)]
)
def test_GeneratorSharedVariable(rng):
s_rng_default = shared(rng)
s_rng_True = shared(rng, borrow=True)
s_rng_False = shared(rng, borrow=False)
......@@ -17,15 +20,22 @@ def test_RandomStateSharedVariable():
assert s_rng_True.container.storage[0] is rng
# ensure that all the random number generators are in the same state
v = rng.randn()
v0 = s_rng_default.container.storage[0].randn()
v1 = s_rng_False.container.storage[0].randn()
assert v == v0 == v1
if hasattr(rng, "randn"):
v = rng.randn()
v0 = s_rng_default.container.storage[0].randn()
v1 = s_rng_False.container.storage[0].randn()
else:
v = rng.standard_normal()
v0 = s_rng_default.container.storage[0].standard_normal()
v1 = s_rng_False.container.storage[0].standard_normal()
assert v == v0 == v1
def test_get_value_borrow():
rng = np.random.RandomState(123)
@pytest.mark.parametrize(
"rng", [np.random.RandomState(123), np.random.default_rng(123)]
)
def test_get_value_borrow(rng):
s_rng = shared(rng)
r_ = s_rng.container.storage[0]
......@@ -39,11 +49,16 @@ def test_get_value_borrow():
assert r_ is r_T
# either way, the rngs should all be in the same state
assert r_.rand() == r_F.rand()
if hasattr(rng, "rand"):
assert r_.rand() == r_F.rand()
else:
assert r_.standard_normal() == r_F.standard_normal()
def test_get_value_internal_type():
rng = np.random.RandomState(123)
@pytest.mark.parametrize(
"rng", [np.random.RandomState(123), np.random.default_rng(123)]
)
def test_get_value_internal_type(rng):
s_rng = shared(rng)
# there is no special behaviour required of return_internal_type
......@@ -60,23 +75,28 @@ def test_get_value_internal_type():
assert r_ is r_T
# either way, the rngs should all be in the same state
assert r_.rand() == r_F.rand()
if hasattr(rng, "rand"):
assert r_.rand() == r_F.rand()
else:
assert r_.standard_normal() == r_F.standard_normal()
def test_set_value_borrow():
rng = np.random.RandomState(123)
s_rng = shared(rng)
@pytest.mark.parametrize("rng_ctor", [np.random.RandomState, np.random.default_rng])
def test_set_value_borrow(rng_ctor):
s_rng = shared(rng_ctor(123))
new_rng = np.random.RandomState(234234)
new_rng = rng_ctor(234234)
# Test the borrow contract is respected:
# assigning with borrow=False makes a copy
s_rng.set_value(new_rng, borrow=False)
assert new_rng is not s_rng.container.storage[0]
assert new_rng.randn() == s_rng.container.storage[0].randn()
if hasattr(new_rng, "randn"):
assert new_rng.randn() == s_rng.container.storage[0].randn()
else:
assert new_rng.standard_normal() == s_rng.container.storage[0].standard_normal()
# Test that the current implementation is actually borrowing when it can.
rr = np.random.RandomState(33)
rr = rng_ctor(33)
s_rng.set_value(rr, borrow=True)
assert rr is s_rng.container.storage[0]
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论