提交 b2365e0e authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Remove unnecessary handling of no longer supported RandomState

上级 a2b79859
......@@ -118,7 +118,7 @@ class TestSumDiffOp(utt.InferShapeTester):
self.op_class = SumDiffOp
def test_perform(self):
rng = np.random.RandomState(43)
rng = np.random.default_rng(43)
x = matrix()
y = matrix()
f = pytensor.function([x, y], self.op_class()(x, y))
......@@ -128,7 +128,7 @@ class TestSumDiffOp(utt.InferShapeTester):
assert np.allclose([x_val + y_val, x_val - y_val], out)
def test_gradient(self):
rng = np.random.RandomState(43)
rng = np.random.default_rng(43)
def output_0(x, y):
return self.op_class()(x, y)[0]
......@@ -150,7 +150,7 @@ class TestSumDiffOp(utt.InferShapeTester):
)
def test_infer_shape(self):
rng = np.random.RandomState(43)
rng = np.random.default_rng(43)
x = dmatrix()
y = dmatrix()
......
......@@ -95,7 +95,7 @@
"noutputs = 10\n",
"nhiddens = 50\n",
"\n",
"rng = np.random.RandomState(0)\n",
"rng = np.random.default_rng(0)\n",
"x = pt.dmatrix('x')\n",
"wh = pytensor.shared(rng.normal(0, 1, (nfeatures, nhiddens)), borrow=True)\n",
"bh = pytensor.shared(np.zeros(nhiddens), borrow=True)\n",
......
......@@ -58,7 +58,7 @@ hidden layer and a softmax output layer.
noutputs = 10
nhiddens = 50
rng = np.random.RandomState(0)
rng = np.random.default_rng(0)
x = pt.dmatrix('x')
wh = pytensor.shared(rng.normal(0, 1, (nfeatures, nhiddens)), borrow=True)
bh = pytensor.shared(np.zeros(nhiddens), borrow=True)
......
......@@ -239,7 +239,7 @@ Optimization o4 o3 o2
See :func:`insert_inplace_optimizer`
inplace_random
Typically when a graph uses random numbers, the RandomState is stored
Typically when a graph uses random numbers, the random Generator is stored
in a shared variable, used once per call and, updated after each function
call. In this common case, it makes sense to update the random number generator in-place.
......
......@@ -104,10 +104,7 @@ def detect_nan(fgraph, i, node, fn):
from pytensor.printing import debugprint
for output in fn.outputs:
if (
not isinstance(output[0], np.random.RandomState | np.random.Generator)
and np.isnan(output[0]).any()
):
if not isinstance(output[0], np.random.Generator) and np.isnan(output[0]).any():
print("*** NaN detected ***") # noqa: T201
debugprint(node)
print(f"Inputs : {[input[0] for input in fn.inputs]}") # noqa: T201
......
......@@ -34,7 +34,7 @@ def _is_numeric_value(arr, var):
if isinstance(arr, _cdata_type):
return False
elif isinstance(arr, np.random.mtrand.RandomState | np.random.Generator):
elif isinstance(arr, np.random.Generator):
return False
elif var is not None and isinstance(var.type, RandomType):
return False
......
import warnings
from numpy.random import Generator, RandomState
from numpy.random import Generator
from pytensor.compile.sharedvalue import SharedVariable, shared
from pytensor.link.basic import JITLinker
......@@ -21,7 +21,7 @@ class JAXLinker(JITLinker):
# Replace any shared RNG inputs so that their values can be updated in place
# without affecting the original RNG container. This is necessary because
# JAX does not accept RandomState/Generators as inputs, and they will have to
# JAX does not accept Generators as inputs, and they will have to
# be tipyfied
if shared_rng_inputs:
warnings.warn(
......@@ -79,7 +79,7 @@ class JAXLinker(JITLinker):
thunk_inputs = []
for n in self.fgraph.inputs:
sinput = storage_map[n]
if isinstance(sinput[0], RandomState | Generator):
if isinstance(sinput[0], Generator):
new_value = jax_typify(
sinput[0], dtype=getattr(sinput[0], "dtype", None)
)
......
......@@ -16,22 +16,4 @@ class NumbaLinker(JITLinker):
return jitted_fn
def create_thunk_inputs(self, storage_map):
from numpy.random import RandomState
from pytensor.link.numba.dispatch import numba_typify
thunk_inputs = []
for n in self.fgraph.inputs:
sinput = storage_map[n]
if isinstance(sinput[0], RandomState):
new_value = numba_typify(
sinput[0], dtype=getattr(sinput[0], "dtype", None)
)
# We need to remove the reference-based connection to the
# original `RandomState`/shared variable's storage, because
# subsequent attempts to use the same shared variable within
# other non-Numba-fied graphs will have problems.
sinput = [new_value]
thunk_inputs.append(sinput)
return thunk_inputs
return [storage_map[n] for n in self.fgraph.inputs]
from typing import TypeVar
import numpy as np
from numpy.random import Generator
import pytensor
from pytensor.graph.type import Type
T = TypeVar("T", np.random.RandomState, np.random.Generator)
T = TypeVar("T")
gen_states_keys = {
......@@ -24,14 +25,10 @@ numpy_bit_gens = {0: "MT19937", 1: "PCG64", 2: "Philox", 3: "SFC64"}
class RandomType(Type[T]):
r"""A Type wrapper for `numpy.random.Generator` and `numpy.random.RandomState`."""
@staticmethod
def may_share_memory(a: T, b: T):
return a._bit_generator is b._bit_generator # type: ignore[attr-defined]
r"""A Type wrapper for `numpy.random.Generator."""
class RandomGeneratorType(RandomType[np.random.Generator]):
class RandomGeneratorType(RandomType[Generator]):
r"""A Type wrapper for `numpy.random.Generator`.
The reason this exists (and `Generic` doesn't suffice) is that
......@@ -47,6 +44,10 @@ class RandomGeneratorType(RandomType[np.random.Generator]):
def __repr__(self):
return "RandomGeneratorType"
@staticmethod
def may_share_memory(a: Generator, b: Generator):
return a._bit_generator is b._bit_generator # type: ignore[attr-defined]
def filter(self, data, strict=False, allow_downcast=None):
"""
XXX: This doesn't convert `data` to the same type of underlying RNG type
......@@ -58,7 +59,7 @@ class RandomGeneratorType(RandomType[np.random.Generator]):
`Type.filter`, we need to have it here to avoid surprising circular
dependencies in sub-classes.
"""
if isinstance(data, np.random.Generator):
if isinstance(data, Generator):
return data
if not strict and isinstance(data, dict):
......
......@@ -27,8 +27,7 @@ def fetch_seed(pseed=None):
If config.unittest.rseed is set to "random", it will seed the rng with
None, which is equivalent to seeding with a random seed.
Useful for seeding RandomState or Generator objects.
>>> rng = np.random.RandomState(fetch_seed())
Useful for seeding Generator objects.
>>> rng = np.random.default_rng(fetch_seed())
"""
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论