提交 5d4b0c4b authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Remove RandomState type in remaining backends

上级 14da898c
...@@ -2,7 +2,7 @@ from functools import singledispatch ...@@ -2,7 +2,7 @@ from functools import singledispatch
import jax import jax
import numpy as np import numpy as np
from numpy.random import Generator, RandomState from numpy.random import Generator
from numpy.random.bit_generator import ( # type: ignore[attr-defined] from numpy.random.bit_generator import ( # type: ignore[attr-defined]
_coerce_to_uint32_array, _coerce_to_uint32_array,
) )
...@@ -54,15 +54,6 @@ def assert_size_argument_jax_compatible(node): ...@@ -54,15 +54,6 @@ def assert_size_argument_jax_compatible(node):
raise NotImplementedError(SIZE_NOT_COMPATIBLE) raise NotImplementedError(SIZE_NOT_COMPATIBLE)
@jax_typify.register(RandomState)
def jax_typify_RandomState(state, **kwargs):
state = state.get_state(legacy=False)
state["bit_generator"] = numpy_bit_gens[state["bit_generator"]]
# XXX: Is this a reasonable approach?
state["jax_state"] = state["state"]["key"][0:2]
return state
@jax_typify.register(Generator) @jax_typify.register(Generator)
def jax_typify_Generator(rng, **kwargs): def jax_typify_Generator(rng, **kwargs):
state = rng.__getstate__() state = rng.__getstate__()
...@@ -214,7 +205,6 @@ def jax_sample_fn_categorical(op, node): ...@@ -214,7 +205,6 @@ def jax_sample_fn_categorical(op, node):
return sample_fn return sample_fn
@jax_sample_fn.register(ptr.RandIntRV)
@jax_sample_fn.register(ptr.IntegersRV) @jax_sample_fn.register(ptr.IntegersRV)
@jax_sample_fn.register(ptr.UniformRV) @jax_sample_fn.register(ptr.UniformRV)
def jax_sample_fn_uniform(op, node): def jax_sample_fn_uniform(op, node):
......
...@@ -25,7 +25,6 @@ from pytensor.link.utils import ( ...@@ -25,7 +25,6 @@ from pytensor.link.utils import (
) )
from pytensor.tensor import get_vector_length from pytensor.tensor import get_vector_length
from pytensor.tensor.random.op import RandomVariable, RandomVariableWithCoreShape from pytensor.tensor.random.op import RandomVariable, RandomVariableWithCoreShape
from pytensor.tensor.random.type import RandomStateType
from pytensor.tensor.type_other import NoneTypeT from pytensor.tensor.type_other import NoneTypeT
from pytensor.tensor.utils import _parse_gufunc_signature from pytensor.tensor.utils import _parse_gufunc_signature
...@@ -348,9 +347,6 @@ def numba_funcify_RandomVariable(op: RandomVariableWithCoreShape, node, **kwargs ...@@ -348,9 +347,6 @@ def numba_funcify_RandomVariable(op: RandomVariableWithCoreShape, node, **kwargs
[rv_node] = op.fgraph.apply_nodes [rv_node] = op.fgraph.apply_nodes
rv_op: RandomVariable = rv_node.op rv_op: RandomVariable = rv_node.op
rng_param = rv_op.rng_param(rv_node)
if isinstance(rng_param.type, RandomStateType):
raise TypeError("Numba does not support NumPy `RandomStateType`s")
size = rv_op.size_param(rv_node) size = rv_op.size_param(rv_node)
dist_params = rv_op.dist_params(rv_node) dist_params = rv_op.dist_params(rv_node)
size_len = None if isinstance(size.type, NoneTypeT) else get_vector_length(size) size_len = None if isinstance(size.type, NoneTypeT) else get_vector_length(size)
......
...@@ -2,5 +2,5 @@ ...@@ -2,5 +2,5 @@
import pytensor.tensor.random.rewriting import pytensor.tensor.random.rewriting
import pytensor.tensor.random.utils import pytensor.tensor.random.utils
from pytensor.tensor.random.basic import * from pytensor.tensor.random.basic import *
from pytensor.tensor.random.op import RandomState, default_rng from pytensor.tensor.random.op import default_rng
from pytensor.tensor.random.utils import RandomStream from pytensor.tensor.random.utils import RandomStream
...@@ -9,15 +9,10 @@ from pytensor.tensor import get_vector_length, specify_shape ...@@ -9,15 +9,10 @@ from pytensor.tensor import get_vector_length, specify_shape
from pytensor.tensor.basic import as_tensor_variable from pytensor.tensor.basic import as_tensor_variable
from pytensor.tensor.math import sqrt from pytensor.tensor.math import sqrt
from pytensor.tensor.random.op import RandomVariable from pytensor.tensor.random.op import RandomVariable
from pytensor.tensor.random.type import RandomGeneratorType, RandomStateType
from pytensor.tensor.random.utils import ( from pytensor.tensor.random.utils import (
broadcast_params, broadcast_params,
normalize_size_param, normalize_size_param,
) )
from pytensor.tensor.random.var import (
RandomGeneratorSharedVariable,
RandomStateSharedVariable,
)
try: try:
...@@ -645,7 +640,7 @@ class GumbelRV(ScipyRandomVariable): ...@@ -645,7 +640,7 @@ class GumbelRV(ScipyRandomVariable):
@classmethod @classmethod
def rng_fn_scipy( def rng_fn_scipy(
cls, cls,
rng: np.random.Generator | np.random.RandomState, rng: np.random.Generator,
loc: np.ndarray | float, loc: np.ndarray | float,
scale: np.ndarray | float, scale: np.ndarray | float,
size: list[int] | int | None, size: list[int] | int | None,
...@@ -1880,58 +1875,6 @@ class CategoricalRV(RandomVariable): ...@@ -1880,58 +1875,6 @@ class CategoricalRV(RandomVariable):
categorical = CategoricalRV() categorical = CategoricalRV()
class RandIntRV(RandomVariable):
r"""A discrete uniform random variable.
Only available for `RandomStateType`. Use `integers` with `RandomGeneratorType`\s.
"""
name = "randint"
signature = "(),()->()"
dtype = "int64"
_print_name = ("randint", "\\operatorname{randint}")
def __call__(self, low, high=None, size=None, **kwargs):
r"""Draw samples from a discrete uniform distribution.
Signature
---------
`() -> ()`
Parameters
----------
low
Lower boundary of the output interval. All values generated will
be greater than or equal to `low`, unless `high=None`, in which case
all values generated are greater than or equal to `0` and
smaller than `low` (exclusive).
high
Upper boundary of the output interval. All values generated
will be smaller than `high` (exclusive).
size
Sample shape. If the given size is `(m, n, k)`, then `m * n * k`
independent, identically distributed samples are
returned. Default is `None`, in which case a single
sample is returned.
"""
if high is None:
low, high = 0, low
return super().__call__(low, high, size=size, **kwargs)
def make_node(self, rng, *args, **kwargs):
if not isinstance(
getattr(rng, "type", None), RandomStateType | RandomStateSharedVariable
):
raise TypeError("`randint` is only available for `RandomStateType`s")
return super().make_node(rng, *args, **kwargs)
randint = RandIntRV()
class IntegersRV(RandomVariable): class IntegersRV(RandomVariable):
r"""A discrete uniform random variable. r"""A discrete uniform random variable.
...@@ -1971,14 +1914,6 @@ class IntegersRV(RandomVariable): ...@@ -1971,14 +1914,6 @@ class IntegersRV(RandomVariable):
low, high = 0, low low, high = 0, low
return super().__call__(low, high, size=size, **kwargs) return super().__call__(low, high, size=size, **kwargs)
def make_node(self, rng, *args, **kwargs):
if not isinstance(
getattr(rng, "type", None),
RandomGeneratorType | RandomGeneratorSharedVariable,
):
raise TypeError("`integers` is only available for `RandomGeneratorType`s")
return super().make_node(rng, *args, **kwargs)
integers = IntegersRV() integers = IntegersRV()
...@@ -2201,7 +2136,6 @@ __all__ = [ ...@@ -2201,7 +2136,6 @@ __all__ = [
"permutation", "permutation",
"choice", "choice",
"integers", "integers",
"randint",
"categorical", "categorical",
"multinomial", "multinomial",
"betabinom", "betabinom",
......
...@@ -20,7 +20,7 @@ from pytensor.tensor.basic import ( ...@@ -20,7 +20,7 @@ from pytensor.tensor.basic import (
infer_static_shape, infer_static_shape,
) )
from pytensor.tensor.blockwise import OpWithCoreShape from pytensor.tensor.blockwise import OpWithCoreShape
from pytensor.tensor.random.type import RandomGeneratorType, RandomStateType, RandomType from pytensor.tensor.random.type import RandomGeneratorType, RandomType
from pytensor.tensor.random.utils import ( from pytensor.tensor.random.utils import (
compute_batch_shape, compute_batch_shape,
explicit_expand_dims, explicit_expand_dims,
...@@ -324,9 +324,8 @@ class RandomVariable(Op): ...@@ -324,9 +324,8 @@ class RandomVariable(Op):
Parameters Parameters
---------- ----------
rng: RandomGeneratorType or RandomStateType rng: RandomGeneratorType
Existing PyTensor `Generator` or `RandomState` object to be used. Creates a Existing PyTensor `Generator` object to be used. Creates a new one, if `None`.
new one, if `None`.
size: int or Sequence size: int or Sequence
NumPy-like size parameter. NumPy-like size parameter.
dtype: str dtype: str
...@@ -354,7 +353,7 @@ class RandomVariable(Op): ...@@ -354,7 +353,7 @@ class RandomVariable(Op):
rng = pytensor.shared(np.random.default_rng()) rng = pytensor.shared(np.random.default_rng())
elif not isinstance(rng.type, RandomType): elif not isinstance(rng.type, RandomType):
raise TypeError( raise TypeError(
"The type of rng should be an instance of either RandomGeneratorType or RandomStateType" "The type of rng should be an instance of RandomGeneratorType "
) )
inferred_shape = self._infer_shape(size, dist_params) inferred_shape = self._infer_shape(size, dist_params)
...@@ -436,14 +435,6 @@ class AbstractRNGConstructor(Op): ...@@ -436,14 +435,6 @@ class AbstractRNGConstructor(Op):
output_storage[0][0] = getattr(np.random, self.random_constructor)(seed=seed) output_storage[0][0] = getattr(np.random, self.random_constructor)(seed=seed)
class RandomStateConstructor(AbstractRNGConstructor):
random_type = RandomStateType()
random_constructor = "RandomState"
RandomState = RandomStateConstructor()
class DefaultGeneratorMakerOp(AbstractRNGConstructor): class DefaultGeneratorMakerOp(AbstractRNGConstructor):
random_type = RandomGeneratorType() random_type = RandomGeneratorType()
random_constructor = "default_rng" random_constructor = "default_rng"
......
...@@ -31,97 +31,6 @@ class RandomType(Type[T]): ...@@ -31,97 +31,6 @@ class RandomType(Type[T]):
return a._bit_generator is b._bit_generator # type: ignore[attr-defined] return a._bit_generator is b._bit_generator # type: ignore[attr-defined]
class RandomStateType(RandomType[np.random.RandomState]):
r"""A Type wrapper for `numpy.random.RandomState`.
The reason this exists (and `Generic` doesn't suffice) is that
`RandomState` objects that would appear to be equal do not compare equal
with the ``==`` operator.
This `Type` also works with a ``dict`` derived from
`RandomState.get_state(legacy=False)`, unless the ``strict`` argument to `Type.filter`
is explicitly set to ``True``.
"""
def __repr__(self):
return "RandomStateType"
def filter(self, data, strict: bool = False, allow_downcast=None):
"""
XXX: This doesn't convert `data` to the same type of underlying RNG type
as `self`. It really only checks that `data` is of the appropriate type
to be a valid `RandomStateType`.
In other words, it serves as a `Type.is_valid_value` implementation,
but, because the default `Type.is_valid_value` depends on
`Type.filter`, we need to have it here to avoid surprising circular
dependencies in sub-classes.
"""
if isinstance(data, np.random.RandomState):
return data
if not strict and isinstance(data, dict):
gen_keys = ["bit_generator", "gauss", "has_gauss", "state"]
state_keys = ["key", "pos"]
for key in gen_keys:
if key not in data:
raise TypeError()
for key in state_keys:
if key not in data["state"]:
raise TypeError()
state_key = data["state"]["key"]
if state_key.shape == (624,) and state_key.dtype == np.uint32:
# TODO: Add an option to convert to a `RandomState` instance?
return data
raise TypeError()
@staticmethod
def values_eq(a, b):
sa = a if isinstance(a, dict) else a.get_state(legacy=False)
sb = b if isinstance(b, dict) else b.get_state(legacy=False)
def _eq(sa, sb):
for key in sa:
if isinstance(sa[key], dict):
if not _eq(sa[key], sb[key]):
return False
elif isinstance(sa[key], np.ndarray):
if not np.array_equal(sa[key], sb[key]):
return False
else:
if sa[key] != sb[key]:
return False
return True
return _eq(sa, sb)
def __eq__(self, other):
return type(self) == type(other)
def __hash__(self):
return hash(type(self))
# Register `RandomStateType`'s C code for `ViewOp`.
pytensor.compile.register_view_op_c_code(
RandomStateType,
"""
Py_XDECREF(%(oname)s);
%(oname)s = %(iname)s;
Py_XINCREF(%(oname)s);
""",
1,
)
random_state_type = RandomStateType()
class RandomGeneratorType(RandomType[np.random.Generator]): class RandomGeneratorType(RandomType[np.random.Generator]):
r"""A Type wrapper for `numpy.random.Generator`. r"""A Type wrapper for `numpy.random.Generator`.
......
...@@ -209,9 +209,7 @@ class RandomStream: ...@@ -209,9 +209,7 @@ class RandomStream:
self, self,
seed: int | None = None, seed: int | None = None,
namespace: ModuleType | None = None, namespace: ModuleType | None = None,
rng_ctor: Literal[ rng_ctor: Literal[np.random.Generator] = np.random.default_rng,
np.random.RandomState, np.random.Generator
] = np.random.default_rng,
): ):
if namespace is None: if namespace is None:
from pytensor.tensor.random import basic # pylint: disable=import-self from pytensor.tensor.random import basic # pylint: disable=import-self
...@@ -223,12 +221,6 @@ class RandomStream: ...@@ -223,12 +221,6 @@ class RandomStream:
self.default_instance_seed = seed self.default_instance_seed = seed
self.state_updates = [] self.state_updates = []
self.gen_seedgen = np.random.SeedSequence(seed) self.gen_seedgen = np.random.SeedSequence(seed)
if isinstance(rng_ctor, type) and issubclass(rng_ctor, np.random.RandomState):
# The legacy state does not accept `SeedSequence`s directly
def rng_ctor(seed):
return np.random.RandomState(np.random.MT19937(seed))
self.rng_ctor = rng_ctor self.rng_ctor = rng_ctor
def __getattr__(self, obj): def __getattr__(self, obj):
......
...@@ -3,17 +3,12 @@ import copy ...@@ -3,17 +3,12 @@ import copy
import numpy as np import numpy as np
from pytensor.compile.sharedvalue import SharedVariable, shared_constructor from pytensor.compile.sharedvalue import SharedVariable, shared_constructor
from pytensor.tensor.random.type import random_generator_type, random_state_type from pytensor.tensor.random.type import random_generator_type
class RandomStateSharedVariable(SharedVariable):
def __str__(self):
return self.name or f"RandomStateSharedVariable({self.container!r})"
class RandomGeneratorSharedVariable(SharedVariable): class RandomGeneratorSharedVariable(SharedVariable):
def __str__(self): def __str__(self):
return self.name or f"RandomGeneratorSharedVariable({self.container!r})" return self.name or f"RNG({self.container!r})"
@shared_constructor.register(np.random.RandomState) @shared_constructor.register(np.random.RandomState)
...@@ -23,9 +18,10 @@ def randomgen_constructor( ...@@ -23,9 +18,10 @@ def randomgen_constructor(
): ):
r"""`SharedVariable` constructor for NumPy's `Generator` and/or `RandomState`.""" r"""`SharedVariable` constructor for NumPy's `Generator` and/or `RandomState`."""
if isinstance(value, np.random.RandomState): if isinstance(value, np.random.RandomState):
rng_sv_type = RandomStateSharedVariable raise TypeError(
rng_type = random_state_type "`np.RandomState` is no longer supported in PyTensor. Use `np.random.Generator` instead."
elif isinstance(value, np.random.Generator): )
rng_sv_type = RandomGeneratorSharedVariable rng_sv_type = RandomGeneratorSharedVariable
rng_type = random_generator_type rng_type = random_generator_type
......
...@@ -49,7 +49,7 @@ def test_random_RandomStream(): ...@@ -49,7 +49,7 @@ def test_random_RandomStream():
assert not np.array_equal(jax_res_1, jax_res_2) assert not np.array_equal(jax_res_1, jax_res_2)
@pytest.mark.parametrize("rng_ctor", (np.random.RandomState, np.random.default_rng)) @pytest.mark.parametrize("rng_ctor", (np.random.default_rng,))
def test_random_updates(rng_ctor): def test_random_updates(rng_ctor):
original_value = rng_ctor(seed=98) original_value = rng_ctor(seed=98)
rng = shared(original_value, name="original_rng", borrow=False) rng = shared(original_value, name="original_rng", borrow=False)
...@@ -299,22 +299,6 @@ def test_replaced_shared_rng_storage_ordering_equality(): ...@@ -299,22 +299,6 @@ def test_replaced_shared_rng_storage_ordering_equality():
"poisson", "poisson",
lambda *args: args, lambda *args: args,
), ),
(
ptr.randint,
[
set_test_value(
pt.lscalar(),
np.array(0, dtype=np.int64),
),
set_test_value( # high-value necessary since test on cdf
pt.lscalar(),
np.array(1000, dtype=np.int64),
),
],
(),
"randint",
lambda *args: args,
),
( (
ptr.integers, ptr.integers,
[ [
...@@ -489,11 +473,7 @@ def test_random_RandomVariable(rv_op, dist_params, base_size, cdf_name, params_c ...@@ -489,11 +473,7 @@ def test_random_RandomVariable(rv_op, dist_params, base_size, cdf_name, params_c
The parameters passed to the op. The parameters passed to the op.
""" """
if rv_op is ptr.integers: rng = shared(np.random.default_rng(29403))
# Integers only accepts Generator, not RandomState
rng = shared(np.random.default_rng(29402))
else:
rng = shared(np.random.RandomState(29402))
g = rv_op(*dist_params, size=(10000, *base_size), rng=rng) g = rv_op(*dist_params, size=(10000, *base_size), rng=rng)
g_fn = compile_random_function(dist_params, g, mode=jax_mode) g_fn = compile_random_function(dist_params, g, mode=jax_mode)
samples = g_fn( samples = g_fn(
...@@ -545,7 +525,7 @@ def test_size_implied_by_broadcasted_parameters(rv_fn): ...@@ -545,7 +525,7 @@ def test_size_implied_by_broadcasted_parameters(rv_fn):
@pytest.mark.parametrize("size", [(), (4,)]) @pytest.mark.parametrize("size", [(), (4,)])
def test_random_bernoulli(size): def test_random_bernoulli(size):
rng = shared(np.random.RandomState(123)) rng = shared(np.random.default_rng(123))
g = pt.random.bernoulli(0.5, size=(1000, *size), rng=rng) g = pt.random.bernoulli(0.5, size=(1000, *size), rng=rng)
g_fn = compile_random_function([], g, mode=jax_mode) g_fn = compile_random_function([], g, mode=jax_mode)
samples = g_fn() samples = g_fn()
...@@ -553,7 +533,7 @@ def test_random_bernoulli(size): ...@@ -553,7 +533,7 @@ def test_random_bernoulli(size):
def test_random_mvnormal(): def test_random_mvnormal():
rng = shared(np.random.RandomState(123)) rng = shared(np.random.default_rng(123))
mu = np.ones(4) mu = np.ones(4)
cov = np.eye(4) cov = np.eye(4)
...@@ -571,7 +551,7 @@ def test_random_mvnormal(): ...@@ -571,7 +551,7 @@ def test_random_mvnormal():
], ],
) )
def test_random_dirichlet(parameter, size): def test_random_dirichlet(parameter, size):
rng = shared(np.random.RandomState(123)) rng = shared(np.random.default_rng(123))
g = pt.random.dirichlet(parameter, size=(1000, *size), rng=rng) g = pt.random.dirichlet(parameter, size=(1000, *size), rng=rng)
g_fn = compile_random_function([], g, mode=jax_mode) g_fn = compile_random_function([], g, mode=jax_mode)
samples = g_fn() samples = g_fn()
...@@ -598,7 +578,7 @@ def test_random_choice(): ...@@ -598,7 +578,7 @@ def test_random_choice():
assert np.all(samples % 2 == 1) assert np.all(samples % 2 == 1)
# `replace=False` and `p is None` # `replace=False` and `p is None`
rng = shared(np.random.RandomState(123)) rng = shared(np.random.default_rng(123))
g = pt.random.choice(np.arange(100), replace=False, size=(2, 49), rng=rng) g = pt.random.choice(np.arange(100), replace=False, size=(2, 49), rng=rng)
g_fn = compile_random_function([], g, mode=jax_mode) g_fn = compile_random_function([], g, mode=jax_mode)
samples = g_fn() samples = g_fn()
...@@ -607,7 +587,7 @@ def test_random_choice(): ...@@ -607,7 +587,7 @@ def test_random_choice():
assert len(np.unique(samples)) == 98 assert len(np.unique(samples)) == 98
# `replace=False` and `p is not None` # `replace=False` and `p is not None`
rng = shared(np.random.RandomState(123)) rng = shared(np.random.default_rng(123))
g = pt.random.choice( g = pt.random.choice(
8, 8,
p=np.array([0.25, 0, 0.25, 0, 0.25, 0, 0.25, 0]), p=np.array([0.25, 0, 0.25, 0, 0.25, 0, 0.25, 0]),
...@@ -625,7 +605,7 @@ def test_random_choice(): ...@@ -625,7 +605,7 @@ def test_random_choice():
def test_random_categorical(): def test_random_categorical():
rng = shared(np.random.RandomState(123)) rng = shared(np.random.default_rng(123))
g = pt.random.categorical(0.25 * np.ones(4), size=(10000, 4), rng=rng) g = pt.random.categorical(0.25 * np.ones(4), size=(10000, 4), rng=rng)
g_fn = compile_random_function([], g, mode=jax_mode) g_fn = compile_random_function([], g, mode=jax_mode)
samples = g_fn() samples = g_fn()
...@@ -642,7 +622,7 @@ def test_random_categorical(): ...@@ -642,7 +622,7 @@ def test_random_categorical():
def test_random_permutation(): def test_random_permutation():
array = np.arange(4) array = np.arange(4)
rng = shared(np.random.RandomState(123)) rng = shared(np.random.default_rng(123))
g = pt.random.permutation(array, rng=rng) g = pt.random.permutation(array, rng=rng)
g_fn = compile_random_function([], g, mode=jax_mode) g_fn = compile_random_function([], g, mode=jax_mode)
permuted = g_fn() permuted = g_fn()
...@@ -664,7 +644,7 @@ def test_unnatural_batched_dims(batch_dims_tester): ...@@ -664,7 +644,7 @@ def test_unnatural_batched_dims(batch_dims_tester):
def test_random_geometric(): def test_random_geometric():
rng = shared(np.random.RandomState(123)) rng = shared(np.random.default_rng(123))
p = np.array([0.3, 0.7]) p = np.array([0.3, 0.7])
g = pt.random.geometric(p, size=(10_000, 2), rng=rng) g = pt.random.geometric(p, size=(10_000, 2), rng=rng)
g_fn = compile_random_function([], g, mode=jax_mode) g_fn = compile_random_function([], g, mode=jax_mode)
...@@ -674,7 +654,7 @@ def test_random_geometric(): ...@@ -674,7 +654,7 @@ def test_random_geometric():
def test_negative_binomial(): def test_negative_binomial():
rng = shared(np.random.RandomState(123)) rng = shared(np.random.default_rng(123))
n = np.array([10, 40]) n = np.array([10, 40])
p = np.array([0.3, 0.7]) p = np.array([0.3, 0.7])
g = pt.random.negative_binomial(n, p, size=(10_000, 2), rng=rng) g = pt.random.negative_binomial(n, p, size=(10_000, 2), rng=rng)
...@@ -688,7 +668,7 @@ def test_negative_binomial(): ...@@ -688,7 +668,7 @@ def test_negative_binomial():
@pytest.mark.skipif(not numpyro_available, reason="Binomial dispatch requires numpyro") @pytest.mark.skipif(not numpyro_available, reason="Binomial dispatch requires numpyro")
def test_binomial(): def test_binomial():
rng = shared(np.random.RandomState(123)) rng = shared(np.random.default_rng(123))
n = np.array([10, 40]) n = np.array([10, 40])
p = np.array([0.3, 0.7]) p = np.array([0.3, 0.7])
g = pt.random.binomial(n, p, size=(10_000, 2), rng=rng) g = pt.random.binomial(n, p, size=(10_000, 2), rng=rng)
...@@ -702,7 +682,7 @@ def test_binomial(): ...@@ -702,7 +682,7 @@ def test_binomial():
not numpyro_available, reason="BetaBinomial dispatch requires numpyro" not numpyro_available, reason="BetaBinomial dispatch requires numpyro"
) )
def test_beta_binomial(): def test_beta_binomial():
rng = shared(np.random.RandomState(123)) rng = shared(np.random.default_rng(123))
n = np.array([10, 40]) n = np.array([10, 40])
a = np.array([1.5, 13]) a = np.array([1.5, 13])
b = np.array([0.5, 9]) b = np.array([0.5, 9])
...@@ -721,7 +701,7 @@ def test_beta_binomial(): ...@@ -721,7 +701,7 @@ def test_beta_binomial():
not numpyro_available, reason="Multinomial dispatch requires numpyro" not numpyro_available, reason="Multinomial dispatch requires numpyro"
) )
def test_multinomial(): def test_multinomial():
rng = shared(np.random.RandomState(123)) rng = shared(np.random.default_rng(123))
n = np.array([10, 40]) n = np.array([10, 40])
p = np.array([[0.3, 0.7, 0.0], [0.1, 0.4, 0.5]]) p = np.array([[0.3, 0.7, 0.0], [0.1, 0.4, 0.5]])
g = pt.random.multinomial(n, p, size=(10_000, 2), rng=rng) g = pt.random.multinomial(n, p, size=(10_000, 2), rng=rng)
...@@ -737,7 +717,7 @@ def test_multinomial(): ...@@ -737,7 +717,7 @@ def test_multinomial():
def test_vonmises_mu_outside_circle(): def test_vonmises_mu_outside_circle():
# Scipy implementation does not behave as PyTensor/NumPy for mu outside the unit circle # Scipy implementation does not behave as PyTensor/NumPy for mu outside the unit circle
# We test that the random draws from the JAX dispatch work as expected in these cases # We test that the random draws from the JAX dispatch work as expected in these cases
rng = shared(np.random.RandomState(123)) rng = shared(np.random.default_rng(123))
mu = np.array([-30, 40]) mu = np.array([-30, 40])
kappa = np.array([100, 10]) kappa = np.array([100, 10])
g = pt.random.vonmises(mu, kappa, size=(10_000, 2), rng=rng) g = pt.random.vonmises(mu, kappa, size=(10_000, 2), rng=rng)
...@@ -781,7 +761,7 @@ def test_random_unimplemented(): ...@@ -781,7 +761,7 @@ def test_random_unimplemented():
return 0 return 0
nonexistentrv = NonExistentRV() nonexistentrv = NonExistentRV()
rng = shared(np.random.RandomState(123)) rng = shared(np.random.default_rng(123))
out = nonexistentrv(rng=rng) out = nonexistentrv(rng=rng)
fgraph = FunctionGraph([out.owner.inputs[0]], [out], clone=False) fgraph = FunctionGraph([out.owner.inputs[0]], [out], clone=False)
...@@ -816,7 +796,7 @@ def test_random_custom_implementation(): ...@@ -816,7 +796,7 @@ def test_random_custom_implementation():
return sample_fn return sample_fn
nonexistentrv = CustomRV() nonexistentrv = CustomRV()
rng = shared(np.random.RandomState(123)) rng = shared(np.random.default_rng(123))
out = nonexistentrv(rng=rng) out = nonexistentrv(rng=rng)
fgraph = FunctionGraph([out.owner.inputs[0]], [out], clone=False) fgraph = FunctionGraph([out.owner.inputs[0]], [out], clone=False)
with pytest.warns( with pytest.warns(
...@@ -836,7 +816,7 @@ def test_random_concrete_shape(): ...@@ -836,7 +816,7 @@ def test_random_concrete_shape():
`size` parameter satisfies either of these criteria. `size` parameter satisfies either of these criteria.
""" """
rng = shared(np.random.RandomState(123)) rng = shared(np.random.default_rng(123))
x_pt = pt.dmatrix() x_pt = pt.dmatrix()
out = pt.random.normal(0, 1, size=x_pt.shape, rng=rng) out = pt.random.normal(0, 1, size=x_pt.shape, rng=rng)
jax_fn = compile_random_function([x_pt], out, mode=jax_mode) jax_fn = compile_random_function([x_pt], out, mode=jax_mode)
...@@ -844,7 +824,7 @@ def test_random_concrete_shape(): ...@@ -844,7 +824,7 @@ def test_random_concrete_shape():
def test_random_concrete_shape_from_param(): def test_random_concrete_shape_from_param():
rng = shared(np.random.RandomState(123)) rng = shared(np.random.default_rng(123))
x_pt = pt.dmatrix() x_pt = pt.dmatrix()
out = pt.random.normal(x_pt, 1, rng=rng) out = pt.random.normal(x_pt, 1, rng=rng)
jax_fn = compile_random_function([x_pt], out, mode=jax_mode) jax_fn = compile_random_function([x_pt], out, mode=jax_mode)
...@@ -863,7 +843,7 @@ def test_random_concrete_shape_subtensor(): ...@@ -863,7 +843,7 @@ def test_random_concrete_shape_subtensor():
slight improvement over their API. slight improvement over their API.
""" """
rng = shared(np.random.RandomState(123)) rng = shared(np.random.default_rng(123))
x_pt = pt.dmatrix() x_pt = pt.dmatrix()
out = pt.random.normal(0, 1, size=x_pt.shape[1], rng=rng) out = pt.random.normal(0, 1, size=x_pt.shape[1], rng=rng)
jax_fn = compile_random_function([x_pt], out, mode=jax_mode) jax_fn = compile_random_function([x_pt], out, mode=jax_mode)
...@@ -879,7 +859,7 @@ def test_random_concrete_shape_subtensor_tuple(): ...@@ -879,7 +859,7 @@ def test_random_concrete_shape_subtensor_tuple():
`jax_size_parameter_as_tuple` rewrite. `jax_size_parameter_as_tuple` rewrite.
""" """
rng = shared(np.random.RandomState(123)) rng = shared(np.random.default_rng(123))
x_pt = pt.dmatrix() x_pt = pt.dmatrix()
out = pt.random.normal(0, 1, size=(x_pt.shape[0],), rng=rng) out = pt.random.normal(0, 1, size=(x_pt.shape[0],), rng=rng)
jax_fn = compile_random_function([x_pt], out, mode=jax_mode) jax_fn = compile_random_function([x_pt], out, mode=jax_mode)
...@@ -890,7 +870,7 @@ def test_random_concrete_shape_subtensor_tuple(): ...@@ -890,7 +870,7 @@ def test_random_concrete_shape_subtensor_tuple():
reason="`size_pt` should be specified as a static argument", strict=True reason="`size_pt` should be specified as a static argument", strict=True
) )
def test_random_concrete_shape_graph_input(): def test_random_concrete_shape_graph_input():
rng = shared(np.random.RandomState(123)) rng = shared(np.random.default_rng(123))
size_pt = pt.scalar() size_pt = pt.scalar()
out = pt.random.normal(0, 1, size=size_pt, rng=rng) out = pt.random.normal(0, 1, size=size_pt, rng=rng)
jax_fn = compile_random_function([size_pt], out, mode=jax_mode) jax_fn = compile_random_function([size_pt], out, mode=jax_mode)
......
...@@ -244,10 +244,7 @@ def scan_nodes_from_fct(fct): ...@@ -244,10 +244,7 @@ def scan_nodes_from_fct(fct):
class TestScan: class TestScan:
@pytest.mark.parametrize( @pytest.mark.parametrize(
"rng_type", "rng_type",
[ [np.random.default_rng],
np.random.default_rng,
np.random.RandomState,
],
) )
def test_inner_graph_cloning(self, rng_type): def test_inner_graph_cloning(self, rng_type):
r"""Scan should remove the updates-providing special properties on `RandomType`\s.""" r"""Scan should remove the updates-providing special properties on `RandomType`\s."""
......
...@@ -51,7 +51,6 @@ from pytensor.tensor.random.basic import ( ...@@ -51,7 +51,6 @@ from pytensor.tensor.random.basic import (
pareto, pareto,
permutation, permutation,
poisson, poisson,
randint,
rayleigh, rayleigh,
standard_normal, standard_normal,
t, t,
...@@ -1355,27 +1354,6 @@ def test_categorical_basic(): ...@@ -1355,27 +1354,6 @@ def test_categorical_basic():
categorical.rng_fn(rng, p[None], size=(3,)) categorical.rng_fn(rng, p[None], size=(3,))
def test_randint_samples():
with pytest.raises(TypeError):
randint(10, rng=shared(np.random.default_rng()))
rng = np.random.RandomState(2313)
compare_sample_values(randint, 10, None, rng=rng)
compare_sample_values(randint, 0, 1, rng=rng)
compare_sample_values(randint, 0, 1, size=[3], rng=rng)
compare_sample_values(randint, [0, 1, 2], 5, rng=rng)
compare_sample_values(randint, [0, 1, 2], 5, size=[3, 3], rng=rng)
compare_sample_values(randint, [0], [5], size=[1], rng=rng)
compare_sample_values(randint, pt.as_tensor_variable([-1]), [1], size=[1], rng=rng)
compare_sample_values(
randint,
pt.as_tensor_variable([-1]),
[1],
size=pt.as_tensor_variable([1]),
rng=rng,
)
def test_integers_samples(): def test_integers_samples():
with pytest.raises(TypeError): with pytest.raises(TypeError):
integers(10, rng=shared(np.random.RandomState())) integers(10, rng=shared(np.random.RandomState()))
......
...@@ -8,7 +8,7 @@ from pytensor.raise_op import Assert ...@@ -8,7 +8,7 @@ from pytensor.raise_op import Assert
from pytensor.tensor.math import eq from pytensor.tensor.math import eq
from pytensor.tensor.random import normal from pytensor.tensor.random import normal
from pytensor.tensor.random.basic import NormalRV from pytensor.tensor.random.basic import NormalRV
from pytensor.tensor.random.op import RandomState, RandomVariable, default_rng from pytensor.tensor.random.op import RandomVariable, default_rng
from pytensor.tensor.shape import specify_shape from pytensor.tensor.shape import specify_shape
from pytensor.tensor.type import iscalar, tensor from pytensor.tensor.type import iscalar, tensor
...@@ -159,7 +159,6 @@ def test_RandomVariable_floatX(strict_test_value_flags): ...@@ -159,7 +159,6 @@ def test_RandomVariable_floatX(strict_test_value_flags):
@pytest.mark.parametrize( @pytest.mark.parametrize(
"seed, maker_op, numpy_res", "seed, maker_op, numpy_res",
[ [
(3, RandomState, np.random.RandomState(3)),
(3, default_rng, np.random.default_rng(3)), (3, default_rng, np.random.default_rng(3)),
], ],
) )
...@@ -174,10 +173,6 @@ def test_random_maker_ops_no_seed(strict_test_value_flags): ...@@ -174,10 +173,6 @@ def test_random_maker_ops_no_seed(strict_test_value_flags):
# Testing the initialization when seed=None # Testing the initialization when seed=None
# Since internal states randomly generated, # Since internal states randomly generated,
# we just check the output classes # we just check the output classes
z = function(inputs=[], outputs=[RandomState()])()
aes_res = z[0]
assert isinstance(aes_res, np.random.RandomState)
z = function(inputs=[], outputs=[default_rng()])() z = function(inputs=[], outputs=[default_rng()])()
aes_res = z[0] aes_res = z[0]
assert isinstance(aes_res, np.random.Generator) assert isinstance(aes_res, np.random.Generator)
......
...@@ -7,9 +7,7 @@ from pytensor import shared ...@@ -7,9 +7,7 @@ from pytensor import shared
from pytensor.compile.ops import ViewOp from pytensor.compile.ops import ViewOp
from pytensor.tensor.random.type import ( from pytensor.tensor.random.type import (
RandomGeneratorType, RandomGeneratorType,
RandomStateType,
random_generator_type, random_generator_type,
random_state_type,
) )
...@@ -28,101 +26,9 @@ def test_view_op_c_code(): ...@@ -28,101 +26,9 @@ def test_view_op_c_code():
# rng_view, # rng_view,
# mode=Mode(optimizer=None, linker=CLinker()), # mode=Mode(optimizer=None, linker=CLinker()),
# ) # )
assert ViewOp.c_code_and_version[RandomStateType]
assert ViewOp.c_code_and_version[RandomGeneratorType] assert ViewOp.c_code_and_version[RandomGeneratorType]
class TestRandomStateType:
def test_pickle(self):
rng_r = random_state_type()
rng_pkl = pickle.dumps(rng_r)
rng_unpkl = pickle.loads(rng_pkl)
assert rng_r != rng_unpkl
assert rng_r.type == rng_unpkl.type
assert hash(rng_r.type) == hash(rng_unpkl.type)
def test_repr(self):
assert repr(random_state_type) == "RandomStateType"
def test_filter(self):
rng_type = random_state_type
rng = np.random.RandomState()
assert rng_type.filter(rng) is rng
with pytest.raises(TypeError):
rng_type.filter(1)
rng_dict = rng.get_state(legacy=False)
assert rng_type.is_valid_value(rng_dict) is False
assert rng_type.is_valid_value(rng_dict, strict=False)
rng_dict["state"] = {}
assert rng_type.is_valid_value(rng_dict, strict=False) is False
rng_dict = {}
assert rng_type.is_valid_value(rng_dict, strict=False) is False
def test_values_eq(self):
rng_type = random_state_type
rng_a = np.random.RandomState(12)
rng_b = np.random.RandomState(12)
rng_c = np.random.RandomState(123)
bg = np.random.PCG64()
rng_d = np.random.RandomState(bg)
rng_e = np.random.RandomState(bg)
bg_2 = np.random.Philox()
rng_f = np.random.RandomState(bg_2)
rng_g = np.random.RandomState(bg_2)
assert rng_type.values_eq(rng_a, rng_b)
assert not rng_type.values_eq(rng_a, rng_c)
assert not rng_type.values_eq(rng_a, rng_d)
assert not rng_type.values_eq(rng_d, rng_a)
assert not rng_type.values_eq(rng_a, rng_d)
assert rng_type.values_eq(rng_d, rng_e)
assert rng_type.values_eq(rng_f, rng_g)
assert not rng_type.values_eq(rng_g, rng_a)
assert not rng_type.values_eq(rng_e, rng_g)
def test_may_share_memory(self):
bg1 = np.random.MT19937()
bg2 = np.random.MT19937()
rng_a = np.random.RandomState(bg1)
rng_b = np.random.RandomState(bg2)
rng_var_a = shared(rng_a, borrow=True)
rng_var_b = shared(rng_b, borrow=True)
assert (
random_state_type.may_share_memory(
rng_var_a.get_value(borrow=True), rng_var_b.get_value(borrow=True)
)
is False
)
rng_c = np.random.RandomState(bg2)
rng_var_c = shared(rng_c, borrow=True)
assert (
random_state_type.may_share_memory(
rng_var_b.get_value(borrow=True), rng_var_c.get_value(borrow=True)
)
is True
)
class TestRandomGeneratorType: class TestRandomGeneratorType:
def test_pickle(self): def test_pickle(self):
rng_r = random_generator_type() rng_r = random_generator_type()
...@@ -200,7 +106,7 @@ class TestRandomGeneratorType: ...@@ -200,7 +106,7 @@ class TestRandomGeneratorType:
rng_var_b = shared(rng_b, borrow=True) rng_var_b = shared(rng_b, borrow=True)
assert ( assert (
random_state_type.may_share_memory( random_generator_type.may_share_memory(
rng_var_a.get_value(borrow=True), rng_var_b.get_value(borrow=True) rng_var_a.get_value(borrow=True), rng_var_b.get_value(borrow=True)
) )
is False is False
...@@ -210,7 +116,7 @@ class TestRandomGeneratorType: ...@@ -210,7 +116,7 @@ class TestRandomGeneratorType:
rng_var_c = shared(rng_c, borrow=True) rng_var_c = shared(rng_c, borrow=True)
assert ( assert (
random_state_type.may_share_memory( random_generator_type.may_share_memory(
rng_var_b.get_value(borrow=True), rng_var_c.get_value(borrow=True) rng_var_b.get_value(borrow=True), rng_var_c.get_value(borrow=True)
) )
is True is True
......
...@@ -101,7 +101,7 @@ class TestSharedRandomStream: ...@@ -101,7 +101,7 @@ class TestSharedRandomStream:
assert np.all(g() == g()) assert np.all(g() == g())
assert np.all(abs(nearly_zeros()) < 1e-5) assert np.all(abs(nearly_zeros()) < 1e-5)
@pytest.mark.parametrize("rng_ctor", [np.random.RandomState, np.random.default_rng]) @pytest.mark.parametrize("rng_ctor", [np.random.default_rng])
def test_basics(self, rng_ctor): def test_basics(self, rng_ctor):
random = RandomStream(seed=utt.fetch_seed(), rng_ctor=rng_ctor) random = RandomStream(seed=utt.fetch_seed(), rng_ctor=rng_ctor)
...@@ -132,7 +132,7 @@ class TestSharedRandomStream: ...@@ -132,7 +132,7 @@ class TestSharedRandomStream:
assert np.allclose(fn_val0, numpy_val0) assert np.allclose(fn_val0, numpy_val0)
assert np.allclose(fn_val1, numpy_val1) assert np.allclose(fn_val1, numpy_val1)
@pytest.mark.parametrize("rng_ctor", [np.random.RandomState, np.random.default_rng]) @pytest.mark.parametrize("rng_ctor", [np.random.default_rng])
def test_seed(self, rng_ctor): def test_seed(self, rng_ctor):
init_seed = 234 init_seed = 234
random = RandomStream(init_seed, rng_ctor=rng_ctor) random = RandomStream(init_seed, rng_ctor=rng_ctor)
...@@ -176,7 +176,7 @@ class TestSharedRandomStream: ...@@ -176,7 +176,7 @@ class TestSharedRandomStream:
assert random_state["bit_generator"] == ref_state["bit_generator"] assert random_state["bit_generator"] == ref_state["bit_generator"]
assert random_state["state"] == ref_state["state"] assert random_state["state"] == ref_state["state"]
@pytest.mark.parametrize("rng_ctor", [np.random.RandomState, np.random.default_rng]) @pytest.mark.parametrize("rng_ctor", [np.random.default_rng])
def test_uniform(self, rng_ctor): def test_uniform(self, rng_ctor):
# Test that RandomStream.uniform generates the same results as numpy # Test that RandomStream.uniform generates the same results as numpy
# Check over two calls to see if the random state is correctly updated. # Check over two calls to see if the random state is correctly updated.
...@@ -195,7 +195,7 @@ class TestSharedRandomStream: ...@@ -195,7 +195,7 @@ class TestSharedRandomStream:
assert np.allclose(fn_val0, numpy_val0) assert np.allclose(fn_val0, numpy_val0)
assert np.allclose(fn_val1, numpy_val1) assert np.allclose(fn_val1, numpy_val1)
@pytest.mark.parametrize("rng_ctor", [np.random.RandomState, np.random.default_rng]) @pytest.mark.parametrize("rng_ctor", [np.random.default_rng])
def test_default_updates(self, rng_ctor): def test_default_updates(self, rng_ctor):
# Basic case: default_updates # Basic case: default_updates
random_a = RandomStream(utt.fetch_seed(), rng_ctor=rng_ctor) random_a = RandomStream(utt.fetch_seed(), rng_ctor=rng_ctor)
...@@ -244,7 +244,7 @@ class TestSharedRandomStream: ...@@ -244,7 +244,7 @@ class TestSharedRandomStream:
assert np.all(fn_e_val0 == fn_a_val0) assert np.all(fn_e_val0 == fn_a_val0)
assert np.all(fn_e_val1 == fn_e_val0) assert np.all(fn_e_val1 == fn_e_val0)
@pytest.mark.parametrize("rng_ctor", [np.random.RandomState, np.random.default_rng]) @pytest.mark.parametrize("rng_ctor", [np.random.default_rng])
def test_multiple_rng_aliasing(self, rng_ctor): def test_multiple_rng_aliasing(self, rng_ctor):
# Test that when we have multiple random number generators, we do not alias # Test that when we have multiple random number generators, we do not alias
# the state_updates member. `state_updates` can be useful when attempting to # the state_updates member. `state_updates` can be useful when attempting to
...@@ -257,7 +257,7 @@ class TestSharedRandomStream: ...@@ -257,7 +257,7 @@ class TestSharedRandomStream:
assert rng1.state_updates is not rng2.state_updates assert rng1.state_updates is not rng2.state_updates
assert rng1.gen_seedgen is not rng2.gen_seedgen assert rng1.gen_seedgen is not rng2.gen_seedgen
@pytest.mark.parametrize("rng_ctor", [np.random.RandomState, np.random.default_rng]) @pytest.mark.parametrize("rng_ctor", [np.random.default_rng])
def test_random_state_transfer(self, rng_ctor): def test_random_state_transfer(self, rng_ctor):
# Test that random state can be transferred from one pytensor graph to another. # Test that random state can be transferred from one pytensor graph to another.
......
...@@ -4,9 +4,7 @@ import pytest ...@@ -4,9 +4,7 @@ import pytest
from pytensor import shared from pytensor import shared
@pytest.mark.parametrize( @pytest.mark.parametrize("rng", [np.random.default_rng(123)])
"rng", [np.random.RandomState(123), np.random.default_rng(123)]
)
def test_GeneratorSharedVariable(rng): def test_GeneratorSharedVariable(rng):
s_rng_default = shared(rng) s_rng_default = shared(rng)
s_rng_True = shared(rng, borrow=True) s_rng_True = shared(rng, borrow=True)
...@@ -32,9 +30,7 @@ def test_GeneratorSharedVariable(rng): ...@@ -32,9 +30,7 @@ def test_GeneratorSharedVariable(rng):
assert v == v0 == v1 assert v == v0 == v1
@pytest.mark.parametrize( @pytest.mark.parametrize("rng", [np.random.default_rng(123)])
"rng", [np.random.RandomState(123), np.random.default_rng(123)]
)
def test_get_value_borrow(rng): def test_get_value_borrow(rng):
s_rng = shared(rng) s_rng = shared(rng)
...@@ -55,9 +51,7 @@ def test_get_value_borrow(rng): ...@@ -55,9 +51,7 @@ def test_get_value_borrow(rng):
assert r_.standard_normal() == r_F.standard_normal() assert r_.standard_normal() == r_F.standard_normal()
@pytest.mark.parametrize( @pytest.mark.parametrize("rng", [np.random.default_rng(123)])
"rng", [np.random.RandomState(123), np.random.default_rng(123)]
)
def test_get_value_internal_type(rng): def test_get_value_internal_type(rng):
s_rng = shared(rng) s_rng = shared(rng)
...@@ -81,7 +75,7 @@ def test_get_value_internal_type(rng): ...@@ -81,7 +75,7 @@ def test_get_value_internal_type(rng):
assert r_.standard_normal() == r_F.standard_normal() assert r_.standard_normal() == r_F.standard_normal()
@pytest.mark.parametrize("rng_ctor", [np.random.RandomState, np.random.default_rng]) @pytest.mark.parametrize("rng_ctor", [np.random.default_rng])
def test_set_value_borrow(rng_ctor): def test_set_value_borrow(rng_ctor):
s_rng = shared(rng_ctor(123)) s_rng = shared(rng_ctor(123))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论