提交 14da898c authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Add support for RandomVariable with Generators in Numba backend and drop support for RandomState

上级 47874eb9
......@@ -27,7 +27,6 @@ from pytensor.graph.op import HasInnerGraph, Op
from pytensor.graph.replace import clone_replace
from pytensor.graph.rewriting.basic import in2out, node_rewriter
from pytensor.graph.utils import MissingInputError
from pytensor.tensor.rewriting.shape import ShapeFeature
def infer_shape(outs, inputs, input_shapes):
......@@ -43,6 +42,10 @@ def infer_shape(outs, inputs, input_shapes):
# inside. We don't use the full ShapeFeature interface, but we
# let it initialize itself with an empty fgraph, otherwise we will
# need to do it manually
# TODO: ShapeFeature should live elsewhere
from pytensor.tensor.rewriting.shape import ShapeFeature
for inp, inp_shp in zip(inputs, input_shapes):
if inp_shp is not None and len(inp_shp) != inp.type.ndim:
assert len(inp_shp) == inp.type.ndim
......@@ -307,6 +310,7 @@ class OpFromGraph(Op, HasInnerGraph):
connection_pattern: list[list[bool]] | None = None,
strict: bool = False,
name: str | None = None,
destroy_map: dict[int, tuple[int, ...]] | None = None,
**kwargs,
):
"""
......@@ -464,6 +468,7 @@ class OpFromGraph(Op, HasInnerGraph):
if name is not None:
assert isinstance(name, str), "name must be None or string object"
self.name = name
self.destroy_map = destroy_map if destroy_map is not None else {}
def __eq__(self, other):
# TODO: recognize a copy
......@@ -862,6 +867,7 @@ class OpFromGraph(Op, HasInnerGraph):
rop_overrides=self.rop_overrides,
connection_pattern=self._connection_pattern,
name=self.name,
destroy_map=self.destroy_map,
**self.kwargs,
)
new_inputs = (
......
......@@ -463,7 +463,7 @@ JAX = Mode(
NUMBA = Mode(
NumbaLinker(),
RewriteDatabaseQuery(
include=["fast_run"],
include=["fast_run", "numba"],
exclude=["cxx_only", "BlasOpt", "local_careduce_fusion"],
),
)
......
......@@ -18,6 +18,7 @@ from numba.cpython.unsafe.tuple import tuple_setitem # noqa: F401
from numba.extending import box, overload
from pytensor import config
from pytensor.compile import NUMBA
from pytensor.compile.builders import OpFromGraph
from pytensor.compile.ops import DeepCopyOp
from pytensor.graph.basic import Apply
......@@ -440,6 +441,11 @@ def numba_funcify(op, node=None, storage_map=None, **kwargs):
def numba_funcify_OpFromGraph(op, node=None, **kwargs):
_ = kwargs.pop("storage_map", None)
# Apply inner rewrites
# TODO: Not sure this is the right place to do this, should we have a rewrite that
# explicitly triggers the optimization of the inner graphs of OpFromGraph?
# The C-code defers it to the make_thunk phase
NUMBA.optimizer(op.fgraph)
fgraph_fn = numba_njit(numba_funcify(op.fgraph, **kwargs))
if len(op.fgraph.outputs) == 1:
......
from collections.abc import Callable
from textwrap import dedent, indent
from typing import Any
from copy import copy
from functools import singledispatch
from textwrap import dedent
import numba
import numba.np.unsafe.ndarray as numba_ndarray
import numpy as np
from numba import _helperlib, types
from numba.core import cgutils
from numba.extending import NativeValue, box, models, register_model, typeof_impl, unbox
from numpy.random import RandomState
from numba import types
from numba.core.extending import overload
import pytensor.tensor.random.basic as ptr
from pytensor.graph.basic import Apply
from pytensor.graph import Apply
from pytensor.graph.op import Op
from pytensor.link.numba.dispatch import basic as numba_basic
from pytensor.link.numba.dispatch.basic import numba_funcify, numba_typify
from pytensor.link.numba.dispatch.basic import direct_cast, numba_funcify
from pytensor.link.numba.dispatch.vectorize_codegen import (
_jit_options,
_vectorized,
encode_literals,
store_core_outputs,
)
from pytensor.link.utils import (
compile_function_src,
get_name_for_object,
unique_name_generator,
)
from pytensor.tensor.basic import get_vector_length
from pytensor.tensor import get_vector_length
from pytensor.tensor.random.op import RandomVariable, RandomVariableWithCoreShape
from pytensor.tensor.random.type import RandomStateType
from pytensor.tensor.type_other import NoneTypeT
class RandomStateNumbaType(types.Type):
def __init__(self):
super().__init__(name="RandomState")
random_state_numba_type = RandomStateNumbaType()
@typeof_impl.register(RandomState)
def typeof_index(val, c):
return random_state_numba_type
@register_model(RandomStateNumbaType)
class RandomStateNumbaModel(models.StructModel):
def __init__(self, dmm, fe_type):
members = [
# TODO: We can add support for boxing and unboxing
# the attributes that describe a RandomState so that
# they can be accessed inside njit functions, if required.
("state_key", types.Array(types.uint32, 1, "C")),
]
models.StructModel.__init__(self, dmm, fe_type, members)
@unbox(RandomStateNumbaType)
def unbox_random_state(typ, obj, c):
"""Convert a `RandomState` object to a native `RandomStateNumbaModel` structure.
Note that this will create a 'fake' structure which will just get the
`RandomState` objects accepted in Numba functions but the actual information
of the Numba's random state is stored internally and can be accessed
anytime using ``numba._helperlib.rnd_get_np_state_ptr()``.
from pytensor.tensor.utils import _parse_gufunc_signature
@overload(copy)
def copy_NumPyRandomGenerator(rng):
def impl(rng):
# TODO: Open issue on Numba?
with numba.objmode(new_rng=types.npy_rng):
new_rng = copy(rng)
return new_rng
return impl
@singledispatch
def numba_core_rv_funcify(op: Op, node: Apply) -> Callable:
"""Return the core function for a random variable operation."""
raise NotImplementedError(f"Core implementation of {op} not implemented.")
@numba_core_rv_funcify.register(ptr.UniformRV)
@numba_core_rv_funcify.register(ptr.TriangularRV)
@numba_core_rv_funcify.register(ptr.BetaRV)
@numba_core_rv_funcify.register(ptr.NormalRV)
@numba_core_rv_funcify.register(ptr.LogNormalRV)
@numba_core_rv_funcify.register(ptr.GammaRV)
@numba_core_rv_funcify.register(ptr.ExponentialRV)
@numba_core_rv_funcify.register(ptr.WeibullRV)
@numba_core_rv_funcify.register(ptr.LogisticRV)
@numba_core_rv_funcify.register(ptr.VonMisesRV)
@numba_core_rv_funcify.register(ptr.PoissonRV)
@numba_core_rv_funcify.register(ptr.GeometricRV)
# @numba_core_rv_funcify.register(ptr.HyperGeometricRV) # Not implemented in numba
@numba_core_rv_funcify.register(ptr.WaldRV)
@numba_core_rv_funcify.register(ptr.LaplaceRV)
@numba_core_rv_funcify.register(ptr.BinomialRV)
@numba_core_rv_funcify.register(ptr.NegBinomialRV)
@numba_core_rv_funcify.register(ptr.MultinomialRV)
@numba_core_rv_funcify.register(ptr.PermutationRV)
@numba_core_rv_funcify.register(ptr.IntegersRV)
def numba_core_rv_default(op, node):
"""Create a default RV core numba function.
@njit
def random(rng, i0, i1, ..., in):
return rng.name(i0, i1, ..., in)
"""
interval = cgutils.create_struct_proxy(typ)(c.context, c.builder)
is_error = cgutils.is_not_null(c.builder, c.pyapi.err_occurred())
return NativeValue(interval._getvalue(), is_error=is_error)
@box(RandomStateNumbaType)
def box_random_state(typ, val, c):
"""Convert a native `RandomStateNumbaModel` structure to an `RandomState` object
using Numba's internal state array.
Note that `RandomStateNumbaModel` is just a placeholder structure with no
inherent information about Numba internal random state, all that information
is instead retrieved from Numba using ``_helperlib.rnd_get_state()`` and a new
`RandomState` is constructed using the Numba's current internal state.
"""
pos, state_list = _helperlib.rnd_get_state(_helperlib.rnd_get_np_state_ptr())
rng = RandomState()
rng.set_state(("MT19937", state_list, pos))
class_obj = c.pyapi.unserialize(c.pyapi.serialize_object(rng))
return class_obj
name = op.name
@numba_typify.register(RandomState)
def numba_typify_RandomState(state, **kwargs):
# The numba_typify in this case is just an passthrough function
# that synchronizes Numba's internal random state with the current
# RandomState object
ints, index = state.get_state()[1:3]
ptr = _helperlib.rnd_get_np_state_ptr()
_helperlib.rnd_set_state(ptr, (index, [int(x) for x in ints]))
return state
inputs = [f"i{i}" for i in range(len(op.ndims_params))]
input_signature = ",".join(inputs)
func_src = dedent(f"""
def {name}(rng, {input_signature}):
return rng.{name}({input_signature})
""")
def make_numba_random_fn(node, np_random_func):
"""Create Numba implementations for existing Numba-supported ``np.random`` functions.
func = compile_function_src(func_src, name, {**globals()})
return numba_basic.numba_njit(func)
The functions generated here add parameter broadcasting and the ``size``
argument to the Numba-supported scalar ``np.random`` functions.
"""
op: ptr.RandomVariable = node.op
rng_param = op.rng_param(node)
if not isinstance(rng_param.type, RandomStateType):
raise TypeError("Numba does not support NumPy `Generator`s")
size_param = op.size_param(node)
size_len = (
None
if isinstance(size_param.type, NoneTypeT)
else int(get_vector_length(size_param))
)
dist_params = op.dist_params(node)
# Make a broadcast-capable version of the Numba supported scalar sampling
# function
bcast_fn_name = f"pytensor_random_{get_name_for_object(np_random_func)}"
sized_fn_name = "sized_random_variable"
unique_names = unique_name_generator(
[
bcast_fn_name,
sized_fn_name,
"np",
"np_random_func",
"numba_vectorize",
"to_fixed_tuple",
"size_len",
"size_dims",
"rng",
"size",
],
suffix_sep="_",
)
bcast_fn_input_names = ", ".join(
[unique_names(i, force_unique=True) for i in dist_params]
)
bcast_fn_global_env = {
"np_random_func": np_random_func,
"numba_vectorize": numba_basic.numba_vectorize,
}
bcast_fn_src = f"""
@numba_vectorize
def {bcast_fn_name}({bcast_fn_input_names}):
return np_random_func({bcast_fn_input_names})
"""
bcast_fn = compile_function_src(
bcast_fn_src, bcast_fn_name, {**globals(), **bcast_fn_global_env}
)
random_fn_input_names = ", ".join(
["rng", "size"] + [unique_names(i) for i in dist_params]
)
# Now, create a Numba JITable function that implements the `size` parameter
@numba_core_rv_funcify.register(ptr.BernoulliRV)
def numba_core_BernoulliRV(op, node):
out_dtype = node.outputs[1].type.numpy_dtype
random_fn_global_env = {
bcast_fn_name: bcast_fn,
"out_dtype": out_dtype,
}
if size_len is not None:
size_dims = size_len - max(i.ndim for i in dist_params)
random_fn_body = dedent(
f"""
size = to_fixed_tuple(size, size_len)
@numba_basic.numba_njit()
def random(rng, p):
return (
direct_cast(0, out_dtype)
if p < rng.uniform()
else direct_cast(1, out_dtype)
)
data = np.empty(size, dtype=out_dtype)
for i in np.ndindex(size[:size_dims]):
data[i] = {bcast_fn_name}({bcast_fn_input_names})
return random
"""
)
random_fn_global_env.update(
{
"np": np,
"to_fixed_tuple": numba_ndarray.to_fixed_tuple,
"size_len": size_len,
"size_dims": size_dims,
}
)
else:
random_fn_body = f"""data = {bcast_fn_name}({bcast_fn_input_names})"""
sized_fn_src = dedent(
f"""
def {sized_fn_name}({random_fn_input_names}):
{indent(random_fn_body, " " * 4)}
return (rng, data)
"""
)
random_fn = compile_function_src(
sized_fn_src, sized_fn_name, {**globals(), **random_fn_global_env}
)
random_fn = numba_basic.numba_njit(random_fn)
@numba_core_rv_funcify.register(ptr.HalfNormalRV)
def numba_core_HalfNormalRV(op, node):
@numba_basic.numba_njit
def random_fn(rng, loc, scale):
return loc + scale * np.abs(rng.standard_normal())
return random_fn
@numba_funcify.register(ptr.UniformRV)
@numba_funcify.register(ptr.TriangularRV)
@numba_funcify.register(ptr.BetaRV)
@numba_funcify.register(ptr.NormalRV)
@numba_funcify.register(ptr.LogNormalRV)
@numba_funcify.register(ptr.GammaRV)
@numba_funcify.register(ptr.ParetoRV)
@numba_funcify.register(ptr.GumbelRV)
@numba_funcify.register(ptr.ExponentialRV)
@numba_funcify.register(ptr.WeibullRV)
@numba_funcify.register(ptr.LogisticRV)
@numba_funcify.register(ptr.VonMisesRV)
@numba_funcify.register(ptr.PoissonRV)
@numba_funcify.register(ptr.GeometricRV)
@numba_funcify.register(ptr.HyperGeometricRV)
@numba_funcify.register(ptr.WaldRV)
@numba_funcify.register(ptr.LaplaceRV)
@numba_funcify.register(ptr.BinomialRV)
@numba_funcify.register(ptr.MultinomialRV)
@numba_funcify.register(ptr.RandIntRV) # only the first two arguments are supported
@numba_funcify.register(ptr.PermutationRV)
def numba_funcify_RandomVariable(op, node, **kwargs):
name = op.name
np_random_func = getattr(np.random, name)
return make_numba_random_fn(node, np_random_func)
@numba_core_rv_funcify.register(ptr.CauchyRV)
def numba_core_CauchyRV(op, node):
@numba_basic.numba_njit
def random(rng, loc, scale):
return (loc + rng.standard_cauchy()) / scale
return random
def create_numba_random_fn(
op: Op,
node: Apply,
scalar_fn: Callable[[str], str],
global_env: dict[str, Any] | None = None,
) -> Callable:
"""Create a vectorized function from a callable that generates the ``str`` function body.
TODO: This could/should be generalized for other simple function
construction cases that need unique-ified symbol names.
"""
np_random_fn_name = f"pytensor_random_{get_name_for_object(op.name)}"
@numba_core_rv_funcify.register(ptr.ParetoRV)
def numba_core_ParetoRV(op, node):
@numba_basic.numba_njit
def random(rng, b, scale):
# Follows scipy implementation
U = rng.random()
return np.power(1 - U, -1 / b) * scale
if global_env:
np_global_env = global_env.copy()
else:
np_global_env = {}
return random
np_global_env["np"] = np
np_global_env["numba_vectorize"] = numba_basic.numba_vectorize
unique_names = unique_name_generator(
[np_random_fn_name, *np_global_env.keys(), "rng", "size"],
suffix_sep="_",
)
@numba_core_rv_funcify.register(ptr.CategoricalRV)
def core_CategoricalRV(op, node):
@numba_basic.numba_njit
def random_fn(rng, p):
unif_sample = rng.uniform(0, 1)
return np.searchsorted(np.cumsum(p), unif_sample)
dist_params = op.dist_params(node)
np_names = [unique_names(i, force_unique=True) for i in dist_params]
np_input_names = ", ".join(np_names)
np_random_fn_src = f"""
@numba_vectorize
def {np_random_fn_name}({np_input_names}):
{scalar_fn(*np_names)}
"""
np_random_fn = compile_function_src(
np_random_fn_src, np_random_fn_name, {**globals(), **np_global_env}
)
return random_fn
return make_numba_random_fn(node, np_random_fn)
@numba_core_rv_funcify.register(ptr.MvNormalRV)
def core_MvNormalRV(op, node):
@numba_basic.numba_njit
def random_fn(rng, mean, cov):
chol = np.linalg.cholesky(cov)
stdnorm = rng.normal(size=cov.shape[-1])
return np.dot(chol, stdnorm) + mean
@numba_funcify.register(ptr.NegBinomialRV)
def numba_funcify_NegBinomialRV(op, node, **kwargs):
return make_numba_random_fn(node, np.random.negative_binomial)
random_fn.handles_out = True
return random_fn
@numba_funcify.register(ptr.CauchyRV)
def numba_funcify_CauchyRV(op, node, **kwargs):
def body_fn(loc, scale):
return f" return ({loc} + np.random.standard_cauchy()) / {scale}"
@numba_core_rv_funcify.register(ptr.DirichletRV)
def core_DirichletRV(op, node):
@numba_basic.numba_njit
def random_fn(rng, alpha):
y = np.empty_like(alpha)
for i in range(len(alpha)):
y[i] = rng.gamma(alpha[i], 1.0)
return y / y.sum()
return create_numba_random_fn(op, node, body_fn)
return random_fn
@numba_funcify.register(ptr.HalfNormalRV)
def numba_funcify_HalfNormalRV(op, node, **kwargs):
def body_fn(a, b):
return f" return {a} + {b} * abs(np.random.normal(0, 1))"
@numba_core_rv_funcify.register(ptr.GumbelRV)
def core_GumbelRV(op, node):
"""Code adapted from Numpy Implementation
return create_numba_random_fn(op, node, body_fn)
https://github.com/numpy/numpy/blob/6f6be042c6208815b15b90ba87d04159bfa25fd3/numpy/random/src/distributions/distributions.c#L502-L511
"""
@numba_basic.numba_njit
def random_fn(rng, loc, scale):
U = 1.0 - rng.random()
if U < 1.0:
return loc - scale * np.log(-np.log(U))
else:
return random_fn(rng, loc, scale)
@numba_funcify.register(ptr.BernoulliRV)
def numba_funcify_BernoulliRV(op, node, **kwargs):
out_dtype = node.outputs[1].type.numpy_dtype
return random_fn
def body_fn(a):
return f"""
if {a} < np.random.uniform(0, 1):
return direct_cast(0, out_dtype)
else:
return direct_cast(1, out_dtype)
"""
return create_numba_random_fn(
op,
node,
body_fn,
{"out_dtype": out_dtype, "direct_cast": numba_basic.direct_cast},
)
@numba_core_rv_funcify.register(ptr.VonMisesRV)
def core_VonMisesRV(op, node):
"""Code adapted from Numpy Implementation
@numba_funcify.register(ptr.CategoricalRV)
def numba_funcify_CategoricalRV(op: ptr.CategoricalRV, node, **kwargs):
out_dtype = node.outputs[1].type.numpy_dtype
size_param = op.size_param(node)
size_len = (
None
if isinstance(size_param.type, NoneTypeT)
else int(get_vector_length(size_param))
)
p_ndim = node.inputs[-1].ndim
https://github.com/numpy/numpy/blob/6f6be042c6208815b15b90ba87d04159bfa25fd3/numpy/random/src/distributions/distributions.c#L855-L925
"""
@numba_basic.numba_njit
def categorical_rv(rng, size, p):
if size_len is None:
size_tpl = p.shape[:-1]
else:
size_tpl = numba_ndarray.to_fixed_tuple(size, size_len)
p = np.broadcast_to(p, size_tpl + p.shape[-1:])
# Workaround https://github.com/numba/numba/issues/8975
if size_len is None and p_ndim == 1:
unif_samples = np.asarray(np.random.uniform(0, 1))
def random_fn(rng, mu, kappa):
if np.isnan(kappa):
return np.nan
if kappa < 1e-8:
# Use a uniform for very small values of kappa
return np.pi * (2 * rng.random() - 1)
else:
unif_samples = np.random.uniform(0, 1, size_tpl)
res = np.empty(size_tpl, dtype=out_dtype)
for idx in np.ndindex(*size_tpl):
res[idx] = np.searchsorted(np.cumsum(p[idx]), unif_samples[idx])
return (rng, res)
# with double precision rho is zero until 1.4e-8
if kappa < 1e-5:
# second order taylor expansion around kappa = 0
# precise until relatively large kappas as second order is 0
s = 1.0 / kappa + kappa
else:
if kappa <= 1e6:
# Path for 1e-5 <= kappa <= 1e6
r = 1 + np.sqrt(1 + 4 * kappa * kappa)
rho = (r - np.sqrt(2 * r)) / (2 * kappa)
s = (1 + rho * rho) / (2 * rho)
else:
# Fallback to wrapped normal distribution for kappa > 1e6
result = mu + np.sqrt(1.0 / kappa) * rng.standard_normal()
# Ensure result is within bounds
if result < -np.pi:
result += 2 * np.pi
if result > np.pi:
result -= 2 * np.pi
return result
while True:
U = rng.random()
Z = np.cos(np.pi * U)
W = (1 + s * Z) / (s + Z)
Y = kappa * (s - W)
V = rng.random()
# V == 0.0 is ok here since Y >= 0 always leads
# to accept, while Y < 0 always rejects
if (Y * (2 - Y) - V >= 0) or (np.log(Y / V) + 1 - Y >= 0):
break
U = rng.random()
result = np.arccos(W)
if U < 0.5:
result = -result
result += mu
neg = result < 0
mod = np.abs(result)
mod = np.mod(mod + np.pi, 2 * np.pi) - np.pi
if neg:
mod *= -1
return mod
return categorical_rv
return random_fn
@numba_funcify.register(ptr.DirichletRV)
def numba_funcify_DirichletRV(op, node, **kwargs):
out_dtype = node.outputs[1].type.numpy_dtype
alphas_ndim = op.dist_params(node)[0].type.ndim
size_param = op.size_param(node)
size_len = (
None
if isinstance(size_param.type, NoneTypeT)
else int(get_vector_length(size_param))
)
@numba_core_rv_funcify.register(ptr.ChoiceWithoutReplacement)
def core_ChoiceWithoutReplacement(op: ptr.ChoiceWithoutReplacement, node):
[core_shape_len_sig] = _parse_gufunc_signature(op.signature)[0][-1]
core_shape_len = int(core_shape_len_sig)
implicit_arange = op.ndims_params[0] == 0
if alphas_ndim > 1:
if op.has_p_param:
@numba_basic.numba_njit
def dirichlet_rv(rng, size, alphas):
if size_len is None:
samples_shape = alphas.shape
def random_fn(rng, a, p, core_shape):
# Adapted from Numpy: https://github.com/numpy/numpy/blob/2a9b9134270371b43223fc848b753fceab96b4a5/numpy/random/_generator.pyx#L922-L941
size = np.prod(core_shape)
core_shape = numba_ndarray.to_fixed_tuple(core_shape, core_shape_len)
if implicit_arange:
pop_size = a
else:
size_tpl = numba_ndarray.to_fixed_tuple(size, size_len)
samples_shape = size_tpl + alphas.shape[-1:]
res = np.empty(samples_shape, dtype=out_dtype)
alphas_bcast = np.broadcast_to(alphas, samples_shape)
for index in np.ndindex(*samples_shape[:-1]):
res[index] = np.random.dirichlet(alphas_bcast[index])
return (rng, res)
pop_size = a.shape[0]
if size > pop_size:
raise ValueError(
"Cannot take a larger sample than population without replacement"
)
if np.count_nonzero(p > 0) < size:
raise ValueError("Fewer non-zero entries in p than size")
p = p.copy()
n_uniq = 0
idx = np.zeros(core_shape, dtype=np.int64)
flat_idx = idx.ravel()
while n_uniq < size:
x = rng.random((size - n_uniq,))
# Set the probabilities of items that have already been found to 0
p[flat_idx[:n_uniq]] = 0
# Take new (unique) categorical draws from the remaining probabilities
cdf = np.cumsum(p)
cdf /= cdf[-1]
new = np.searchsorted(cdf, x, side="right")
# Numba doesn't support return_index in np.unique
# _, unique_indices = np.unique(new, return_index=True)
# unique_indices.sort()
new.sort()
unique_indices = [
idx
for idx, prev_item in enumerate(new[:-1], 1)
if new[idx] != prev_item
]
unique_indices = np.array([0] + unique_indices) # noqa: RUF005
new = new[unique_indices]
flat_idx[n_uniq : n_uniq + new.size] = new
n_uniq += new.size
if implicit_arange:
return idx
else:
# Numba doesn't support advanced indexing, so we ravel index and reshape
return a[idx.ravel()].reshape(core_shape + a.shape[1:])
else:
@numba_basic.numba_njit
def dirichlet_rv(rng, size, alphas):
if size_len is not None:
size = numba_ndarray.to_fixed_tuple(size, size_len)
return (rng, np.random.dirichlet(alphas, size))
return dirichlet_rv
@numba_funcify.register(ptr.ChoiceWithoutReplacement)
def numba_funcify_choice_without_replacement(op, node, **kwargs):
batch_ndim = op.batch_ndim(node)
if batch_ndim:
# The code isn't too hard to write, but Numba doesn't support a with ndim > 1,
# and I don't want to change the batched tests for this
# We'll just raise an error for now
raise NotImplementedError(
"ChoiceWithoutReplacement with batch_ndim not supported in Numba backend"
)
def random_fn(rng, a, core_shape):
# Until Numba supports generator.choice we use a poor implementation
# that permutates the whole arange array and takes the first `size` elements
# This is widely inefficient when size << a.shape[0]
size = np.prod(core_shape)
core_shape = numba_ndarray.to_fixed_tuple(core_shape, core_shape_len)
idx = rng.permutation(size)[:size]
[core_shape_len] = node.inputs[-1].type.shape
# Numba doesn't support advanced indexing so index on the flat dimension and reshape
# idx = idx.reshape(core_shape)
# if implicit_arange:
# return idx
# else:
# return a[idx]
if op.has_p_param:
if implicit_arange:
return idx.reshape(core_shape)
else:
return a[idx].reshape(core_shape + a.shape[1:])
@numba_basic.numba_njit
def choice_without_replacement_rv(rng, size, a, p, core_shape):
core_shape = numba_ndarray.to_fixed_tuple(core_shape, core_shape_len)
samples = np.random.choice(a, size=core_shape, replace=False, p=p)
return (rng, samples)
else:
return random_fn
@numba_basic.numba_njit
def choice_without_replacement_rv(rng, size, a, core_shape):
core_shape = numba_ndarray.to_fixed_tuple(core_shape, core_shape_len)
samples = np.random.choice(a, size=core_shape, replace=False)
return (rng, samples)
return choice_without_replacement_rv
@numba_funcify.register
def numba_funcify_RandomVariable_core(op: RandomVariable, **kwargs):
raise RuntimeError(
"It is necessary to replace RandomVariable with RandomVariableWithCoreShape. "
"This is done by the default rewrites during compilation."
)
@numba_funcify.register(ptr.PermutationRV)
def numba_funcify_permutation(op: ptr.PermutationRV, node, **kwargs):
size_is_none = isinstance(op.size_param(node).type, NoneTypeT)
batch_ndim = op.batch_ndim(node)
x_batch_ndim = node.inputs[-1].type.ndim - op.ndims_params[0]
@numba_funcify.register
def numba_funcify_RandomVariable(op: RandomVariableWithCoreShape, node, **kwargs):
core_shape = node.inputs[0]
@numba_basic.numba_njit
def permutation_rv(rng, size, x):
if batch_ndim:
x_core_shape = x.shape[x_batch_ndim:]
if size_is_none:
size = x.shape[:batch_ndim]
else:
size = numba_ndarray.to_fixed_tuple(size, batch_ndim)
x = np.broadcast_to(x, size + x_core_shape)
[rv_node] = op.fgraph.apply_nodes
rv_op: RandomVariable = rv_node.op
rng_param = rv_op.rng_param(rv_node)
if isinstance(rng_param.type, RandomStateType):
raise TypeError("Numba does not support NumPy `RandomStateType`s")
size = rv_op.size_param(rv_node)
dist_params = rv_op.dist_params(rv_node)
size_len = None if isinstance(size.type, NoneTypeT) else get_vector_length(size)
core_shape_len = get_vector_length(core_shape)
inplace = rv_op.inplace
samples = np.empty(size + x_core_shape, dtype=x.dtype)
for index in np.ndindex(size):
samples[index] = np.random.permutation(x[index])
core_rv_fn = numba_core_rv_funcify(rv_op, rv_node)
nin = 1 + len(dist_params) # rng + params
core_op_fn = store_core_outputs(core_rv_fn, nin=nin, nout=1)
else:
samples = np.random.permutation(x)
batch_ndim = rv_op.batch_ndim(rv_node)
# numba doesn't support nested literals right now...
input_bc_patterns = encode_literals(
tuple(input_var.type.broadcastable[:batch_ndim] for input_var in dist_params)
)
output_bc_patterns = encode_literals(
(rv_node.outputs[1].type.broadcastable[:batch_ndim],)
)
output_dtypes = encode_literals((rv_node.default_output().type.dtype,))
inplace_pattern = encode_literals(())
def random_wrapper(core_shape, rng, size, *dist_params):
if not inplace:
rng = copy(rng)
draws = _vectorized(
core_op_fn,
input_bc_patterns,
output_bc_patterns,
output_dtypes,
inplace_pattern,
(rng,),
dist_params,
(numba_ndarray.to_fixed_tuple(core_shape, core_shape_len),),
None if size_len is None else numba_ndarray.to_fixed_tuple(size, size_len),
)
return rng, draws
def random(core_shape, rng, size, *dist_params):
pass
return (rng, samples)
@overload(random, jit_options=_jit_options)
def ov_random(core_shape, rng, size, *dist_params):
return random_wrapper
return permutation_rv
return random
......@@ -58,7 +58,11 @@ def numba_funcify_Scan(op, node, **kwargs):
# TODO: Not sure this is the right place to do this, should we have a rewrite that
# explicitly triggers the optimization of the inner graphs of Scan?
# The C-code defers it to the make_thunk phase
rewriter = op.mode_instance.excluding(*NUMBA._optimizer.exclude).optimizer
rewriter = (
op.mode_instance.including("numba")
.excluding(*NUMBA._optimizer.exclude)
.optimizer
)
rewriter(op.fgraph)
scan_inner_func = numba_basic.numba_njit(numba_funcify(op.fgraph))
......
......@@ -5,6 +5,7 @@ from typing import Any, cast
import numpy as np
from pytensor import config
from pytensor.compile.builders import OpFromGraph
from pytensor.gradient import DisconnectedType
from pytensor.graph.basic import Apply, Constant
from pytensor.graph.null_type import NullType
......@@ -377,3 +378,7 @@ def vectorize_node_fallback(op: Op, node: Apply, *bached_inputs) -> Apply:
_vectorize_node.register(Blockwise, _vectorize_not_needed)
class OpWithCoreShape(OpFromGraph):
"""Generalizes an `Op` to include core shape as an additional input."""
......@@ -2082,10 +2082,7 @@ def choice(a, size=None, replace=True, p=None, rng=None):
# This is equivalent to the numpy implementation:
# https://github.com/numpy/numpy/blob/2a9b9134270371b43223fc848b753fceab96b4a5/numpy/random/_generator.pyx#L905-L914
if p is None:
if rng is not None and isinstance(rng.type, RandomStateType):
idxs = randint(0, a_size, size=size, rng=rng)
else:
idxs = integers(0, a_size, size=size, rng=rng)
idxs = integers(0, a_size, size=size, rng=rng)
else:
idxs = categorical(p, size=size, rng=rng)
......
......@@ -19,6 +19,7 @@ from pytensor.tensor.basic import (
get_vector_length,
infer_static_shape,
)
from pytensor.tensor.blockwise import OpWithCoreShape
from pytensor.tensor.random.type import RandomGeneratorType, RandomStateType, RandomType
from pytensor.tensor.random.utils import (
compute_batch_shape,
......@@ -476,3 +477,11 @@ def vectorize_random_variable(
size = concatenate([new_size_dims, size])
return op.make_node(rng, size, *dist_params)
class RandomVariableWithCoreShape(OpWithCoreShape):
"""Generalizes a random variable `Op` to include a core shape parameter."""
def __str__(self):
[rv_node] = self.fgraph.apply_nodes
return f"[{rv_node.op!s}]"
......@@ -4,7 +4,8 @@ from pytensor.tensor.random.rewriting.basic import *
# isort: off
# Register JAX specializations
# Register Numba and JAX specializations
import pytensor.tensor.random.rewriting.numba
import pytensor.tensor.random.rewriting.jax
# isort: on
from pytensor.compile import optdb
from pytensor.graph import node_rewriter
from pytensor.graph.rewriting.basic import out2in
from pytensor.tensor import as_tensor, constant
from pytensor.tensor.random.op import RandomVariable, RandomVariableWithCoreShape
from pytensor.tensor.rewriting.shape import ShapeFeature
@node_rewriter([RandomVariable])
def introduce_explicit_core_shape_rv(fgraph, node):
"""Introduce the core shape of a RandomVariable.
We wrap RandomVariable graphs into a RandomVariableWithCoreShape OpFromGraph
that has an extra "non-functional" input that represents the core shape of the random variable.
This core_shape is used by the numba backend to pre-allocate the output array.
If available, the core shape is extracted from the shape feature of the graph,
which has a higher change of having been simplified, optimized, constant-folded.
If missing, we fall back to the op._supp_shape_from_params method.
This rewrite is required for the numba backend implementation of RandomVariable.
Example
-------
.. code-block:: python
import pytensor
import pytensor.tensor as pt
x = pt.random.dirichlet(alphas=[1, 2, 3], size=(5,))
pytensor.dprint(x, print_type=True)
# dirichlet_rv{"(a)->(a)"}.1 [id A] <Matrix(float64, shape=(5, 3))>
# ├─ RNG(<Generator(PCG64) at 0x7F09E59C18C0>) [id B] <RandomGeneratorType>
# ├─ [5] [id C] <Vector(int64, shape=(1,))>
# └─ ExpandDims{axis=0} [id D] <Matrix(int64, shape=(1, 3))>
# └─ [1 2 3] [id E] <Vector(int64, shape=(3,))>
# After the rewrite, note the new core shape input [3] [id B]
fn = pytensor.function([], x, mode="NUMBA")
pytensor.dprint(fn.maker.fgraph)
# [dirichlet_rv{"(a)->(a)"}].1 [id A] 0
# ├─ [3] [id B]
# ├─ RNG(<Generator(PCG64) at 0x7F15B8E844A0>) [id C]
# ├─ [5] [id D]
# └─ [[1 2 3]] [id E]
# Inner graphs:
# [dirichlet_rv{"(a)->(a)"}] [id A]
# ← dirichlet_rv{"(a)->(a)"}.0 [id F]
# ├─ *1-<RandomGeneratorType> [id G]
# ├─ *2-<Vector(int64, shape=(1,))> [id H]
# └─ *3-<Matrix(int64, shape=(1, 3))> [id I]
# ← dirichlet_rv{"(a)->(a)"}.1 [id F]
# └─ ···
"""
op: RandomVariable = node.op # type: ignore[annotation-unchecked]
next_rng, rv = node.outputs
shape_feature: ShapeFeature | None = getattr(fgraph, "shape_feature", None) # type: ignore[annotation-unchecked]
if shape_feature:
core_shape = [
shape_feature.get_shape(rv, -i - 1) for i in reversed(range(op.ndim_supp))
]
else:
core_shape = op._supp_shape_from_params(op.dist_params(node))
if len(core_shape) == 0:
core_shape = constant([], dtype="int64")
else:
core_shape = as_tensor(core_shape)
return (
RandomVariableWithCoreShape(
[core_shape, *node.inputs],
node.outputs,
destroy_map={0: [1]} if op.inplace else None,
)
.make_node(core_shape, *node.inputs)
.outputs
)
optdb.register(
introduce_explicit_core_shape_rv.__name__,
out2in(introduce_explicit_core_shape_rv),
"numba",
position=100,
)
......@@ -740,13 +740,13 @@ class UnShapeOptimizer(GraphRewriter):
# Register it after merge1 optimization at 0. We don't want to track
# the shape of merged node.
pytensor.compile.mode.optdb.register( # type: ignore
pytensor.compile.mode.optdb.register(
"ShapeOpt", ShapeOptimizer(), "fast_run", "fast_compile", position=0.1
)
# Not enabled by default for now. Some crossentropy opt use the
# shape_feature. They are at step 2.01. uncanonicalize is at step
# 3. After it goes to 48.5 that move to the gpu. So 10 seems reasonable.
pytensor.compile.mode.optdb.register("UnShapeOpt", UnShapeOptimizer(), position=10) # type: ignore
pytensor.compile.mode.optdb.register("UnShapeOpt", UnShapeOptimizer(), position=10)
def local_reshape_chain(op):
......
......@@ -8,7 +8,6 @@ pytensor/graph/rewriting/basic.py
pytensor/ifelse.py
pytensor/link/basic.py
pytensor/link/numba/dispatch/elemwise.py
pytensor/link/numba/dispatch/random.py
pytensor/link/numba/dispatch/scan.py
pytensor/printing.py
pytensor/raise_op.py
......
......@@ -29,7 +29,6 @@ from pytensor.graph.rewriting.db import RewriteDatabaseQuery
from pytensor.graph.type import Type
from pytensor.ifelse import ifelse
from pytensor.link.numba.dispatch import basic as numba_basic
from pytensor.link.numba.dispatch import numba_typify
from pytensor.link.numba.linker import NumbaLinker
from pytensor.raise_op import assert_op
from pytensor.scalar.basic import ScalarOp, as_scalar
......@@ -120,7 +119,7 @@ my_multi_out.ufunc = MyMultiOut.impl
my_multi_out.ufunc.nin = 2
my_multi_out.ufunc.nout = 2
opts = RewriteDatabaseQuery(include=[None], exclude=["cxx_only", "BlasOpt"])
numba_mode = Mode(NumbaLinker(), opts)
numba_mode = Mode(NumbaLinker(), opts.including("numba"))
py_mode = Mode("py", opts)
rng = np.random.default_rng(42849)
......@@ -229,6 +228,7 @@ def compare_numba_and_py(
numba_mode=numba_mode,
py_mode=py_mode,
updates=None,
eval_obj_mode: bool = True,
) -> tuple[Callable, Any]:
"""Function to compare python graph output and Numba compiled output for testing equality
......@@ -247,6 +247,8 @@ def compare_numba_and_py(
provided uses `np.testing.assert_allclose`.
updates
Updates to be passed to `pytensor.function`.
eval_obj_mode : bool, default True
Whether to do an isolated call in object mode. Used for test coverage
Returns
-------
......@@ -283,7 +285,8 @@ def compare_numba_and_py(
numba_res = pytensor_numba_fn(*inputs)
# Get some coverage
eval_python_only(fn_inputs, fn_outputs, inputs, mode=numba_mode)
if eval_obj_mode:
eval_python_only(fn_inputs, fn_outputs, inputs, mode=numba_mode)
if len(fn_outputs) > 1:
for j, p in zip(numba_res, py_res):
......@@ -359,26 +362,6 @@ def test_create_numba_signature(v, expected, force_scalar):
assert res == expected
@pytest.mark.parametrize(
"input, wrapper_fn, check_fn",
[
(
np.random.RandomState(1),
numba_typify,
lambda x, y: np.all(x.get_state()[1] == y.get_state()[1]),
)
],
)
def test_box_unbox(input, wrapper_fn, check_fn):
input = wrapper_fn(input)
pass_through = numba.njit(lambda x: x)
res = pass_through(input)
assert isinstance(res, type(input))
assert check_fn(res, input)
@pytest.mark.parametrize(
"x, indices",
[
......
......@@ -8,13 +8,13 @@ import scipy.stats as stats
import pytensor.tensor as pt
import pytensor.tensor.random.basic as ptr
from pytensor import shared
from pytensor.compile.builders import OpFromGraph
from pytensor.compile.function import function
from pytensor.compile.sharedvalue import SharedVariable
from pytensor.graph.basic import Constant
from pytensor.graph.fg import FunctionGraph
from tests.link.numba.test_basic import (
compare_numba_and_py,
eval_python_only,
numba_mode,
set_test_value,
)
......@@ -28,37 +28,139 @@ from tests.tensor.random.test_basic import (
rng = np.random.default_rng(42849)
@pytest.mark.xfail(
reason="Most RVs are not working correctly with explicit expand_dims"
)
@pytest.mark.parametrize("mu_shape", [(), (3,), (5, 1)])
@pytest.mark.parametrize("sigma_shape", [(), (1,), (5, 3)])
@pytest.mark.parametrize("size_type", (None, "constant", "mutable"))
def test_random_size(mu_shape, sigma_shape, size_type):
test_value_rng = np.random.default_rng(637)
mu = test_value_rng.normal(size=mu_shape)
sigma = np.exp(test_value_rng.normal(size=sigma_shape))
# For testing
rng = np.random.default_rng(123)
pt_rng = shared(rng)
if size_type is None:
size = None
pt_size = None
elif size_type == "constant":
size = (5, 3)
pt_size = pt.as_tensor(size, dtype="int64")
else:
size = (5, 3)
pt_size = shared(np.array(size, dtype="int64"), shape=(2,))
next_rng, x = pt.random.normal(mu, sigma, rng=pt_rng, size=pt_size).owner.outputs
fn = function([], x, updates={pt_rng: next_rng}, mode="NUMBA")
res1 = fn()
np.testing.assert_allclose(
res1,
rng.normal(mu, sigma, size=size),
)
res2 = fn()
np.testing.assert_allclose(
res2,
rng.normal(mu, sigma, size=size),
)
pt_rng.set_value(np.random.default_rng(123))
res3 = fn()
np.testing.assert_array_equal(res1, res3)
if size_type == "mutable" and len(mu_shape) < 2 and len(sigma_shape) < 2:
pt_size.set_value(np.array((6, 3), dtype="int64"))
res4 = fn()
assert res4.shape == (6, 3)
def test_rng_copy():
rng = shared(np.random.default_rng(123))
x = pt.random.normal(rng=rng)
fn = function([], x, mode="NUMBA")
np.testing.assert_array_equal(fn(), fn())
rng.type.values_eq(rng.get_value(), np.random.default_rng(123))
def test_rng_non_default_update():
rng = shared(np.random.default_rng(1))
rng_new = shared(np.random.default_rng(2))
x = pt.random.normal(size=10, rng=rng)
fn = function([], x, updates={rng: rng_new}, mode=numba_mode)
ref = np.random.default_rng(1).normal(size=10)
np.testing.assert_allclose(fn(), ref)
ref = np.random.default_rng(2).normal(size=10)
np.testing.assert_allclose(fn(), ref)
np.testing.assert_allclose(fn(), ref)
def test_categorical_rv():
"""This is also a smoke test for a vector input scalar output RV"""
p = np.array(
[
[
[1.0, 0, 0, 0],
[0.0, 1.0, 0, 0],
[0.0, 0, 1.0, 0],
],
[
[0, 0, 0, 1.0],
[0, 0, 0, 1.0],
[0, 0, 0, 1.0],
],
]
)
x = pt.random.categorical(p=p, size=None)
updates = {x.owner.inputs[0]: x.owner.outputs[0]}
fn = function([], x, updates=updates, mode="NUMBA")
res = fn()
assert np.all(np.argmax(p, axis=-1) == res)
# Batch size
x = pt.random.categorical(p=p, size=(3, *p.shape[:-1]))
fn = function([], x, updates=updates, mode="NUMBA")
new_res = fn()
assert new_res.shape == (3, *res.shape)
for new_res_row in new_res:
assert np.all(new_res_row == res)
def test_multivariate_normal():
"""This is also a smoke test for a multivariate RV"""
rng = np.random.default_rng(123)
x = pt.random.multivariate_normal(
mean=np.zeros((3, 2)),
cov=np.eye(2),
rng=shared(rng),
)
fn = function([], x, mode="NUMBA")
np.testing.assert_array_equal(
fn(),
rng.multivariate_normal(np.zeros(2), np.eye(2), size=(3,)),
)
@pytest.mark.parametrize(
"rv_op, dist_args, size",
[
(
ptr.normal,
ptr.uniform,
[
set_test_value(
pt.dvector(),
np.array([1.0, 2.0], dtype=np.float64),
),
set_test_value(
pt.dscalar(),
np.array(1.0, dtype=np.float64),
),
],
pt.as_tensor([3, 2]),
),
(
ptr.uniform,
[
set_test_value(
pt.dvector(),
np.array([1.0, 2.0], dtype=np.float64),
),
set_test_value(
pt.dscalar(),
np.array(1.0, dtype=np.float64),
),
],
pt.as_tensor([3, 2]),
),
......@@ -94,7 +196,7 @@ rng = np.random.default_rng(42849)
],
pt.as_tensor([3, 2]),
),
pytest.param(
(
ptr.pareto,
[
set_test_value(
......@@ -107,7 +209,6 @@ rng = np.random.default_rng(42849)
),
],
pt.as_tensor([3, 2]),
marks=pytest.mark.xfail(reason="Not implemented"),
),
(
ptr.exponential,
......@@ -153,7 +254,7 @@ rng = np.random.default_rng(42849)
],
pt.as_tensor([3, 2]),
),
(
pytest.param(
ptr.hypergeometric,
[
set_test_value(
......@@ -170,6 +271,7 @@ rng = np.random.default_rng(42849)
),
],
pt.as_tensor([3, 2]),
marks=pytest.mark.xfail, # Not implemented
),
(
ptr.wald,
......@@ -262,33 +364,70 @@ rng = np.random.default_rng(42849)
None,
),
(
ptr.randint,
ptr.beta,
[
set_test_value(
pt.lscalar(),
np.array(0, dtype=np.int64),
pt.dvector(),
np.array([1.0, 2.0], dtype=np.float64),
),
set_test_value(
pt.lscalar(),
np.array(5, dtype=np.int64),
pt.dscalar(),
np.array(1.0, dtype=np.float64),
),
],
pt.as_tensor([3, 2]),
(2,),
),
pytest.param(
ptr.multivariate_normal,
(
ptr._gamma,
[
set_test_value(
pt.dvector(),
np.array([1.0, 2.0], dtype=np.float64),
),
set_test_value(
pt.dvector(),
np.array([0.5, 3.0], dtype=np.float64),
),
],
(2,),
),
(
ptr.chisquare,
[
set_test_value(
pt.dvector(),
np.array([1.0, 2.0], dtype=np.float64),
)
],
(2,),
),
(
ptr.negative_binomial,
[
set_test_value(
pt.dmatrix(),
np.array([[1, 2], [3, 4]], dtype=np.float64),
pt.lvector(),
np.array([100, 200], dtype=np.int64),
),
set_test_value(
pt.tensor(dtype="float64", shape=(1, None, None)),
np.eye(2)[None, ...],
pt.dscalar(),
np.array(0.09, dtype=np.float64),
),
],
pt.as_tensor(tuple(set_test_value(pt.lscalar(), v) for v in [4, 3, 2])),
marks=pytest.mark.xfail(reason="Not implemented"),
(2,),
),
(
ptr.vonmises,
[
set_test_value(
pt.dvector(),
np.array([-0.5, 0.5], dtype=np.float64),
),
set_test_value(
pt.dscalar(),
np.array(1.0, dtype=np.float64),
),
],
(2,),
),
(
ptr.permutation,
......@@ -312,17 +451,21 @@ rng = np.random.default_rng(42849)
[
set_test_value(pt.dmatrix(), np.eye(3, dtype=np.float64)),
set_test_value(
pt.dvector(), np.array([0.5, 0.0, 0.5], dtype=np.float64)
pt.dvector(), np.array([0.25, 0.5, 0.25], dtype=np.float64)
),
],
(),
(pt.as_tensor([2, 3])),
),
(
pytest.param(
partial(ptr.choice, replace=False),
[
set_test_value(pt.dvector(), np.arange(5, dtype=np.float64)),
],
pt.as_tensor([2]),
marks=pytest.mark.xfail(
AssertionError,
reason="Not aligned with NumPy implementation",
),
),
pytest.param(
partial(ptr.choice, replace=False),
......@@ -331,28 +474,23 @@ rng = np.random.default_rng(42849)
],
pt.as_tensor([2]),
marks=pytest.mark.xfail(
raises=ValueError,
reason="Numba random.choice does not support >=1D `a`",
raises=AssertionError,
reason="Not aligned with NumPy implementation",
),
),
pytest.param(
(
# p must be passed by kwarg
lambda a, p, size, rng: ptr.choice(
a, p=p, size=size, replace=False, rng=rng
),
[
set_test_value(pt.vector(), np.arange(5, dtype=np.float64)),
# Boring p, because the variable is not truly "aligned"
set_test_value(
pt.dvector(),
np.array([0.5, 0.0, 0.25, 0.0, 0.25], dtype=np.float64),
),
],
(),
marks=pytest.mark.xfail(
raises=Exception, # numba.TypeError
reason="Numba random.choice does not support `p` parameter",
),
pt.as_tensor([2]),
),
pytest.param(
# p must be passed by kwarg
......@@ -361,23 +499,31 @@ rng = np.random.default_rng(42849)
),
[
set_test_value(pt.dmatrix(), np.eye(3, dtype=np.float64)),
# Boring p, because the variable is not truly "aligned"
set_test_value(
pt.dvector(), np.array([0.5, 0.0, 0.5], dtype=np.float64)
pt.dvector(), np.array([0.25, 0.5, 0.25], dtype=np.float64)
),
],
(),
marks=pytest.mark.xfail(
raises=ValueError,
reason="Numba random.choice does not support >=1D `a`",
),
pytest.param(
# p must be passed by kwarg
lambda a, p, size, rng: ptr.choice(
a, p=p, size=size, replace=False, rng=rng
),
[
set_test_value(pt.dmatrix(), np.eye(3, dtype=np.float64)),
set_test_value(
pt.dvector(), np.array([0.25, 0.5, 0.25], dtype=np.float64)
),
],
(pt.as_tensor([2, 1])),
),
],
ids=str,
)
def test_aligned_RandomVariable(rv_op, dist_args, size):
"""Tests for Numba samplers that are one-to-one with PyTensor's/NumPy's samplers."""
rng = shared(np.random.RandomState(29402))
rng = shared(np.random.default_rng(29402))
g = rv_op(*dist_args, size=size, rng=rng)
g_fg = FunctionGraph(outputs=[g])
......@@ -388,45 +534,13 @@ def test_aligned_RandomVariable(rv_op, dist_args, size):
for i in g_fg.inputs
if not isinstance(i, SharedVariable | Constant)
],
eval_obj_mode=False, # No python impl
)
@pytest.mark.xfail(reason="Test is not working correctly with explicit expand_dims")
@pytest.mark.parametrize(
"rv_op, dist_args, base_size, cdf_name, params_conv",
[
(
ptr.beta,
[
set_test_value(
pt.dvector(),
np.array([1.0, 2.0], dtype=np.float64),
),
set_test_value(
pt.dscalar(),
np.array(1.0, dtype=np.float64),
),
],
(2,),
"beta",
lambda *args: args,
),
(
ptr._gamma,
[
set_test_value(
pt.dvector(),
np.array([1.0, 2.0], dtype=np.float64),
),
set_test_value(
pt.dvector(),
np.array([0.5, 3.0], dtype=np.float64),
),
],
(2,),
"gamma",
lambda a, b: (a, 0.0, b),
),
(
ptr.cauchy,
[
......@@ -443,18 +557,6 @@ def test_aligned_RandomVariable(rv_op, dist_args, size):
"cauchy",
lambda *args: args,
),
(
ptr.chisquare,
[
set_test_value(
pt.dvector(),
np.array([1.0, 2.0], dtype=np.float64),
)
],
(2,),
"chi2",
lambda *args: args,
),
(
ptr.gumbel,
[
......@@ -471,49 +573,11 @@ def test_aligned_RandomVariable(rv_op, dist_args, size):
"gumbel_r",
lambda *args: args,
),
(
ptr.negative_binomial,
[
set_test_value(
pt.lvector(),
np.array([100, 200], dtype=np.int64),
),
set_test_value(
pt.dscalar(),
np.array(0.09, dtype=np.float64),
),
],
(2,),
"nbinom",
lambda *args: args,
),
pytest.param(
ptr.vonmises,
[
set_test_value(
pt.dvector(),
np.array([-0.5, 0.5], dtype=np.float64),
),
set_test_value(
pt.dscalar(),
np.array(1.0, dtype=np.float64),
),
],
(2,),
"vonmises_line",
lambda mu, kappa: (kappa, mu),
marks=pytest.mark.xfail(
reason=(
"Numba's parameterization of `vonmises` does not match NumPy's."
"See https://github.com/numba/numba/issues/7886"
)
),
),
],
)
def test_unaligned_RandomVariable(rv_op, dist_args, base_size, cdf_name, params_conv):
"""Tests for Numba samplers that are not one-to-one with PyTensor's/NumPy's samplers."""
rng = shared(np.random.RandomState(29402))
rng = shared(np.random.default_rng(29402))
g = rv_op(*dist_args, size=(2000, *base_size), rng=rng)
g_fn = function(dist_args, g, mode=numba_mode)
samples = g_fn(
......@@ -534,78 +598,6 @@ def test_unaligned_RandomVariable(rv_op, dist_args, base_size, cdf_name, params_
assert test_res.pvalue > 0.1
@pytest.mark.parametrize(
"dist_args, size, cm",
[
pytest.param(
[
set_test_value(
pt.dvector(),
np.array([100000, 1, 1], dtype=np.float64),
),
],
None,
contextlib.suppress(),
),
pytest.param(
[
set_test_value(
pt.dmatrix(),
np.array(
[[100000, 1, 1], [1, 100000, 1], [1, 1, 100000]],
dtype=np.float64,
),
),
],
(10, 3),
contextlib.suppress(),
),
pytest.param(
[
set_test_value(
pt.dmatrix(),
np.array(
[[100000, 1, 1]],
dtype=np.float64,
),
),
],
(5, 4, 3),
contextlib.suppress(),
),
pytest.param(
[
set_test_value(
pt.dmatrix(),
np.array(
[[100000, 1, 1], [1, 100000, 1], [1, 1, 100000]],
dtype=np.float64,
),
),
],
(10, 4),
pytest.raises(
ValueError, match="objects cannot be broadcast to a single shape"
),
),
],
)
def test_CategoricalRV(dist_args, size, cm):
rng = shared(np.random.RandomState(29402))
g = ptr.categorical(*dist_args, size=size, rng=rng)
g_fg = FunctionGraph(outputs=[g])
with cm:
compare_numba_and_py(
g_fg,
[
i.tag.test_value
for i in g_fg.inputs
if not isinstance(i, SharedVariable | Constant)
],
)
@pytest.mark.parametrize(
"a, size, cm",
[
......@@ -637,21 +629,21 @@ def test_CategoricalRV(dist_args, size, cm):
),
),
(10, 4),
pytest.raises(ValueError, match="operands could not be broadcast together"),
pytest.raises(
ValueError,
match="Vectorized input 0 has an incompatible shape in axis 1.",
),
),
],
)
def test_DirichletRV(a, size, cm):
rng = shared(np.random.RandomState(29402))
rng = shared(np.random.default_rng(29402))
g = ptr.dirichlet(a, size=size, rng=rng)
g_fn = function([a], g, mode=numba_mode)
with cm:
a_val = a.tag.test_value
# For coverage purposes only...
eval_python_only([a], [g], [a_val])
all_samples = []
for i in range(1000):
samples = g_fn(a_val)
......@@ -662,48 +654,34 @@ def test_DirichletRV(a, size, cm):
assert np.allclose(res, exp_res, atol=1e-4)
@pytest.mark.xfail(reason="RandomState is not aligned with explicit expand_dims")
def test_RandomState_updates():
rng = shared(np.random.RandomState(1))
rng_new = shared(np.random.RandomState(2))
x = pt.random.normal(size=10, rng=rng)
res = function([], x, updates={rng: rng_new}, mode=numba_mode)()
def test_rv_inside_ofg():
rng_np = np.random.default_rng(562)
rng = shared(rng_np)
ref = np.random.RandomState(2).normal(size=10)
assert np.allclose(res, ref)
rng_dummy = rng.type()
next_rng_dummy, rv_dummy = ptr.normal(
0, 1, size=(3, 2), rng=rng_dummy
).owner.outputs
out_dummy = rv_dummy.T
next_rng, out = OpFromGraph([rng_dummy], [next_rng_dummy, out_dummy])(rng)
fn = function([], out, updates={rng: next_rng}, mode=numba_mode)
def test_random_Generator():
rng = shared(np.random.default_rng(29402))
g = ptr.normal(rng=rng)
g_fg = FunctionGraph(outputs=[g])
res1, res2 = fn(), fn()
assert res1.shape == (2, 3)
with pytest.raises(TypeError):
compare_numba_and_py(
g_fg,
[
i.tag.test_value
for i in g_fg.inputs
if not isinstance(i, SharedVariable | Constant)
],
)
np.testing.assert_allclose(res1, rng_np.normal(0, 1, size=(3, 2)).T)
np.testing.assert_allclose(res2, rng_np.normal(0, 1, size=(3, 2)).T)
@pytest.mark.parametrize(
"batch_dims_tester",
[
pytest.param(
batched_unweighted_choice_without_replacement_tester,
marks=pytest.mark.xfail(raises=NotImplementedError),
),
pytest.param(
batched_weighted_choice_without_replacement_tester,
marks=pytest.mark.xfail(raises=NotImplementedError),
),
batched_unweighted_choice_without_replacement_tester,
batched_weighted_choice_without_replacement_tester,
batched_permutation_tester,
],
)
def test_unnatural_batched_dims(batch_dims_tester):
"""Tests for RVs that don't have natural batch dims in Numba API."""
batch_dims_tester(mode="NUMBA", rng_ctor=np.random.RandomState)
batch_dims_tester(mode="NUMBA")
......@@ -77,15 +77,13 @@ from tests.link.numba.test_basic import compare_numba_and_py
),
# nit-sot, shared input/output
(
lambda: RandomStream(seed=1930, rng_ctor=np.random.RandomState).normal(
0, 1, name="a"
),
lambda: RandomStream(seed=1930).normal(0, 1, name="a"),
[],
[{}],
[],
3,
[],
[np.array([-1.63408257, 0.18046406, 2.43265803])],
[np.array([0.50100236, 2.16822932, 1.36326596])],
lambda op: op.info.n_shared_outs > 0,
),
# mit-sot (that's also a type of sit-sot)
......
......@@ -1452,9 +1452,7 @@ def test_permutation_shape():
assert tuple(permutation(np.arange(5), size=(2, 3)).shape.eval()) == (2, 3, 5)
def batched_unweighted_choice_without_replacement_tester(
mode="FAST_RUN", rng_ctor=np.random.default_rng
):
def batched_unweighted_choice_without_replacement_tester(mode="FAST_RUN"):
"""Test unweighted choice without replacement with batched ndims.
This has no corresponding in numpy, but is supported for consistency within the
......@@ -1462,7 +1460,7 @@ def batched_unweighted_choice_without_replacement_tester(
It can be triggered by manual buiding the Op or during automatic vectorization.
"""
rng = shared(rng_ctor())
rng = shared(np.random.default_rng())
# Batched a implicit size
rv_op = ChoiceWithoutReplacement(
......@@ -1499,9 +1497,7 @@ def batched_unweighted_choice_without_replacement_tester(
assert np.all((draw >= i * 10) & (draw < (i + 1) * 10))
def batched_weighted_choice_without_replacement_tester(
mode="FAST_RUN", rng_ctor=np.random.default_rng
):
def batched_weighted_choice_without_replacement_tester(mode="FAST_RUN"):
"""Test weighted choice without replacement with batched ndims.
This has no corresponding in numpy, but is supported for consistency within the
......@@ -1509,7 +1505,7 @@ def batched_weighted_choice_without_replacement_tester(
It can be triggered by manual buiding the Op or during automatic vectorization.
"""
rng = shared(rng_ctor())
rng = shared(np.random.default_rng())
rv_op = ChoiceWithoutReplacement(
signature="(a0,a1),(a0),(1)->(s0,a1)",
......@@ -1574,7 +1570,7 @@ def batched_weighted_choice_without_replacement_tester(
assert np.all((draw >= i * 10 + 2) & (draw < (i + 1) * 10))
def batched_permutation_tester(mode="FAST_RUN", rng_ctor=np.random.default_rng):
def batched_permutation_tester(mode="FAST_RUN"):
"""Test permutation with batched ndims.
This has no corresponding in numpy, but is supported for consistency within the
......@@ -1583,7 +1579,7 @@ def batched_permutation_tester(mode="FAST_RUN", rng_ctor=np.random.default_rng):
It can be triggered by manual buiding the Op or during automatic vectorization.
"""
rng = shared(rng_ctor())
rng = shared(np.random.default_rng())
rv_op = PermutationRV(ndim_supp=2, ndims_params=[2], dtype="int64")
x = np.arange(5 * 3 * 2).reshape((5, 3, 2))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论