提交 cc674a11 authored 作者: emekaokoli19's avatar emekaokoli19 提交者: Ricardo Vieira

Faster RNG deepcopy

上级 370b172c
from collections.abc import Callable from collections.abc import Callable
from copy import deepcopy
from functools import singledispatch from functools import singledispatch
from hashlib import sha256 from hashlib import sha256
from textwrap import dedent from textwrap import dedent
...@@ -32,6 +31,7 @@ from pytensor.link.utils import ( ...@@ -32,6 +31,7 @@ from pytensor.link.utils import (
) )
from pytensor.tensor import get_vector_length from pytensor.tensor import get_vector_length
from pytensor.tensor.random.op import RandomVariable, RandomVariableWithCoreShape from pytensor.tensor.random.op import RandomVariable, RandomVariableWithCoreShape
from pytensor.tensor.random.utils import custom_rng_deepcopy
from pytensor.tensor.type_other import NoneTypeT from pytensor.tensor.type_other import NoneTypeT
from pytensor.tensor.utils import _parse_gufunc_signature from pytensor.tensor.utils import _parse_gufunc_signature
...@@ -42,7 +42,7 @@ def numba_deepcopy_random_generator(x): ...@@ -42,7 +42,7 @@ def numba_deepcopy_random_generator(x):
def random_generator_deepcopy(x): def random_generator_deepcopy(x):
with numba.objmode(new_rng=types.npy_rng): with numba.objmode(new_rng=types.npy_rng):
new_rng = deepcopy(x) new_rng = custom_rng_deepcopy(x)
return new_rng return new_rng
return random_generator_deepcopy return random_generator_deepcopy
......
import abc import abc
import warnings import warnings
from collections.abc import Sequence from collections.abc import Sequence
from copy import deepcopy
from typing import Any, cast from typing import Any, cast
import numpy as np import numpy as np
...@@ -23,6 +22,7 @@ from pytensor.tensor.blockwise import OpWithCoreShape ...@@ -23,6 +22,7 @@ from pytensor.tensor.blockwise import OpWithCoreShape
from pytensor.tensor.random.type import RandomGeneratorType, RandomType from pytensor.tensor.random.type import RandomGeneratorType, RandomType
from pytensor.tensor.random.utils import ( from pytensor.tensor.random.utils import (
compute_batch_shape, compute_batch_shape,
custom_rng_deepcopy,
explicit_expand_dims, explicit_expand_dims,
normalize_size_param, normalize_size_param,
) )
...@@ -423,7 +423,7 @@ class RandomVariable(RNGConsumerOp): ...@@ -423,7 +423,7 @@ class RandomVariable(RNGConsumerOp):
# Draw from `rng` if `self.inplace` is `True`, and from a copy of `rng` otherwise. # Draw from `rng` if `self.inplace` is `True`, and from a copy of `rng` otherwise.
if not self.inplace: if not self.inplace:
rng = deepcopy(rng) rng = custom_rng_deepcopy(rng)
outputs[0][0] = rng outputs[0][0] = rng
outputs[1][0] = np.asarray( outputs[1][0] = np.asarray(
......
from collections.abc import Callable, Sequence from collections.abc import Callable, Sequence
from copy import deepcopy
from functools import wraps from functools import wraps
from itertools import zip_longest from itertools import zip_longest
from types import ModuleType from types import ModuleType
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
import numpy as np import numpy as np
from numpy.random import Generator
from pytensor.compile.sharedvalue import shared from pytensor.compile.sharedvalue import shared
from pytensor.graph.basic import Variable from pytensor.graph.basic import Variable
...@@ -204,6 +206,16 @@ def normalize_size_param( ...@@ -204,6 +206,16 @@ def normalize_size_param(
return shape return shape
def custom_rng_deepcopy(rng):
# This helper exists because copying numpy.random.Generator via deepcopy is slow.
# NumPy may implement a faster clone/copy API in the future:
# https://github.com/numpy/numpy/issues/24086
old_bitgen = rng.bit_generator
new_bitgen = type(old_bitgen)(deepcopy(old_bitgen._seed_seq))
new_bitgen.state = old_bitgen.state
return Generator(new_bitgen)
class RandomStream: class RandomStream:
"""Module component with similar interface to `numpy.random.Generator`. """Module component with similar interface to `numpy.random.Generator`.
......
from copy import deepcopy
import numpy as np import numpy as np
import pytest import pytest
...@@ -7,6 +9,7 @@ from pytensor.graph.rewriting.db import RewriteDatabaseQuery ...@@ -7,6 +9,7 @@ from pytensor.graph.rewriting.db import RewriteDatabaseQuery
from pytensor.tensor.random.utils import ( from pytensor.tensor.random.utils import (
RandomStream, RandomStream,
broadcast_params, broadcast_params,
custom_rng_deepcopy,
normalize_size_param, normalize_size_param,
supp_shape_from_ref_param_shape, supp_shape_from_ref_param_shape,
) )
...@@ -348,3 +351,28 @@ def test_normalize_size_param(): ...@@ -348,3 +351,28 @@ def test_normalize_size_param():
sym_tensor_size = tensor(shape=(3,), dtype="int64") sym_tensor_size = tensor(shape=(3,), dtype="int64")
assert normalize_size_param(sym_tensor_size) is sym_tensor_size assert normalize_size_param(sym_tensor_size) is sym_tensor_size
def test_custom_rng_deepcopy_matches_deepcopy():
rng = np.random.default_rng(123)
dp = deepcopy(rng).bit_generator
fc = custom_rng_deepcopy(rng).bit_generator
# Same state
assert dp.state == fc.state
# Same seed sequence
assert dp.seed_seq.state == fc.seed_seq.state
def test_custom_rng_deepcopy_output_identical():
rng = np.random.default_rng(123)
rng1 = deepcopy(rng)
rng2 = custom_rng_deepcopy(rng)
# Generate numbers from each
x1 = rng1.normal(size=10)
x2 = rng2.normal(size=10)
assert np.allclose(x1, x2)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论