提交 e0eea331 authored 作者: kc611's avatar kc611 提交者: Brandon T. Willard

Add NumPy Generator support for RandomVariables

上级 adf83fa5
...@@ -105,7 +105,7 @@ def detect_nan(fgraph, i, node, fn): ...@@ -105,7 +105,7 @@ def detect_nan(fgraph, i, node, fn):
for output in fn.outputs: for output in fn.outputs:
if ( if (
not isinstance(output[0], np.random.RandomState) not isinstance(output[0], (np.random.RandomState, np.random.Generator))
and np.isnan(output[0]).any() and np.isnan(output[0]).any()
): ):
print("*** NaN detected ***") print("*** NaN detected ***")
......
...@@ -44,7 +44,7 @@ def _is_numeric_value(arr, var): ...@@ -44,7 +44,7 @@ def _is_numeric_value(arr, var):
""" """
if isinstance(arr, aesara.graph.type._cdata_type): if isinstance(arr, aesara.graph.type._cdata_type):
return False return False
elif isinstance(arr, np.random.mtrand.RandomState): elif isinstance(arr, (np.random.mtrand.RandomState, np.random.Generator)):
return False return False
elif var and getattr(var.tag, "is_rng", False): elif var and getattr(var.tag, "is_rng", False):
return False return False
......
...@@ -1841,7 +1841,7 @@ def verify_grad( ...@@ -1841,7 +1841,7 @@ def verify_grad(
# random_projection should not have elements too small, # random_projection should not have elements too small,
# otherwise too much precision is lost in numerical gradient # otherwise too much precision is lost in numerical gradient
def random_projection(): def random_projection():
plain = rng.rand(*o_fn_out.shape) + 0.5 plain = rng.random(o_fn_out.shape) + 0.5
if cast_to_output_type and o_output.dtype == "float32": if cast_to_output_type and o_output.dtype == "float32":
return np.array(plain, o_output.dtype) return np.array(plain, o_output.dtype)
return plain return plain
......
...@@ -736,7 +736,7 @@ class MRG_RandomStream: ...@@ -736,7 +736,7 @@ class MRG_RandomStream:
def set_rstate(self, seed): def set_rstate(self, seed):
# TODO : need description for method, parameter # TODO : need description for method, parameter
if isinstance(seed, int): if isinstance(seed, (int, np.int32, np.int64)):
if seed == 0: if seed == 0:
raise ValueError("seed should not be 0", seed) raise ValueError("seed should not be 0", seed)
elif seed >= M2: elif seed >= M2:
......
...@@ -72,7 +72,9 @@ def check_equal_numpy(x, y): ...@@ -72,7 +72,9 @@ def check_equal_numpy(x, y):
""" """
if isinstance(x, np.ndarray) and isinstance(y, np.ndarray): if isinstance(x, np.ndarray) and isinstance(y, np.ndarray):
return x.dtype == y.dtype and x.shape == y.shape and np.all(abs(x - y) < 1e-10) return x.dtype == y.dtype and x.shape == y.shape and np.all(abs(x - y) < 1e-10)
elif isinstance(x, np.random.RandomState) and isinstance(y, np.random.RandomState): elif isinstance(x, (np.random.Generator, np.random.RandomState)) and isinstance(
y, (np.random.Generator, np.random.RandomState)
):
return builtins.all( return builtins.all(
np.all(a == b) for a, b in zip(x.__getstate__(), y.__getstate__()) np.all(a == b) for a, b in zip(x.__getstate__(), y.__getstate__())
) )
......
...@@ -6,7 +6,12 @@ import scipy.stats as stats ...@@ -6,7 +6,12 @@ import scipy.stats as stats
import aesara import aesara
from aesara.tensor.basic import as_tensor_variable from aesara.tensor.basic import as_tensor_variable
from aesara.tensor.random.op import RandomVariable, default_shape_from_params from aesara.tensor.random.op import RandomVariable, default_shape_from_params
from aesara.tensor.random.type import RandomGeneratorType, RandomStateType
from aesara.tensor.random.utils import broadcast_params from aesara.tensor.random.utils import broadcast_params
from aesara.tensor.random.var import (
RandomGeneratorSharedVariable,
RandomStateSharedVariable,
)
try: try:
...@@ -165,7 +170,7 @@ class GumbelRV(RandomVariable): ...@@ -165,7 +170,7 @@ class GumbelRV(RandomVariable):
@classmethod @classmethod
def rng_fn( def rng_fn(
cls, cls,
rng: np.random.RandomState, rng: Union[np.random.Generator, np.random.RandomState],
loc: Union[np.ndarray, float], loc: Union[np.ndarray, float],
scale: Union[np.ndarray, float], scale: Union[np.ndarray, float],
size: Optional[Union[List[int], int]], size: Optional[Union[List[int], int]],
...@@ -590,7 +595,8 @@ class PolyaGammaRV(RandomVariable): ...@@ -590,7 +595,8 @@ class PolyaGammaRV(RandomVariable):
@classmethod @classmethod
def rng_fn(cls, rng, b, c, size): def rng_fn(cls, rng, b, c, size):
pg = PyPolyaGamma(rng.randint(2 ** 16)) rand_method = rng.integers if hasattr(rng, "integers") else rng.randint
pg = PyPolyaGamma(rand_method(2 ** 16))
if not size and b.shape == c.shape == (): if not size and b.shape == c.shape == ():
return pg.pgdraw(b, c) return pg.pgdraw(b, c)
...@@ -627,10 +633,41 @@ class RandIntRV(RandomVariable): ...@@ -627,10 +633,41 @@ class RandIntRV(RandomVariable):
low, high = 0, low low, high = 0, low
return super().__call__(low, high, size=size, **kwargs) return super().__call__(low, high, size=size, **kwargs)
def make_node(self, rng, *args, **kwargs):
if not isinstance(
getattr(rng, "type", None), (RandomStateType, RandomStateSharedVariable)
):
raise TypeError("`randint` is only available for `RandomStateType`s")
return super().make_node(rng, *args, **kwargs)
randint = RandIntRV() randint = RandIntRV()
class IntegersRV(RandomVariable):
name = "integers"
ndim_supp = 0
ndims_params = [0, 0]
dtype = "int64"
_print_name = ("integers", "\\operatorname{integers}")
def __call__(self, low, high=None, size=None, **kwargs):
if high is None:
low, high = 0, low
return super().__call__(low, high, size=size, **kwargs)
def make_node(self, rng, *args, **kwargs):
if not isinstance(
getattr(rng, "type", None),
(RandomGeneratorType, RandomGeneratorSharedVariable),
):
raise TypeError("`integers` is only available for `RandomGeneratorType`s")
return super().make_node(rng, *args, **kwargs)
integers = IntegersRV()
class ChoiceRV(RandomVariable): class ChoiceRV(RandomVariable):
name = "choice" name = "choice"
ndim_supp = 0 ndim_supp = 0
...@@ -698,6 +735,7 @@ permutation = PermutationRV() ...@@ -698,6 +735,7 @@ permutation = PermutationRV()
__all__ = [ __all__ = [
"permutation", "permutation",
"choice", "choice",
"integers",
"randint", "randint",
"categorical", "categorical",
"multinomial", "multinomial",
......
...@@ -18,7 +18,7 @@ from aesara.tensor.basic import ( ...@@ -18,7 +18,7 @@ from aesara.tensor.basic import (
) )
from aesara.tensor.elemwise import Elemwise from aesara.tensor.elemwise import Elemwise
from aesara.tensor.exceptions import NotScalarConstantError from aesara.tensor.exceptions import NotScalarConstantError
from aesara.tensor.random.type import RandomStateType from aesara.tensor.random.type import RandomType
from aesara.tensor.random.utils import normalize_size_param, params_broadcast_shapes from aesara.tensor.random.utils import normalize_size_param, params_broadcast_shapes
from aesara.tensor.type import TensorType, all_dtypes from aesara.tensor.type import TensorType, all_dtypes
...@@ -158,7 +158,7 @@ class RandomVariable(Op): ...@@ -158,7 +158,7 @@ class RandomVariable(Op):
def rng_fn(self, rng, *args, **kwargs): def rng_fn(self, rng, *args, **kwargs):
"""Sample a numeric random variate.""" """Sample a numeric random variate."""
return getattr(np.random.RandomState, self.name)(rng, *args, **kwargs) return getattr(rng, self.name)(*args, **kwargs)
def __str__(self): def __str__(self):
props_str = ", ".join((f"{getattr(self, prop)}" for prop in self.__props__[1:])) props_str = ", ".join((f"{getattr(self, prop)}" for prop in self.__props__[1:]))
...@@ -336,8 +336,8 @@ class RandomVariable(Op): ...@@ -336,8 +336,8 @@ class RandomVariable(Op):
Parameters Parameters
---------- ----------
rng: RandomStateType rng: RandomGeneratorType or RandomStateType
Existing Aesara `RandomState` object to be used. Creates a Existing Aesara `Generator` or `RandomState` object to be used. Creates a
new one, if `None`. new one, if `None`.
size: int or Sequence size: int or Sequence
Numpy-like size of the output (i.e. replications). Numpy-like size of the output (i.e. replications).
...@@ -363,9 +363,11 @@ class RandomVariable(Op): ...@@ -363,9 +363,11 @@ class RandomVariable(Op):
) )
if rng is None: if rng is None:
rng = aesara.shared(np.random.RandomState()) rng = aesara.shared(np.random.default_rng())
elif not isinstance(rng.type, RandomStateType): elif not isinstance(rng.type, RandomType):
raise TypeError("The type of rng should be an instance of RandomStateType") raise TypeError(
"The type of rng should be an instance of either RandomGeneratorType or RandomStateType"
)
bcast = self.compute_bcast(dist_params, size) bcast = self.compute_bcast(dist_params, size)
dtype = self.dtype or dtype dtype = self.dtype or dtype
......
...@@ -6,21 +6,22 @@ import aesara ...@@ -6,21 +6,22 @@ import aesara
from aesara.graph.type import Type from aesara.graph.type import Type
class RandomStateType(Type): gen_states_keys = {
"""A Type wrapper for `numpy.random.RandomState`. "MT19937": (["state"], ["key", "pos"]),
"PCG64": (["state", "has_uint32", "uinteger"], ["state", "inc"]),
"Philox": (
["state", "buffer", "buffer_pos", "has_uint32", "uinteger"],
["counter", "key"],
),
"SFC64": (["state", "has_uint32", "uinteger"], ["state"]),
}
The reason this exists (and `Generic` doesn't suffice) is that # We map bit generators to an integer index so that we can avoid using strings
`RandomState` objects that would appear to be equal do not compare equal numpy_bit_gens = {0: "MT19937", 1: "PCG64", 2: "Philox", 3: "SFC64"}
with the `==` operator. This `Type` exists to provide an equals function
that is used by `DebugMode`.
Also works with a `dict` derived from RandomState.get_state() unless
the `strict` argument is explicitly set to `True`.
"""
def __repr__(self): class RandomType(Type):
return "RandomStateType" r"""A Type wrapper for `numpy.random.Generator` and `numpy.random.RandomState`."""
@classmethod @classmethod
def filter(cls, data, strict=False, allow_downcast=None): def filter(cls, data, strict=False, allow_downcast=None):
...@@ -29,6 +30,31 @@ class RandomStateType(Type): ...@@ -29,6 +30,31 @@ class RandomStateType(Type):
else: else:
raise TypeError() raise TypeError()
@staticmethod
def get_shape_info(obj):
return obj.get_value(borrow=True)
@staticmethod
def may_share_memory(a, b):
return a._bit_generator is b._bit_generator
class RandomStateType(RandomType):
r"""A Type wrapper for `numpy.random.RandomState`.
The reason this exists (and `Generic` doesn't suffice) is that
`RandomState` objects that would appear to be equal do not compare equal
with the ``==`` operator.
This `Type` also works with a ``dict`` derived from
`RandomState.get_state(legacy=False)`, unless the ``strict`` argument to `Type.filter`
is explicitly set to ``True``.
"""
def __repr__(self):
return "RandomStateType"
@staticmethod @staticmethod
def is_valid_value(a, strict): def is_valid_value(a, strict):
if isinstance(a, np.random.RandomState): if isinstance(a, np.random.RandomState):
...@@ -73,18 +99,10 @@ class RandomStateType(Type): ...@@ -73,18 +99,10 @@ class RandomStateType(Type):
return _eq(sa, sb) return _eq(sa, sb)
@staticmethod
def get_shape_info(obj):
return obj.get_value(borrow=True)
@staticmethod @staticmethod
def get_size(shape_info): def get_size(shape_info):
return sys.getsizeof(shape_info.get_state(legacy=False)) return sys.getsizeof(shape_info.get_state(legacy=False))
@staticmethod
def may_share_memory(a, b):
return a._bit_generator is b._bit_generator
# Register `RandomStateType`'s C code for `ViewOp`. # Register `RandomStateType`'s C code for `ViewOp`.
aesara.compile.register_view_op_c_code( aesara.compile.register_view_op_c_code(
...@@ -98,3 +116,89 @@ aesara.compile.register_view_op_c_code( ...@@ -98,3 +116,89 @@ aesara.compile.register_view_op_c_code(
) )
random_state_type = RandomStateType() random_state_type = RandomStateType()
class RandomGeneratorType(RandomType):
r"""A Type wrapper for `numpy.random.Generator`.
The reason this exists (and `Generic` doesn't suffice) is that
`Generator` objects that would appear to be equal do not compare equal
with the ``==`` operator.
This `Type` also works with a ``dict`` derived from
`Generator.__get_state__`, unless the ``strict`` argument to `Type.filter`
is explicitly set to ``True``.
"""
def __repr__(self):
return "RandomGeneratorType"
@staticmethod
def is_valid_value(a, strict):
if isinstance(a, np.random.Generator):
return True
if not strict and isinstance(a, dict):
if "bit_generator" not in a:
return False
else:
bit_gen_key = a["bit_generator"]
if hasattr(bit_gen_key, "_value"):
bit_gen_key = int(bit_gen_key._value)
bit_gen_key = numpy_bit_gens[bit_gen_key]
gen_keys, state_keys = gen_states_keys[bit_gen_key]
for key in gen_keys:
if key not in a:
return False
for key in state_keys:
if key not in a["state"]:
return False
return True
return False
@staticmethod
def values_eq(a, b):
sa = a if isinstance(a, dict) else a.__getstate__()
sb = b if isinstance(b, dict) else b.__getstate__()
def _eq(sa, sb):
for key in sa:
if isinstance(sa[key], dict):
if not _eq(sa[key], sb[key]):
return False
elif isinstance(sa[key], np.ndarray):
if not np.array_equal(sa[key], sb[key]):
return False
else:
if sa[key] != sb[key]:
return False
return True
return _eq(sa, sb)
@staticmethod
def get_size(shape_info):
state = shape_info.__getstate__()
return sys.getsizeof(state)
# Register `RandomGeneratorType`'s C code for `ViewOp`.
aesara.compile.register_view_op_c_code(
RandomGeneratorType,
"""
Py_XDECREF(%(oname)s);
%(oname)s = %(iname)s;
Py_XINCREF(%(oname)s);
""",
1,
)
random_generator_type = RandomGeneratorType()
...@@ -117,25 +117,28 @@ def normalize_size_param(size): ...@@ -117,25 +117,28 @@ def normalize_size_param(size):
class RandomStream: class RandomStream:
"""Module component with similar interface to `numpy.random.RandomState`. """Module component with similar interface to `numpy.random.Generator`.
Attributes Attributes
---------- ----------
seed: None or int seed: None or int
A default seed to initialize the RandomState instances after build. A default seed to initialize the `Generator` instances after build.
state_updates: list state_updates: list
A list of pairs of the form `(input_r, output_r)`. This will be A list of pairs of the form ``(input_r, output_r)``. This will be
over-ridden by the module instance to contain stream generators. over-ridden by the module instance to contain stream generators.
default_instance_seed: int default_instance_seed: int
Instance variable should take None or integer value. Used to seed the Instance variable should take None or integer value. Used to seed the
random number generator that provides seeds for member streams. random number generator that provides seeds for member streams.
gen_seedgen: numpy.random.RandomState gen_seedgen: numpy.random.Generator
`RandomState` instance that `RandomStream.gen` uses to seed new `Generator` instance that `RandomStream.gen` uses to seed new
streams. streams.
rng_ctor: type
Constructor used to create the underlying RNG objects. The default
is `np.random.default_rng`.
""" """
def __init__(self, seed=None, namespace=None): def __init__(self, seed=None, namespace=None, rng_ctor=np.random.default_rng):
if namespace is None: if namespace is None:
from aesara.tensor.random import basic # pylint: disable=import-self from aesara.tensor.random import basic # pylint: disable=import-self
...@@ -145,7 +148,8 @@ class RandomStream: ...@@ -145,7 +148,8 @@ class RandomStream:
self.default_instance_seed = seed self.default_instance_seed = seed
self.state_updates = [] self.state_updates = []
self.gen_seedgen = np.random.RandomState(seed) self.gen_seedgen = np.random.default_rng(seed)
self.rng_ctor = rng_ctor
def __getattr__(self, obj): def __getattr__(self, obj):
...@@ -191,11 +195,11 @@ class RandomStream: ...@@ -191,11 +195,11 @@ class RandomStream:
if seed is None: if seed is None:
seed = self.default_instance_seed seed = self.default_instance_seed
self.gen_seedgen.seed(seed) self.gen_seedgen = np.random.default_rng(seed)
for old_r, new_r in self.state_updates: for old_r, new_r in self.state_updates:
old_r_seed = self.gen_seedgen.randint(2 ** 30) old_r_seed = self.gen_seedgen.integers(2 ** 30)
old_r.set_value(np.random.RandomState(int(old_r_seed)), borrow=True) old_r.set_value(self.rng_ctor(int(old_r_seed)), borrow=True)
def gen(self, op, *args, **kwargs): def gen(self, op, *args, **kwargs):
"""Create a new random stream in this container. """Create a new random stream in this container.
...@@ -213,18 +217,18 @@ class RandomStream: ...@@ -213,18 +217,18 @@ class RandomStream:
------- -------
TensorVariable TensorVariable
The symbolic random draw part of op()'s return value. The symbolic random draw part of op()'s return value.
This function stores the updated `RandomStateType` variable This function stores the updated `RandomGeneratorType` variable
for use at `build` time. for use at `build` time.
""" """
if "rng" in kwargs: if "rng" in kwargs:
raise TypeError( raise ValueError(
"The rng option cannot be used with a variate in a RandomStream" "The `rng` option cannot be used with a variate in a `RandomStream`"
) )
# Generate a new random state # Generate a new random state
seed = int(self.gen_seedgen.randint(2 ** 30)) seed = int(self.gen_seedgen.integers(2 ** 30))
random_state_variable = shared(np.random.RandomState(seed)) random_state_variable = shared(self.rng_ctor(seed))
# Distinguish it from other shared variables (why?) # Distinguish it from other shared variables (why?)
random_state_variable.tag.is_rng = True random_state_variable.tag.is_rng = True
......
...@@ -3,7 +3,7 @@ import copy ...@@ -3,7 +3,7 @@ import copy
import numpy as np import numpy as np
from aesara.compile.sharedvalue import SharedVariable, shared_constructor from aesara.compile.sharedvalue import SharedVariable, shared_constructor
from aesara.tensor.random.type import random_state_type from aesara.tensor.random.type import random_generator_type, random_state_type
class RandomStateSharedVariable(SharedVariable): class RandomStateSharedVariable(SharedVariable):
...@@ -11,20 +11,30 @@ class RandomStateSharedVariable(SharedVariable): ...@@ -11,20 +11,30 @@ class RandomStateSharedVariable(SharedVariable):
return "RandomStateSharedVariable({})".format(repr(self.container)) return "RandomStateSharedVariable({})".format(repr(self.container))
class RandomGeneratorSharedVariable(SharedVariable):
def __str__(self):
return "RandomGeneratorSharedVariable({})".format(repr(self.container))
@shared_constructor @shared_constructor
def randomstate_constructor( def randomgen_constructor(
value, name=None, strict=False, allow_downcast=None, borrow=False value, name=None, strict=False, allow_downcast=None, borrow=False
): ):
""" r"""`SharedVariable` Constructor for NumPy's `Generator` and/or `RandomState`."""
SharedVariable Constructor for RandomState. if isinstance(value, np.random.RandomState):
rng_sv_type = RandomStateSharedVariable
rng_type = random_state_type
elif isinstance(value, np.random.Generator):
rng_sv_type = RandomGeneratorSharedVariable
rng_type = random_generator_type
else:
raise TypeError()
"""
if not isinstance(value, np.random.RandomState):
raise TypeError
if not borrow: if not borrow:
value = copy.deepcopy(value) value = copy.deepcopy(value)
return RandomStateSharedVariable(
type=random_state_type, return rng_sv_type(
type=rng_type,
value=value, value=value,
name=name, name=name,
strict=strict, strict=strict,
......
...@@ -57,7 +57,7 @@ if __name__ == "__main__": ...@@ -57,7 +57,7 @@ if __name__ == "__main__":
license=LICENSE, license=LICENSE,
platforms=PLATFORMS, platforms=PLATFORMS,
packages=find_packages(exclude=["tests", "tests.*"]), packages=find_packages(exclude=["tests", "tests.*"]),
install_requires=["numpy>=1.9.1", "scipy>=0.14", "filelock"], install_requires=["numpy>=1.17.0", "scipy>=0.14", "filelock"],
package_data={ package_data={
"": [ "": [
"*.txt", "*.txt",
......
...@@ -29,6 +29,7 @@ from aesara.tensor.random.basic import ( ...@@ -29,6 +29,7 @@ from aesara.tensor.random.basic import (
halfcauchy, halfcauchy,
halfnormal, halfnormal,
hypergeometric, hypergeometric,
integers,
invgamma, invgamma,
laplace, laplace,
logistic, logistic,
...@@ -58,7 +59,7 @@ def set_aesara_flags(): ...@@ -58,7 +59,7 @@ def set_aesara_flags():
yield yield
def rv_numpy_tester(rv, *params, **kwargs): def rv_numpy_tester(rv, *params, rng=None, **kwargs):
"""Test for correspondence between `RandomVariable` and NumPy shape and """Test for correspondence between `RandomVariable` and NumPy shape and
broadcast dimensions. broadcast dimensions.
""" """
...@@ -70,9 +71,9 @@ def rv_numpy_tester(rv, *params, **kwargs): ...@@ -70,9 +71,9 @@ def rv_numpy_tester(rv, *params, **kwargs):
if name is None: if name is None:
name = rv.__name__ name = rv.__name__
test_fn = getattr(np.random, name) test_fn = getattr(rng or np.random, name)
aesara_res = rv(*params, **kwargs) aesara_res = rv(*params, rng=shared(rng) if rng else None, **kwargs)
param_vals = [get_test_value(p) if isinstance(p, Variable) else p for p in params] param_vals = [get_test_value(p) if isinstance(p, Variable) else p for p in params]
kwargs_vals = { kwargs_vals = {
...@@ -738,17 +739,47 @@ def test_polyagamma_samples(): ...@@ -738,17 +739,47 @@ def test_polyagamma_samples():
assert np.all(np.abs(np.diff(bcast_smpl.flat)) > 0.0) assert np.all(np.abs(np.diff(bcast_smpl.flat)) > 0.0)
def test_random_integer_samples(): def test_randint_samples():
rv_numpy_tester(randint, 10, None) with raises(TypeError):
rv_numpy_tester(randint, 0, 1) randint(10, rng=shared(np.random.default_rng()))
rv_numpy_tester(randint, 0, 1, size=[3])
rv_numpy_tester(randint, [0, 1, 2], 5) rng = np.random.RandomState(2313)
rv_numpy_tester(randint, [0, 1, 2], 5, size=[3, 3]) rv_numpy_tester(randint, 10, None, rng=rng)
rv_numpy_tester(randint, [0], [5], size=[1]) rv_numpy_tester(randint, 0, 1, rng=rng)
rv_numpy_tester(randint, aet.as_tensor_variable([-1]), [1], size=[1]) rv_numpy_tester(randint, 0, 1, size=[3], rng=rng)
rv_numpy_tester(randint, [0, 1, 2], 5, rng=rng)
rv_numpy_tester(randint, [0, 1, 2], 5, size=[3, 3], rng=rng)
rv_numpy_tester(randint, [0], [5], size=[1], rng=rng)
rv_numpy_tester(randint, aet.as_tensor_variable([-1]), [1], size=[1], rng=rng)
rv_numpy_tester( rv_numpy_tester(
randint, aet.as_tensor_variable([-1]), [1], size=aet.as_tensor_variable([1]) randint,
aet.as_tensor_variable([-1]),
[1],
size=aet.as_tensor_variable([1]),
rng=rng,
)
def test_integers_samples():
with raises(TypeError):
integers(10, rng=shared(np.random.RandomState()))
rng = np.random.default_rng(2313)
rv_numpy_tester(integers, 10, None, rng=rng)
rv_numpy_tester(integers, 0, 1, rng=rng)
rv_numpy_tester(integers, 0, 1, size=[3], rng=rng)
rv_numpy_tester(integers, [0, 1, 2], 5, rng=rng)
rv_numpy_tester(integers, [0, 1, 2], 5, size=[3, 3], rng=rng)
rv_numpy_tester(integers, [0], [5], size=[1], rng=rng)
rv_numpy_tester(integers, aet.as_tensor_variable([-1]), [1], size=[1], rng=rng)
rv_numpy_tester(
integers,
aet.as_tensor_variable([-1]),
[1],
size=aet.as_tensor_variable([1]),
rng=rng,
) )
......
...@@ -6,7 +6,12 @@ import pytest ...@@ -6,7 +6,12 @@ import pytest
from aesara import shared from aesara import shared
from aesara.compile.ops import ViewOp from aesara.compile.ops import ViewOp
from aesara.tensor.random.type import RandomStateType, random_state_type from aesara.tensor.random.type import (
RandomGeneratorType,
RandomStateType,
random_generator_type,
random_state_type,
)
# @pytest.mark.skipif( # @pytest.mark.skipif(
...@@ -24,8 +29,8 @@ def test_view_op_c_code(): ...@@ -24,8 +29,8 @@ def test_view_op_c_code():
# rng_view, # rng_view,
# mode=Mode(optimizer=None, linker=CLinker()), # mode=Mode(optimizer=None, linker=CLinker()),
# ) # )
assert ViewOp.c_code_and_version[RandomStateType] assert ViewOp.c_code_and_version[RandomStateType]
assert ViewOp.c_code_and_version[RandomGeneratorType]
class TestRandomStateType: class TestRandomStateType:
...@@ -106,9 +111,112 @@ class TestRandomStateType: ...@@ -106,9 +111,112 @@ class TestRandomStateType:
assert size == sys.getsizeof(rng.get_state(legacy=False)) assert size == sys.getsizeof(rng.get_state(legacy=False))
def test_may_share_memory(self): def test_may_share_memory(self):
rng_a = np.random.RandomState(12) bg1 = np.random.MT19937()
bg = np.random.PCG64() bg2 = np.random.MT19937()
rng_b = np.random.RandomState(bg)
rng_a = np.random.RandomState(bg1)
rng_b = np.random.RandomState(bg2)
rng_var_a = shared(rng_a, borrow=True)
rng_var_b = shared(rng_b, borrow=True)
shape_info_a = random_state_type.get_shape_info(rng_var_a)
shape_info_b = random_state_type.get_shape_info(rng_var_b)
assert random_state_type.may_share_memory(shape_info_a, shape_info_b) is False
rng_c = np.random.RandomState(bg2)
rng_var_c = shared(rng_c, borrow=True)
shape_info_c = random_state_type.get_shape_info(rng_var_c)
assert random_state_type.may_share_memory(shape_info_b, shape_info_c) is True
class TestRandomGeneratorType:
def test_pickle(self):
rng_r = random_generator_type()
rng_pkl = pickle.dumps(rng_r)
rng_unpkl = pickle.loads(rng_pkl)
assert isinstance(rng_unpkl, type(rng_r))
assert isinstance(rng_unpkl.type, type(rng_r.type))
def test_repr(self):
assert repr(random_generator_type) == "RandomGeneratorType"
def test_filter(self):
rng_type = random_generator_type
rng = np.random.default_rng()
assert rng_type.filter(rng) is rng
with pytest.raises(TypeError):
rng_type.filter(1)
rng = rng.__getstate__()
assert rng_type.is_valid_value(rng, strict=False)
rng["state"] = {}
assert rng_type.is_valid_value(rng, strict=False) is False
rng = {}
assert rng_type.is_valid_value(rng, strict=False) is False
def test_values_eq(self):
rng_type = random_generator_type
bg_1 = np.random.PCG64()
bg_2 = np.random.Philox()
bg_3 = np.random.MT19937()
bg_4 = np.random.SFC64()
bitgen_a = np.random.Generator(bg_1)
bitgen_b = np.random.Generator(bg_1)
assert rng_type.values_eq(bitgen_a, bitgen_b)
bitgen_c = np.random.Generator(bg_2)
bitgen_d = np.random.Generator(bg_2)
assert rng_type.values_eq(bitgen_c, bitgen_d)
bitgen_e = np.random.Generator(bg_3)
bitgen_f = np.random.Generator(bg_3)
assert rng_type.values_eq(bitgen_e, bitgen_f)
bitgen_g = np.random.Generator(bg_4)
bitgen_h = np.random.Generator(bg_4)
assert rng_type.values_eq(bitgen_g, bitgen_h)
assert rng_type.is_valid_value(bitgen_a, strict=True)
assert rng_type.is_valid_value(bitgen_b.__getstate__(), strict=False)
assert rng_type.is_valid_value(bitgen_c, strict=True)
assert rng_type.is_valid_value(bitgen_d.__getstate__(), strict=False)
assert rng_type.is_valid_value(bitgen_e, strict=True)
assert rng_type.is_valid_value(bitgen_f.__getstate__(), strict=False)
assert rng_type.is_valid_value(bitgen_g, strict=True)
assert rng_type.is_valid_value(bitgen_h.__getstate__(), strict=False)
def test_get_shape_info(self):
rng = np.random.default_rng(12)
rng_a = shared(rng)
assert isinstance(
random_generator_type.get_shape_info(rng_a), np.random.Generator
)
def test_get_size(self):
rng = np.random.Generator(np.random.PCG64(12))
rng_a = shared(rng)
shape_info = random_generator_type.get_shape_info(rng_a)
size = random_generator_type.get_size(shape_info)
assert size == sys.getsizeof(rng.__getstate__())
def test_may_share_memory(self):
bg_a = np.random.PCG64()
bg_b = np.random.PCG64()
rng_a = np.random.Generator(bg_a)
rng_b = np.random.Generator(bg_b)
rng_var_a = shared(rng_a, borrow=True) rng_var_a = shared(rng_a, borrow=True)
rng_var_b = shared(rng_b, borrow=True) rng_var_b = shared(rng_b, borrow=True)
...@@ -117,7 +225,7 @@ class TestRandomStateType: ...@@ -117,7 +225,7 @@ class TestRandomStateType:
assert random_state_type.may_share_memory(shape_info_a, shape_info_b) is False assert random_state_type.may_share_memory(shape_info_a, shape_info_b) is False
rng_c = np.random.RandomState(bg) rng_c = np.random.Generator(bg_b)
rng_var_c = shared(rng_c, borrow=True) rng_var_c = shared(rng_c, borrow=True)
shape_info_c = random_state_type.get_shape_info(rng_var_c) shape_info_c = random_state_type.get_shape_info(rng_var_c)
......
import numpy as np import numpy as np
import pytest
from aesara import shared from aesara import shared
def test_RandomStateSharedVariable(): @pytest.mark.parametrize(
rng = np.random.RandomState(123) "rng", [np.random.RandomState(123), np.random.default_rng(123)]
)
def test_GeneratorSharedVariable(rng):
s_rng_default = shared(rng) s_rng_default = shared(rng)
s_rng_True = shared(rng, borrow=True) s_rng_True = shared(rng, borrow=True)
s_rng_False = shared(rng, borrow=False) s_rng_False = shared(rng, borrow=False)
...@@ -17,15 +20,22 @@ def test_RandomStateSharedVariable(): ...@@ -17,15 +20,22 @@ def test_RandomStateSharedVariable():
assert s_rng_True.container.storage[0] is rng assert s_rng_True.container.storage[0] is rng
# ensure that all the random number generators are in the same state # ensure that all the random number generators are in the same state
v = rng.randn() if hasattr(rng, "randn"):
v0 = s_rng_default.container.storage[0].randn() v = rng.randn()
v1 = s_rng_False.container.storage[0].randn() v0 = s_rng_default.container.storage[0].randn()
assert v == v0 == v1 v1 = s_rng_False.container.storage[0].randn()
else:
v = rng.standard_normal()
v0 = s_rng_default.container.storage[0].standard_normal()
v1 = s_rng_False.container.storage[0].standard_normal()
assert v == v0 == v1
def test_get_value_borrow():
rng = np.random.RandomState(123) @pytest.mark.parametrize(
"rng", [np.random.RandomState(123), np.random.default_rng(123)]
)
def test_get_value_borrow(rng):
s_rng = shared(rng) s_rng = shared(rng)
r_ = s_rng.container.storage[0] r_ = s_rng.container.storage[0]
...@@ -39,11 +49,16 @@ def test_get_value_borrow(): ...@@ -39,11 +49,16 @@ def test_get_value_borrow():
assert r_ is r_T assert r_ is r_T
# either way, the rngs should all be in the same state # either way, the rngs should all be in the same state
assert r_.rand() == r_F.rand() if hasattr(rng, "rand"):
assert r_.rand() == r_F.rand()
else:
assert r_.standard_normal() == r_F.standard_normal()
def test_get_value_internal_type(): @pytest.mark.parametrize(
rng = np.random.RandomState(123) "rng", [np.random.RandomState(123), np.random.default_rng(123)]
)
def test_get_value_internal_type(rng):
s_rng = shared(rng) s_rng = shared(rng)
# there is no special behaviour required of return_internal_type # there is no special behaviour required of return_internal_type
...@@ -60,23 +75,28 @@ def test_get_value_internal_type(): ...@@ -60,23 +75,28 @@ def test_get_value_internal_type():
assert r_ is r_T assert r_ is r_T
# either way, the rngs should all be in the same state # either way, the rngs should all be in the same state
assert r_.rand() == r_F.rand() if hasattr(rng, "rand"):
assert r_.rand() == r_F.rand()
else:
assert r_.standard_normal() == r_F.standard_normal()
def test_set_value_borrow(): @pytest.mark.parametrize("rng_ctor", [np.random.RandomState, np.random.default_rng])
rng = np.random.RandomState(123) def test_set_value_borrow(rng_ctor):
s_rng = shared(rng_ctor(123))
s_rng = shared(rng)
new_rng = np.random.RandomState(234234) new_rng = rng_ctor(234234)
# Test the borrow contract is respected: # Test the borrow contract is respected:
# assigning with borrow=False makes a copy # assigning with borrow=False makes a copy
s_rng.set_value(new_rng, borrow=False) s_rng.set_value(new_rng, borrow=False)
assert new_rng is not s_rng.container.storage[0] assert new_rng is not s_rng.container.storage[0]
assert new_rng.randn() == s_rng.container.storage[0].randn() if hasattr(new_rng, "randn"):
assert new_rng.randn() == s_rng.container.storage[0].randn()
else:
assert new_rng.standard_normal() == s_rng.container.storage[0].standard_normal()
# Test that the current implementation is actually borrowing when it can. # Test that the current implementation is actually borrowing when it can.
rr = np.random.RandomState(33) rr = rng_ctor(33)
s_rng.set_value(rr, borrow=True) s_rng.set_value(rr, borrow=True)
assert rr is s_rng.container.storage[0] assert rr is s_rng.container.storage[0]
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论