提交 81462c66 authored 作者: kc611's avatar kc611 提交者: Brandon T. Willard

Added Numba Type for RandomStates

上级 97269458
......@@ -430,7 +430,9 @@ def numba_funcify_DeepCopyOp(op, node, **kwargs):
# Scalars are apparently returned as actual Python scalar types and not
# NumPy scalars, so we need two separate Numba functions for each case.
if node.outputs[0].type.ndim == 0:
# The type can also be RandomType with no ndims
if not hasattr(node.outputs[0].type, "ndim") or node.outputs[0].type.ndim == 0:
# TODO: Do we really need to compile a pass-through function like this?
@numba.njit(inline="always")
def deepcopyop(x):
......
......@@ -4,7 +4,9 @@ from typing import Any, Callable, Dict, Optional
import numba
import numba.np.unsafe.ndarray as numba_ndarray
import numpy as np
from numba import _helperlib
from numba import _helperlib, types
from numba.core import cgutils
from numba.extending import NativeValue, box, models, register_model, typeof_impl, unbox
from numpy.random import RandomState
import aesara.tensor.random.basic as aer
......@@ -22,12 +24,71 @@ from aesara.tensor.random.type import RandomStateType
from aesara.tensor.random.var import RandomStateSharedVariable
class RandomStateNumbaType(types.Type):
def __init__(self):
super(RandomStateNumbaType, self).__init__(name="RandomState")
random_state_numba_type = RandomStateNumbaType()
@typeof_impl.register(RandomState)
def typeof_index(val, c):
return random_state_numba_type
@register_model(RandomStateNumbaType)
class RandomStateNumbaModel(models.StructModel):
def __init__(self, dmm, fe_type):
members = [
# TODO: We can add support for boxing and unboxing
# the attributes that describe a RandomState so that
# they can be accessed inside njit functions, if required.
("state_key", types.Array(types.uint32, 1, "C")),
]
models.StructModel.__init__(self, dmm, fe_type, members)
@unbox(RandomStateNumbaType)
def unbox_random_state(typ, obj, c):
"""Convert a `RandomState` object to a native `RandomStateNumbaModel` structure.
Note that this will create a 'fake' structure which will just get the
`RandomState` objects accepted in Numba functions but the actual information
of the Numba's random state is stored internally and can be accessed
anytime using ``numba._helperlib.rnd_get_np_state_ptr()``.
"""
interval = cgutils.create_struct_proxy(typ)(c.context, c.builder)
is_error = cgutils.is_not_null(c.builder, c.pyapi.err_occurred())
return NativeValue(interval._getvalue(), is_error=is_error)
@box(RandomStateNumbaType)
def box_random_state(typ, val, c):
"""Convert a native `RandomStateNumbaModel` structure to an `RandomState` object
using Numba's internal state array.
Note that `RandomStateNumbaModel` is just a placeholder structure with no
inherent information about Numba internal random state, all that information
is instead retrieved from Numba using ``_helperlib.rnd_get_state()`` and a new
`RandomState` is constructed using the Numba's current internal state.
"""
pos, state_list = _helperlib.rnd_get_state(_helperlib.rnd_get_np_state_ptr())
rng = RandomState()
rng.set_state(("MT19937", state_list, pos))
class_obj = c.pyapi.unserialize(c.pyapi.serialize_object(rng))
return class_obj
@numba_typify.register(RandomState)
def numba_typify_RandomState(state, **kwargs):
# The numba_typify in this case is just an passthrough function
# that synchronizes Numba's internal random state with the current
# RandomState object
ints, index = state.get_state()[1:3]
ptr = _helperlib.rnd_get_np_state_ptr()
_helperlib.rnd_set_state(ptr, (index, [int(x) for x in ints]))
return ints
return state
def make_numba_random_fn(node, np_random_func):
......
......@@ -27,6 +27,7 @@ from aesara.graph.optdb import OptimizationQuery
from aesara.graph.type import Type
from aesara.ifelse import ifelse
from aesara.link.numba.dispatch import basic as numba_basic
from aesara.link.numba.dispatch import numba_typify
from aesara.link.numba.linker import NumbaLinker
from aesara.scalar.basic import Composite
from aesara.scan.basic import scan
......@@ -293,6 +294,26 @@ def test_create_numba_signature(v, expected, force_scalar):
assert res == expected
@pytest.mark.parametrize(
"input, wrapper_fn, check_fn",
[
(
np.random.RandomState(1),
numba_typify,
lambda x, y: np.all(x.get_state()[1] == y.get_state()[1]),
)
],
)
def test_numba_box_unbox(input, wrapper_fn, check_fn):
input = wrapper_fn(input)
pass_through = numba.njit(lambda x: x)
res = pass_through(input)
assert isinstance(res, type(input))
assert check_fn(res, input)
@pytest.mark.parametrize(
"inputs, input_vals, output_fn, exc",
[
......@@ -2925,6 +2946,17 @@ def test_RandomVariable(rv_op, dist_args, size):
)
def test_RandomState_updates():
rng = shared(np.random.RandomState(1))
rng_new = shared(np.random.RandomState(2))
x = aet.random.normal(size=10, rng=rng)
res = function([], x, updates={rng: rng_new}, mode=numba_mode)()
ref = np.random.RandomState(2).normal(size=10)
assert np.allclose(res, ref)
def test_random_Generator():
rng = shared(np.random.default_rng(29402))
g = aer.normal(rng=rng)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论