提交 b2365e0e authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Remove unnecessary handling of no longer supported RandomState

上级 a2b79859
...@@ -118,7 +118,7 @@ class TestSumDiffOp(utt.InferShapeTester): ...@@ -118,7 +118,7 @@ class TestSumDiffOp(utt.InferShapeTester):
self.op_class = SumDiffOp self.op_class = SumDiffOp
def test_perform(self): def test_perform(self):
rng = np.random.RandomState(43) rng = np.random.default_rng(43)
x = matrix() x = matrix()
y = matrix() y = matrix()
f = pytensor.function([x, y], self.op_class()(x, y)) f = pytensor.function([x, y], self.op_class()(x, y))
...@@ -128,7 +128,7 @@ class TestSumDiffOp(utt.InferShapeTester): ...@@ -128,7 +128,7 @@ class TestSumDiffOp(utt.InferShapeTester):
assert np.allclose([x_val + y_val, x_val - y_val], out) assert np.allclose([x_val + y_val, x_val - y_val], out)
def test_gradient(self): def test_gradient(self):
rng = np.random.RandomState(43) rng = np.random.default_rng(43)
def output_0(x, y): def output_0(x, y):
return self.op_class()(x, y)[0] return self.op_class()(x, y)[0]
...@@ -150,7 +150,7 @@ class TestSumDiffOp(utt.InferShapeTester): ...@@ -150,7 +150,7 @@ class TestSumDiffOp(utt.InferShapeTester):
) )
def test_infer_shape(self): def test_infer_shape(self):
rng = np.random.RandomState(43) rng = np.random.default_rng(43)
x = dmatrix() x = dmatrix()
y = dmatrix() y = dmatrix()
......
...@@ -95,7 +95,7 @@ ...@@ -95,7 +95,7 @@
"noutputs = 10\n", "noutputs = 10\n",
"nhiddens = 50\n", "nhiddens = 50\n",
"\n", "\n",
"rng = np.random.RandomState(0)\n", "rng = np.random.default_rng(0)\n",
"x = pt.dmatrix('x')\n", "x = pt.dmatrix('x')\n",
"wh = pytensor.shared(rng.normal(0, 1, (nfeatures, nhiddens)), borrow=True)\n", "wh = pytensor.shared(rng.normal(0, 1, (nfeatures, nhiddens)), borrow=True)\n",
"bh = pytensor.shared(np.zeros(nhiddens), borrow=True)\n", "bh = pytensor.shared(np.zeros(nhiddens), borrow=True)\n",
......
...@@ -58,7 +58,7 @@ hidden layer and a softmax output layer. ...@@ -58,7 +58,7 @@ hidden layer and a softmax output layer.
noutputs = 10 noutputs = 10
nhiddens = 50 nhiddens = 50
rng = np.random.RandomState(0) rng = np.random.default_rng(0)
x = pt.dmatrix('x') x = pt.dmatrix('x')
wh = pytensor.shared(rng.normal(0, 1, (nfeatures, nhiddens)), borrow=True) wh = pytensor.shared(rng.normal(0, 1, (nfeatures, nhiddens)), borrow=True)
bh = pytensor.shared(np.zeros(nhiddens), borrow=True) bh = pytensor.shared(np.zeros(nhiddens), borrow=True)
......
...@@ -239,7 +239,7 @@ Optimization o4 o3 o2 ...@@ -239,7 +239,7 @@ Optimization o4 o3 o2
See :func:`insert_inplace_optimizer` See :func:`insert_inplace_optimizer`
inplace_random inplace_random
Typically when a graph uses random numbers, the RandomState is stored Typically when a graph uses random numbers, the random Generator is stored
in a shared variable, used once per call and, updated after each function in a shared variable, used once per call and, updated after each function
call. In this common case, it makes sense to update the random number generator in-place. call. In this common case, it makes sense to update the random number generator in-place.
......
...@@ -104,10 +104,7 @@ def detect_nan(fgraph, i, node, fn): ...@@ -104,10 +104,7 @@ def detect_nan(fgraph, i, node, fn):
from pytensor.printing import debugprint from pytensor.printing import debugprint
for output in fn.outputs: for output in fn.outputs:
if ( if not isinstance(output[0], np.random.Generator) and np.isnan(output[0]).any():
not isinstance(output[0], np.random.RandomState | np.random.Generator)
and np.isnan(output[0]).any()
):
print("*** NaN detected ***") # noqa: T201 print("*** NaN detected ***") # noqa: T201
debugprint(node) debugprint(node)
print(f"Inputs : {[input[0] for input in fn.inputs]}") # noqa: T201 print(f"Inputs : {[input[0] for input in fn.inputs]}") # noqa: T201
......
...@@ -34,7 +34,7 @@ def _is_numeric_value(arr, var): ...@@ -34,7 +34,7 @@ def _is_numeric_value(arr, var):
if isinstance(arr, _cdata_type): if isinstance(arr, _cdata_type):
return False return False
elif isinstance(arr, np.random.mtrand.RandomState | np.random.Generator): elif isinstance(arr, np.random.Generator):
return False return False
elif var is not None and isinstance(var.type, RandomType): elif var is not None and isinstance(var.type, RandomType):
return False return False
......
import warnings import warnings
from numpy.random import Generator, RandomState from numpy.random import Generator
from pytensor.compile.sharedvalue import SharedVariable, shared from pytensor.compile.sharedvalue import SharedVariable, shared
from pytensor.link.basic import JITLinker from pytensor.link.basic import JITLinker
...@@ -21,7 +21,7 @@ class JAXLinker(JITLinker): ...@@ -21,7 +21,7 @@ class JAXLinker(JITLinker):
# Replace any shared RNG inputs so that their values can be updated in place # Replace any shared RNG inputs so that their values can be updated in place
# without affecting the original RNG container. This is necessary because # without affecting the original RNG container. This is necessary because
# JAX does not accept RandomState/Generators as inputs, and they will have to # JAX does not accept Generators as inputs, and they will have to
# be tipyfied # be tipyfied
if shared_rng_inputs: if shared_rng_inputs:
warnings.warn( warnings.warn(
...@@ -79,7 +79,7 @@ class JAXLinker(JITLinker): ...@@ -79,7 +79,7 @@ class JAXLinker(JITLinker):
thunk_inputs = [] thunk_inputs = []
for n in self.fgraph.inputs: for n in self.fgraph.inputs:
sinput = storage_map[n] sinput = storage_map[n]
if isinstance(sinput[0], RandomState | Generator): if isinstance(sinput[0], Generator):
new_value = jax_typify( new_value = jax_typify(
sinput[0], dtype=getattr(sinput[0], "dtype", None) sinput[0], dtype=getattr(sinput[0], "dtype", None)
) )
......
...@@ -16,22 +16,4 @@ class NumbaLinker(JITLinker): ...@@ -16,22 +16,4 @@ class NumbaLinker(JITLinker):
return jitted_fn return jitted_fn
def create_thunk_inputs(self, storage_map): def create_thunk_inputs(self, storage_map):
from numpy.random import RandomState return [storage_map[n] for n in self.fgraph.inputs]
from pytensor.link.numba.dispatch import numba_typify
thunk_inputs = []
for n in self.fgraph.inputs:
sinput = storage_map[n]
if isinstance(sinput[0], RandomState):
new_value = numba_typify(
sinput[0], dtype=getattr(sinput[0], "dtype", None)
)
# We need to remove the reference-based connection to the
# original `RandomState`/shared variable's storage, because
# subsequent attempts to use the same shared variable within
# other non-Numba-fied graphs will have problems.
sinput = [new_value]
thunk_inputs.append(sinput)
return thunk_inputs
from typing import TypeVar from typing import TypeVar
import numpy as np import numpy as np
from numpy.random import Generator
import pytensor import pytensor
from pytensor.graph.type import Type from pytensor.graph.type import Type
T = TypeVar("T", np.random.RandomState, np.random.Generator) T = TypeVar("T")
gen_states_keys = { gen_states_keys = {
...@@ -24,14 +25,10 @@ numpy_bit_gens = {0: "MT19937", 1: "PCG64", 2: "Philox", 3: "SFC64"} ...@@ -24,14 +25,10 @@ numpy_bit_gens = {0: "MT19937", 1: "PCG64", 2: "Philox", 3: "SFC64"}
class RandomType(Type[T]): class RandomType(Type[T]):
r"""A Type wrapper for `numpy.random.Generator` and `numpy.random.RandomState`.""" r"""A Type wrapper for `numpy.random.Generator."""
@staticmethod
def may_share_memory(a: T, b: T):
return a._bit_generator is b._bit_generator # type: ignore[attr-defined]
class RandomGeneratorType(RandomType[np.random.Generator]): class RandomGeneratorType(RandomType[Generator]):
r"""A Type wrapper for `numpy.random.Generator`. r"""A Type wrapper for `numpy.random.Generator`.
The reason this exists (and `Generic` doesn't suffice) is that The reason this exists (and `Generic` doesn't suffice) is that
...@@ -47,6 +44,10 @@ class RandomGeneratorType(RandomType[np.random.Generator]): ...@@ -47,6 +44,10 @@ class RandomGeneratorType(RandomType[np.random.Generator]):
def __repr__(self): def __repr__(self):
return "RandomGeneratorType" return "RandomGeneratorType"
@staticmethod
def may_share_memory(a: Generator, b: Generator):
return a._bit_generator is b._bit_generator # type: ignore[attr-defined]
def filter(self, data, strict=False, allow_downcast=None): def filter(self, data, strict=False, allow_downcast=None):
""" """
XXX: This doesn't convert `data` to the same type of underlying RNG type XXX: This doesn't convert `data` to the same type of underlying RNG type
...@@ -58,7 +59,7 @@ class RandomGeneratorType(RandomType[np.random.Generator]): ...@@ -58,7 +59,7 @@ class RandomGeneratorType(RandomType[np.random.Generator]):
`Type.filter`, we need to have it here to avoid surprising circular `Type.filter`, we need to have it here to avoid surprising circular
dependencies in sub-classes. dependencies in sub-classes.
""" """
if isinstance(data, np.random.Generator): if isinstance(data, Generator):
return data return data
if not strict and isinstance(data, dict): if not strict and isinstance(data, dict):
......
...@@ -27,8 +27,7 @@ def fetch_seed(pseed=None): ...@@ -27,8 +27,7 @@ def fetch_seed(pseed=None):
If config.unittest.rseed is set to "random", it will seed the rng with If config.unittest.rseed is set to "random", it will seed the rng with
None, which is equivalent to seeding with a random seed. None, which is equivalent to seeding with a random seed.
Useful for seeding RandomState or Generator objects. Useful for seeding Generator objects.
>>> rng = np.random.RandomState(fetch_seed())
>>> rng = np.random.default_rng(fetch_seed()) >>> rng = np.random.default_rng(fetch_seed())
""" """
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论