提交 5d4b0c4b authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Remove RandomState type in remaining backends

上级 14da898c
......@@ -2,7 +2,7 @@ from functools import singledispatch
import jax
import numpy as np
from numpy.random import Generator, RandomState
from numpy.random import Generator
from numpy.random.bit_generator import ( # type: ignore[attr-defined]
_coerce_to_uint32_array,
)
......@@ -54,15 +54,6 @@ def assert_size_argument_jax_compatible(node):
raise NotImplementedError(SIZE_NOT_COMPATIBLE)
@jax_typify.register(RandomState)
def jax_typify_RandomState(state, **kwargs):
state = state.get_state(legacy=False)
state["bit_generator"] = numpy_bit_gens[state["bit_generator"]]
# XXX: Is this a reasonable approach?
state["jax_state"] = state["state"]["key"][0:2]
return state
@jax_typify.register(Generator)
def jax_typify_Generator(rng, **kwargs):
state = rng.__getstate__()
......@@ -214,7 +205,6 @@ def jax_sample_fn_categorical(op, node):
return sample_fn
@jax_sample_fn.register(ptr.RandIntRV)
@jax_sample_fn.register(ptr.IntegersRV)
@jax_sample_fn.register(ptr.UniformRV)
def jax_sample_fn_uniform(op, node):
......
......@@ -25,7 +25,6 @@ from pytensor.link.utils import (
)
from pytensor.tensor import get_vector_length
from pytensor.tensor.random.op import RandomVariable, RandomVariableWithCoreShape
from pytensor.tensor.random.type import RandomStateType
from pytensor.tensor.type_other import NoneTypeT
from pytensor.tensor.utils import _parse_gufunc_signature
......@@ -348,9 +347,6 @@ def numba_funcify_RandomVariable(op: RandomVariableWithCoreShape, node, **kwargs
[rv_node] = op.fgraph.apply_nodes
rv_op: RandomVariable = rv_node.op
rng_param = rv_op.rng_param(rv_node)
if isinstance(rng_param.type, RandomStateType):
raise TypeError("Numba does not support NumPy `RandomStateType`s")
size = rv_op.size_param(rv_node)
dist_params = rv_op.dist_params(rv_node)
size_len = None if isinstance(size.type, NoneTypeT) else get_vector_length(size)
......
......@@ -2,5 +2,5 @@
import pytensor.tensor.random.rewriting
import pytensor.tensor.random.utils
from pytensor.tensor.random.basic import *
from pytensor.tensor.random.op import RandomState, default_rng
from pytensor.tensor.random.op import default_rng
from pytensor.tensor.random.utils import RandomStream
......@@ -9,15 +9,10 @@ from pytensor.tensor import get_vector_length, specify_shape
from pytensor.tensor.basic import as_tensor_variable
from pytensor.tensor.math import sqrt
from pytensor.tensor.random.op import RandomVariable
from pytensor.tensor.random.type import RandomGeneratorType, RandomStateType
from pytensor.tensor.random.utils import (
broadcast_params,
normalize_size_param,
)
from pytensor.tensor.random.var import (
RandomGeneratorSharedVariable,
RandomStateSharedVariable,
)
try:
......@@ -645,7 +640,7 @@ class GumbelRV(ScipyRandomVariable):
@classmethod
def rng_fn_scipy(
cls,
rng: np.random.Generator | np.random.RandomState,
rng: np.random.Generator,
loc: np.ndarray | float,
scale: np.ndarray | float,
size: list[int] | int | None,
......@@ -1880,58 +1875,6 @@ class CategoricalRV(RandomVariable):
categorical = CategoricalRV()
class RandIntRV(RandomVariable):
r"""A discrete uniform random variable.
Only available for `RandomStateType`. Use `integers` with `RandomGeneratorType`\s.
"""
name = "randint"
signature = "(),()->()"
dtype = "int64"
_print_name = ("randint", "\\operatorname{randint}")
def __call__(self, low, high=None, size=None, **kwargs):
r"""Draw samples from a discrete uniform distribution.
Signature
---------
`() -> ()`
Parameters
----------
low
Lower boundary of the output interval. All values generated will
be greater than or equal to `low`, unless `high=None`, in which case
all values generated are greater than or equal to `0` and
smaller than `low` (exclusive).
high
Upper boundary of the output interval. All values generated
will be smaller than `high` (exclusive).
size
Sample shape. If the given size is `(m, n, k)`, then `m * n * k`
independent, identically distributed samples are
returned. Default is `None`, in which case a single
sample is returned.
"""
if high is None:
low, high = 0, low
return super().__call__(low, high, size=size, **kwargs)
def make_node(self, rng, *args, **kwargs):
if not isinstance(
getattr(rng, "type", None), RandomStateType | RandomStateSharedVariable
):
raise TypeError("`randint` is only available for `RandomStateType`s")
return super().make_node(rng, *args, **kwargs)
randint = RandIntRV()
class IntegersRV(RandomVariable):
r"""A discrete uniform random variable.
......@@ -1971,14 +1914,6 @@ class IntegersRV(RandomVariable):
low, high = 0, low
return super().__call__(low, high, size=size, **kwargs)
def make_node(self, rng, *args, **kwargs):
if not isinstance(
getattr(rng, "type", None),
RandomGeneratorType | RandomGeneratorSharedVariable,
):
raise TypeError("`integers` is only available for `RandomGeneratorType`s")
return super().make_node(rng, *args, **kwargs)
integers = IntegersRV()
......@@ -2201,7 +2136,6 @@ __all__ = [
"permutation",
"choice",
"integers",
"randint",
"categorical",
"multinomial",
"betabinom",
......
......@@ -20,7 +20,7 @@ from pytensor.tensor.basic import (
infer_static_shape,
)
from pytensor.tensor.blockwise import OpWithCoreShape
from pytensor.tensor.random.type import RandomGeneratorType, RandomStateType, RandomType
from pytensor.tensor.random.type import RandomGeneratorType, RandomType
from pytensor.tensor.random.utils import (
compute_batch_shape,
explicit_expand_dims,
......@@ -324,9 +324,8 @@ class RandomVariable(Op):
Parameters
----------
rng: RandomGeneratorType or RandomStateType
Existing PyTensor `Generator` or `RandomState` object to be used. Creates a
new one, if `None`.
rng: RandomGeneratorType
Existing PyTensor `Generator` object to be used. Creates a new one, if `None`.
size: int or Sequence
NumPy-like size parameter.
dtype: str
......@@ -354,7 +353,7 @@ class RandomVariable(Op):
rng = pytensor.shared(np.random.default_rng())
elif not isinstance(rng.type, RandomType):
raise TypeError(
"The type of rng should be an instance of either RandomGeneratorType or RandomStateType"
"The type of rng should be an instance of RandomGeneratorType "
)
inferred_shape = self._infer_shape(size, dist_params)
......@@ -436,14 +435,6 @@ class AbstractRNGConstructor(Op):
output_storage[0][0] = getattr(np.random, self.random_constructor)(seed=seed)
class RandomStateConstructor(AbstractRNGConstructor):
random_type = RandomStateType()
random_constructor = "RandomState"
RandomState = RandomStateConstructor()
class DefaultGeneratorMakerOp(AbstractRNGConstructor):
random_type = RandomGeneratorType()
random_constructor = "default_rng"
......
......@@ -31,97 +31,6 @@ class RandomType(Type[T]):
return a._bit_generator is b._bit_generator # type: ignore[attr-defined]
class RandomStateType(RandomType[np.random.RandomState]):
r"""A Type wrapper for `numpy.random.RandomState`.
The reason this exists (and `Generic` doesn't suffice) is that
`RandomState` objects that would appear to be equal do not compare equal
with the ``==`` operator.
This `Type` also works with a ``dict`` derived from
`RandomState.get_state(legacy=False)`, unless the ``strict`` argument to `Type.filter`
is explicitly set to ``True``.
"""
def __repr__(self):
return "RandomStateType"
def filter(self, data, strict: bool = False, allow_downcast=None):
"""
XXX: This doesn't convert `data` to the same type of underlying RNG type
as `self`. It really only checks that `data` is of the appropriate type
to be a valid `RandomStateType`.
In other words, it serves as a `Type.is_valid_value` implementation,
but, because the default `Type.is_valid_value` depends on
`Type.filter`, we need to have it here to avoid surprising circular
dependencies in sub-classes.
"""
if isinstance(data, np.random.RandomState):
return data
if not strict and isinstance(data, dict):
gen_keys = ["bit_generator", "gauss", "has_gauss", "state"]
state_keys = ["key", "pos"]
for key in gen_keys:
if key not in data:
raise TypeError()
for key in state_keys:
if key not in data["state"]:
raise TypeError()
state_key = data["state"]["key"]
if state_key.shape == (624,) and state_key.dtype == np.uint32:
# TODO: Add an option to convert to a `RandomState` instance?
return data
raise TypeError()
@staticmethod
def values_eq(a, b):
sa = a if isinstance(a, dict) else a.get_state(legacy=False)
sb = b if isinstance(b, dict) else b.get_state(legacy=False)
def _eq(sa, sb):
for key in sa:
if isinstance(sa[key], dict):
if not _eq(sa[key], sb[key]):
return False
elif isinstance(sa[key], np.ndarray):
if not np.array_equal(sa[key], sb[key]):
return False
else:
if sa[key] != sb[key]:
return False
return True
return _eq(sa, sb)
def __eq__(self, other):
return type(self) == type(other)
def __hash__(self):
return hash(type(self))
# Register `RandomStateType`'s C code for `ViewOp`.
pytensor.compile.register_view_op_c_code(
RandomStateType,
"""
Py_XDECREF(%(oname)s);
%(oname)s = %(iname)s;
Py_XINCREF(%(oname)s);
""",
1,
)
random_state_type = RandomStateType()
class RandomGeneratorType(RandomType[np.random.Generator]):
r"""A Type wrapper for `numpy.random.Generator`.
......
......@@ -209,9 +209,7 @@ class RandomStream:
self,
seed: int | None = None,
namespace: ModuleType | None = None,
rng_ctor: Literal[
np.random.RandomState, np.random.Generator
] = np.random.default_rng,
rng_ctor: Literal[np.random.Generator] = np.random.default_rng,
):
if namespace is None:
from pytensor.tensor.random import basic # pylint: disable=import-self
......@@ -223,12 +221,6 @@ class RandomStream:
self.default_instance_seed = seed
self.state_updates = []
self.gen_seedgen = np.random.SeedSequence(seed)
if isinstance(rng_ctor, type) and issubclass(rng_ctor, np.random.RandomState):
# The legacy state does not accept `SeedSequence`s directly
def rng_ctor(seed):
return np.random.RandomState(np.random.MT19937(seed))
self.rng_ctor = rng_ctor
def __getattr__(self, obj):
......
......@@ -3,17 +3,12 @@ import copy
import numpy as np
from pytensor.compile.sharedvalue import SharedVariable, shared_constructor
from pytensor.tensor.random.type import random_generator_type, random_state_type
class RandomStateSharedVariable(SharedVariable):
def __str__(self):
return self.name or f"RandomStateSharedVariable({self.container!r})"
from pytensor.tensor.random.type import random_generator_type
class RandomGeneratorSharedVariable(SharedVariable):
def __str__(self):
return self.name or f"RandomGeneratorSharedVariable({self.container!r})"
return self.name or f"RNG({self.container!r})"
@shared_constructor.register(np.random.RandomState)
......@@ -23,11 +18,12 @@ def randomgen_constructor(
):
r"""`SharedVariable` constructor for NumPy's `Generator` and/or `RandomState`."""
if isinstance(value, np.random.RandomState):
rng_sv_type = RandomStateSharedVariable
rng_type = random_state_type
elif isinstance(value, np.random.Generator):
rng_sv_type = RandomGeneratorSharedVariable
rng_type = random_generator_type
raise TypeError(
"`np.RandomState` is no longer supported in PyTensor. Use `np.random.Generator` instead."
)
rng_sv_type = RandomGeneratorSharedVariable
rng_type = random_generator_type
if not borrow:
value = copy.deepcopy(value)
......
......@@ -49,7 +49,7 @@ def test_random_RandomStream():
assert not np.array_equal(jax_res_1, jax_res_2)
@pytest.mark.parametrize("rng_ctor", (np.random.RandomState, np.random.default_rng))
@pytest.mark.parametrize("rng_ctor", (np.random.default_rng,))
def test_random_updates(rng_ctor):
original_value = rng_ctor(seed=98)
rng = shared(original_value, name="original_rng", borrow=False)
......@@ -299,22 +299,6 @@ def test_replaced_shared_rng_storage_ordering_equality():
"poisson",
lambda *args: args,
),
(
ptr.randint,
[
set_test_value(
pt.lscalar(),
np.array(0, dtype=np.int64),
),
set_test_value( # high-value necessary since test on cdf
pt.lscalar(),
np.array(1000, dtype=np.int64),
),
],
(),
"randint",
lambda *args: args,
),
(
ptr.integers,
[
......@@ -489,11 +473,7 @@ def test_random_RandomVariable(rv_op, dist_params, base_size, cdf_name, params_c
The parameters passed to the op.
"""
if rv_op is ptr.integers:
# Integers only accepts Generator, not RandomState
rng = shared(np.random.default_rng(29402))
else:
rng = shared(np.random.RandomState(29402))
rng = shared(np.random.default_rng(29403))
g = rv_op(*dist_params, size=(10000, *base_size), rng=rng)
g_fn = compile_random_function(dist_params, g, mode=jax_mode)
samples = g_fn(
......@@ -545,7 +525,7 @@ def test_size_implied_by_broadcasted_parameters(rv_fn):
@pytest.mark.parametrize("size", [(), (4,)])
def test_random_bernoulli(size):
rng = shared(np.random.RandomState(123))
rng = shared(np.random.default_rng(123))
g = pt.random.bernoulli(0.5, size=(1000, *size), rng=rng)
g_fn = compile_random_function([], g, mode=jax_mode)
samples = g_fn()
......@@ -553,7 +533,7 @@ def test_random_bernoulli(size):
def test_random_mvnormal():
rng = shared(np.random.RandomState(123))
rng = shared(np.random.default_rng(123))
mu = np.ones(4)
cov = np.eye(4)
......@@ -571,7 +551,7 @@ def test_random_mvnormal():
],
)
def test_random_dirichlet(parameter, size):
rng = shared(np.random.RandomState(123))
rng = shared(np.random.default_rng(123))
g = pt.random.dirichlet(parameter, size=(1000, *size), rng=rng)
g_fn = compile_random_function([], g, mode=jax_mode)
samples = g_fn()
......@@ -598,7 +578,7 @@ def test_random_choice():
assert np.all(samples % 2 == 1)
# `replace=False` and `p is None`
rng = shared(np.random.RandomState(123))
rng = shared(np.random.default_rng(123))
g = pt.random.choice(np.arange(100), replace=False, size=(2, 49), rng=rng)
g_fn = compile_random_function([], g, mode=jax_mode)
samples = g_fn()
......@@ -607,7 +587,7 @@ def test_random_choice():
assert len(np.unique(samples)) == 98
# `replace=False` and `p is not None`
rng = shared(np.random.RandomState(123))
rng = shared(np.random.default_rng(123))
g = pt.random.choice(
8,
p=np.array([0.25, 0, 0.25, 0, 0.25, 0, 0.25, 0]),
......@@ -625,7 +605,7 @@ def test_random_choice():
def test_random_categorical():
rng = shared(np.random.RandomState(123))
rng = shared(np.random.default_rng(123))
g = pt.random.categorical(0.25 * np.ones(4), size=(10000, 4), rng=rng)
g_fn = compile_random_function([], g, mode=jax_mode)
samples = g_fn()
......@@ -642,7 +622,7 @@ def test_random_categorical():
def test_random_permutation():
array = np.arange(4)
rng = shared(np.random.RandomState(123))
rng = shared(np.random.default_rng(123))
g = pt.random.permutation(array, rng=rng)
g_fn = compile_random_function([], g, mode=jax_mode)
permuted = g_fn()
......@@ -664,7 +644,7 @@ def test_unnatural_batched_dims(batch_dims_tester):
def test_random_geometric():
rng = shared(np.random.RandomState(123))
rng = shared(np.random.default_rng(123))
p = np.array([0.3, 0.7])
g = pt.random.geometric(p, size=(10_000, 2), rng=rng)
g_fn = compile_random_function([], g, mode=jax_mode)
......@@ -674,7 +654,7 @@ def test_random_geometric():
def test_negative_binomial():
rng = shared(np.random.RandomState(123))
rng = shared(np.random.default_rng(123))
n = np.array([10, 40])
p = np.array([0.3, 0.7])
g = pt.random.negative_binomial(n, p, size=(10_000, 2), rng=rng)
......@@ -688,7 +668,7 @@ def test_negative_binomial():
@pytest.mark.skipif(not numpyro_available, reason="Binomial dispatch requires numpyro")
def test_binomial():
rng = shared(np.random.RandomState(123))
rng = shared(np.random.default_rng(123))
n = np.array([10, 40])
p = np.array([0.3, 0.7])
g = pt.random.binomial(n, p, size=(10_000, 2), rng=rng)
......@@ -702,7 +682,7 @@ def test_binomial():
not numpyro_available, reason="BetaBinomial dispatch requires numpyro"
)
def test_beta_binomial():
rng = shared(np.random.RandomState(123))
rng = shared(np.random.default_rng(123))
n = np.array([10, 40])
a = np.array([1.5, 13])
b = np.array([0.5, 9])
......@@ -721,7 +701,7 @@ def test_beta_binomial():
not numpyro_available, reason="Multinomial dispatch requires numpyro"
)
def test_multinomial():
rng = shared(np.random.RandomState(123))
rng = shared(np.random.default_rng(123))
n = np.array([10, 40])
p = np.array([[0.3, 0.7, 0.0], [0.1, 0.4, 0.5]])
g = pt.random.multinomial(n, p, size=(10_000, 2), rng=rng)
......@@ -737,7 +717,7 @@ def test_multinomial():
def test_vonmises_mu_outside_circle():
# Scipy implementation does not behave as PyTensor/NumPy for mu outside the unit circle
# We test that the random draws from the JAX dispatch work as expected in these cases
rng = shared(np.random.RandomState(123))
rng = shared(np.random.default_rng(123))
mu = np.array([-30, 40])
kappa = np.array([100, 10])
g = pt.random.vonmises(mu, kappa, size=(10_000, 2), rng=rng)
......@@ -781,7 +761,7 @@ def test_random_unimplemented():
return 0
nonexistentrv = NonExistentRV()
rng = shared(np.random.RandomState(123))
rng = shared(np.random.default_rng(123))
out = nonexistentrv(rng=rng)
fgraph = FunctionGraph([out.owner.inputs[0]], [out], clone=False)
......@@ -816,7 +796,7 @@ def test_random_custom_implementation():
return sample_fn
nonexistentrv = CustomRV()
rng = shared(np.random.RandomState(123))
rng = shared(np.random.default_rng(123))
out = nonexistentrv(rng=rng)
fgraph = FunctionGraph([out.owner.inputs[0]], [out], clone=False)
with pytest.warns(
......@@ -836,7 +816,7 @@ def test_random_concrete_shape():
`size` parameter satisfies either of these criteria.
"""
rng = shared(np.random.RandomState(123))
rng = shared(np.random.default_rng(123))
x_pt = pt.dmatrix()
out = pt.random.normal(0, 1, size=x_pt.shape, rng=rng)
jax_fn = compile_random_function([x_pt], out, mode=jax_mode)
......@@ -844,7 +824,7 @@ def test_random_concrete_shape():
def test_random_concrete_shape_from_param():
rng = shared(np.random.RandomState(123))
rng = shared(np.random.default_rng(123))
x_pt = pt.dmatrix()
out = pt.random.normal(x_pt, 1, rng=rng)
jax_fn = compile_random_function([x_pt], out, mode=jax_mode)
......@@ -863,7 +843,7 @@ def test_random_concrete_shape_subtensor():
slight improvement over their API.
"""
rng = shared(np.random.RandomState(123))
rng = shared(np.random.default_rng(123))
x_pt = pt.dmatrix()
out = pt.random.normal(0, 1, size=x_pt.shape[1], rng=rng)
jax_fn = compile_random_function([x_pt], out, mode=jax_mode)
......@@ -879,7 +859,7 @@ def test_random_concrete_shape_subtensor_tuple():
`jax_size_parameter_as_tuple` rewrite.
"""
rng = shared(np.random.RandomState(123))
rng = shared(np.random.default_rng(123))
x_pt = pt.dmatrix()
out = pt.random.normal(0, 1, size=(x_pt.shape[0],), rng=rng)
jax_fn = compile_random_function([x_pt], out, mode=jax_mode)
......@@ -890,7 +870,7 @@ def test_random_concrete_shape_subtensor_tuple():
reason="`size_pt` should be specified as a static argument", strict=True
)
def test_random_concrete_shape_graph_input():
rng = shared(np.random.RandomState(123))
rng = shared(np.random.default_rng(123))
size_pt = pt.scalar()
out = pt.random.normal(0, 1, size=size_pt, rng=rng)
jax_fn = compile_random_function([size_pt], out, mode=jax_mode)
......
......@@ -244,10 +244,7 @@ def scan_nodes_from_fct(fct):
class TestScan:
@pytest.mark.parametrize(
"rng_type",
[
np.random.default_rng,
np.random.RandomState,
],
[np.random.default_rng],
)
def test_inner_graph_cloning(self, rng_type):
r"""Scan should remove the updates-providing special properties on `RandomType`\s."""
......
......@@ -51,7 +51,6 @@ from pytensor.tensor.random.basic import (
pareto,
permutation,
poisson,
randint,
rayleigh,
standard_normal,
t,
......@@ -1355,27 +1354,6 @@ def test_categorical_basic():
categorical.rng_fn(rng, p[None], size=(3,))
def test_randint_samples():
with pytest.raises(TypeError):
randint(10, rng=shared(np.random.default_rng()))
rng = np.random.RandomState(2313)
compare_sample_values(randint, 10, None, rng=rng)
compare_sample_values(randint, 0, 1, rng=rng)
compare_sample_values(randint, 0, 1, size=[3], rng=rng)
compare_sample_values(randint, [0, 1, 2], 5, rng=rng)
compare_sample_values(randint, [0, 1, 2], 5, size=[3, 3], rng=rng)
compare_sample_values(randint, [0], [5], size=[1], rng=rng)
compare_sample_values(randint, pt.as_tensor_variable([-1]), [1], size=[1], rng=rng)
compare_sample_values(
randint,
pt.as_tensor_variable([-1]),
[1],
size=pt.as_tensor_variable([1]),
rng=rng,
)
def test_integers_samples():
with pytest.raises(TypeError):
integers(10, rng=shared(np.random.RandomState()))
......
......@@ -8,7 +8,7 @@ from pytensor.raise_op import Assert
from pytensor.tensor.math import eq
from pytensor.tensor.random import normal
from pytensor.tensor.random.basic import NormalRV
from pytensor.tensor.random.op import RandomState, RandomVariable, default_rng
from pytensor.tensor.random.op import RandomVariable, default_rng
from pytensor.tensor.shape import specify_shape
from pytensor.tensor.type import iscalar, tensor
......@@ -159,7 +159,6 @@ def test_RandomVariable_floatX(strict_test_value_flags):
@pytest.mark.parametrize(
"seed, maker_op, numpy_res",
[
(3, RandomState, np.random.RandomState(3)),
(3, default_rng, np.random.default_rng(3)),
],
)
......@@ -174,10 +173,6 @@ def test_random_maker_ops_no_seed(strict_test_value_flags):
# Testing the initialization when seed=None
# Since internal states randomly generated,
# we just check the output classes
z = function(inputs=[], outputs=[RandomState()])()
aes_res = z[0]
assert isinstance(aes_res, np.random.RandomState)
z = function(inputs=[], outputs=[default_rng()])()
aes_res = z[0]
assert isinstance(aes_res, np.random.Generator)
......
......@@ -7,9 +7,7 @@ from pytensor import shared
from pytensor.compile.ops import ViewOp
from pytensor.tensor.random.type import (
RandomGeneratorType,
RandomStateType,
random_generator_type,
random_state_type,
)
......@@ -28,101 +26,9 @@ def test_view_op_c_code():
# rng_view,
# mode=Mode(optimizer=None, linker=CLinker()),
# )
assert ViewOp.c_code_and_version[RandomStateType]
assert ViewOp.c_code_and_version[RandomGeneratorType]
class TestRandomStateType:
def test_pickle(self):
rng_r = random_state_type()
rng_pkl = pickle.dumps(rng_r)
rng_unpkl = pickle.loads(rng_pkl)
assert rng_r != rng_unpkl
assert rng_r.type == rng_unpkl.type
assert hash(rng_r.type) == hash(rng_unpkl.type)
def test_repr(self):
assert repr(random_state_type) == "RandomStateType"
def test_filter(self):
rng_type = random_state_type
rng = np.random.RandomState()
assert rng_type.filter(rng) is rng
with pytest.raises(TypeError):
rng_type.filter(1)
rng_dict = rng.get_state(legacy=False)
assert rng_type.is_valid_value(rng_dict) is False
assert rng_type.is_valid_value(rng_dict, strict=False)
rng_dict["state"] = {}
assert rng_type.is_valid_value(rng_dict, strict=False) is False
rng_dict = {}
assert rng_type.is_valid_value(rng_dict, strict=False) is False
def test_values_eq(self):
rng_type = random_state_type
rng_a = np.random.RandomState(12)
rng_b = np.random.RandomState(12)
rng_c = np.random.RandomState(123)
bg = np.random.PCG64()
rng_d = np.random.RandomState(bg)
rng_e = np.random.RandomState(bg)
bg_2 = np.random.Philox()
rng_f = np.random.RandomState(bg_2)
rng_g = np.random.RandomState(bg_2)
assert rng_type.values_eq(rng_a, rng_b)
assert not rng_type.values_eq(rng_a, rng_c)
assert not rng_type.values_eq(rng_a, rng_d)
assert not rng_type.values_eq(rng_d, rng_a)
assert not rng_type.values_eq(rng_a, rng_d)
assert rng_type.values_eq(rng_d, rng_e)
assert rng_type.values_eq(rng_f, rng_g)
assert not rng_type.values_eq(rng_g, rng_a)
assert not rng_type.values_eq(rng_e, rng_g)
def test_may_share_memory(self):
bg1 = np.random.MT19937()
bg2 = np.random.MT19937()
rng_a = np.random.RandomState(bg1)
rng_b = np.random.RandomState(bg2)
rng_var_a = shared(rng_a, borrow=True)
rng_var_b = shared(rng_b, borrow=True)
assert (
random_state_type.may_share_memory(
rng_var_a.get_value(borrow=True), rng_var_b.get_value(borrow=True)
)
is False
)
rng_c = np.random.RandomState(bg2)
rng_var_c = shared(rng_c, borrow=True)
assert (
random_state_type.may_share_memory(
rng_var_b.get_value(borrow=True), rng_var_c.get_value(borrow=True)
)
is True
)
class TestRandomGeneratorType:
def test_pickle(self):
rng_r = random_generator_type()
......@@ -200,7 +106,7 @@ class TestRandomGeneratorType:
rng_var_b = shared(rng_b, borrow=True)
assert (
random_state_type.may_share_memory(
random_generator_type.may_share_memory(
rng_var_a.get_value(borrow=True), rng_var_b.get_value(borrow=True)
)
is False
......@@ -210,7 +116,7 @@ class TestRandomGeneratorType:
rng_var_c = shared(rng_c, borrow=True)
assert (
random_state_type.may_share_memory(
random_generator_type.may_share_memory(
rng_var_b.get_value(borrow=True), rng_var_c.get_value(borrow=True)
)
is True
......
......@@ -101,7 +101,7 @@ class TestSharedRandomStream:
assert np.all(g() == g())
assert np.all(abs(nearly_zeros()) < 1e-5)
@pytest.mark.parametrize("rng_ctor", [np.random.RandomState, np.random.default_rng])
@pytest.mark.parametrize("rng_ctor", [np.random.default_rng])
def test_basics(self, rng_ctor):
random = RandomStream(seed=utt.fetch_seed(), rng_ctor=rng_ctor)
......@@ -132,7 +132,7 @@ class TestSharedRandomStream:
assert np.allclose(fn_val0, numpy_val0)
assert np.allclose(fn_val1, numpy_val1)
@pytest.mark.parametrize("rng_ctor", [np.random.RandomState, np.random.default_rng])
@pytest.mark.parametrize("rng_ctor", [np.random.default_rng])
def test_seed(self, rng_ctor):
init_seed = 234
random = RandomStream(init_seed, rng_ctor=rng_ctor)
......@@ -176,7 +176,7 @@ class TestSharedRandomStream:
assert random_state["bit_generator"] == ref_state["bit_generator"]
assert random_state["state"] == ref_state["state"]
@pytest.mark.parametrize("rng_ctor", [np.random.RandomState, np.random.default_rng])
@pytest.mark.parametrize("rng_ctor", [np.random.default_rng])
def test_uniform(self, rng_ctor):
# Test that RandomStream.uniform generates the same results as numpy
# Check over two calls to see if the random state is correctly updated.
......@@ -195,7 +195,7 @@ class TestSharedRandomStream:
assert np.allclose(fn_val0, numpy_val0)
assert np.allclose(fn_val1, numpy_val1)
@pytest.mark.parametrize("rng_ctor", [np.random.RandomState, np.random.default_rng])
@pytest.mark.parametrize("rng_ctor", [np.random.default_rng])
def test_default_updates(self, rng_ctor):
# Basic case: default_updates
random_a = RandomStream(utt.fetch_seed(), rng_ctor=rng_ctor)
......@@ -244,7 +244,7 @@ class TestSharedRandomStream:
assert np.all(fn_e_val0 == fn_a_val0)
assert np.all(fn_e_val1 == fn_e_val0)
@pytest.mark.parametrize("rng_ctor", [np.random.RandomState, np.random.default_rng])
@pytest.mark.parametrize("rng_ctor", [np.random.default_rng])
def test_multiple_rng_aliasing(self, rng_ctor):
# Test that when we have multiple random number generators, we do not alias
# the state_updates member. `state_updates` can be useful when attempting to
......@@ -257,7 +257,7 @@ class TestSharedRandomStream:
assert rng1.state_updates is not rng2.state_updates
assert rng1.gen_seedgen is not rng2.gen_seedgen
@pytest.mark.parametrize("rng_ctor", [np.random.RandomState, np.random.default_rng])
@pytest.mark.parametrize("rng_ctor", [np.random.default_rng])
def test_random_state_transfer(self, rng_ctor):
# Test that random state can be transferred from one pytensor graph to another.
......
......@@ -4,9 +4,7 @@ import pytest
from pytensor import shared
@pytest.mark.parametrize(
"rng", [np.random.RandomState(123), np.random.default_rng(123)]
)
@pytest.mark.parametrize("rng", [np.random.default_rng(123)])
def test_GeneratorSharedVariable(rng):
s_rng_default = shared(rng)
s_rng_True = shared(rng, borrow=True)
......@@ -32,9 +30,7 @@ def test_GeneratorSharedVariable(rng):
assert v == v0 == v1
@pytest.mark.parametrize(
"rng", [np.random.RandomState(123), np.random.default_rng(123)]
)
@pytest.mark.parametrize("rng", [np.random.default_rng(123)])
def test_get_value_borrow(rng):
s_rng = shared(rng)
......@@ -55,9 +51,7 @@ def test_get_value_borrow(rng):
assert r_.standard_normal() == r_F.standard_normal()
@pytest.mark.parametrize(
"rng", [np.random.RandomState(123), np.random.default_rng(123)]
)
@pytest.mark.parametrize("rng", [np.random.default_rng(123)])
def test_get_value_internal_type(rng):
s_rng = shared(rng)
......@@ -81,7 +75,7 @@ def test_get_value_internal_type(rng):
assert r_.standard_normal() == r_F.standard_normal()
@pytest.mark.parametrize("rng_ctor", [np.random.RandomState, np.random.default_rng])
@pytest.mark.parametrize("rng_ctor", [np.random.default_rng])
def test_set_value_borrow(rng_ctor):
s_rng = shared(rng_ctor(123))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论