提交 cc674a11 authored 作者: emekaokoli19's avatar emekaokoli19 提交者: Ricardo Vieira

Faster RNG deepcopy

上级 370b172c
from collections.abc import Callable
from copy import deepcopy
from functools import singledispatch
from hashlib import sha256
from textwrap import dedent
......@@ -32,6 +31,7 @@ from pytensor.link.utils import (
)
from pytensor.tensor import get_vector_length
from pytensor.tensor.random.op import RandomVariable, RandomVariableWithCoreShape
from pytensor.tensor.random.utils import custom_rng_deepcopy
from pytensor.tensor.type_other import NoneTypeT
from pytensor.tensor.utils import _parse_gufunc_signature
......@@ -42,7 +42,7 @@ def numba_deepcopy_random_generator(x):
def random_generator_deepcopy(x):
with numba.objmode(new_rng=types.npy_rng):
new_rng = deepcopy(x)
new_rng = custom_rng_deepcopy(x)
return new_rng
return random_generator_deepcopy
......
import abc
import warnings
from collections.abc import Sequence
from copy import deepcopy
from typing import Any, cast
import numpy as np
......@@ -23,6 +22,7 @@ from pytensor.tensor.blockwise import OpWithCoreShape
from pytensor.tensor.random.type import RandomGeneratorType, RandomType
from pytensor.tensor.random.utils import (
compute_batch_shape,
custom_rng_deepcopy,
explicit_expand_dims,
normalize_size_param,
)
......@@ -423,7 +423,7 @@ class RandomVariable(RNGConsumerOp):
# Draw from `rng` if `self.inplace` is `True`, and from a copy of `rng` otherwise.
if not self.inplace:
rng = deepcopy(rng)
rng = custom_rng_deepcopy(rng)
outputs[0][0] = rng
outputs[1][0] = np.asarray(
......
from collections.abc import Callable, Sequence
from copy import deepcopy
from functools import wraps
from itertools import zip_longest
from types import ModuleType
from typing import TYPE_CHECKING
import numpy as np
from numpy.random import Generator
from pytensor.compile.sharedvalue import shared
from pytensor.graph.basic import Variable
......@@ -204,6 +206,16 @@ def normalize_size_param(
return shape
def custom_rng_deepcopy(rng):
# This helper exists because copying numpy.random.Generator via deepcopy is slow.
# NumPy may implement a faster clone/copy API in the future:
# https://github.com/numpy/numpy/issues/24086
old_bitgen = rng.bit_generator
new_bitgen = type(old_bitgen)(deepcopy(old_bitgen._seed_seq))
new_bitgen.state = old_bitgen.state
return Generator(new_bitgen)
class RandomStream:
"""Module component with similar interface to `numpy.random.Generator`.
......
from copy import deepcopy
import numpy as np
import pytest
......@@ -7,6 +9,7 @@ from pytensor.graph.rewriting.db import RewriteDatabaseQuery
from pytensor.tensor.random.utils import (
RandomStream,
broadcast_params,
custom_rng_deepcopy,
normalize_size_param,
supp_shape_from_ref_param_shape,
)
......@@ -348,3 +351,28 @@ def test_normalize_size_param():
sym_tensor_size = tensor(shape=(3,), dtype="int64")
assert normalize_size_param(sym_tensor_size) is sym_tensor_size
def test_custom_rng_deepcopy_matches_deepcopy():
rng = np.random.default_rng(123)
dp = deepcopy(rng).bit_generator
fc = custom_rng_deepcopy(rng).bit_generator
# Same state
assert dp.state == fc.state
# Same seed sequence
assert dp.seed_seq.state == fc.seed_seq.state
def test_custom_rng_deepcopy_output_identical():
rng = np.random.default_rng(123)
rng1 = deepcopy(rng)
rng2 = custom_rng_deepcopy(rng)
# Generate numbers from each
x1 = rng1.normal(size=10)
x2 = rng2.normal(size=10)
assert np.allclose(x1, x2)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论