提交 14da898c authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Add support for RandomVariable with Generators in Numba backend and drop support for RandomState

上级 47874eb9
...@@ -27,7 +27,6 @@ from pytensor.graph.op import HasInnerGraph, Op ...@@ -27,7 +27,6 @@ from pytensor.graph.op import HasInnerGraph, Op
from pytensor.graph.replace import clone_replace from pytensor.graph.replace import clone_replace
from pytensor.graph.rewriting.basic import in2out, node_rewriter from pytensor.graph.rewriting.basic import in2out, node_rewriter
from pytensor.graph.utils import MissingInputError from pytensor.graph.utils import MissingInputError
from pytensor.tensor.rewriting.shape import ShapeFeature
def infer_shape(outs, inputs, input_shapes): def infer_shape(outs, inputs, input_shapes):
...@@ -43,6 +42,10 @@ def infer_shape(outs, inputs, input_shapes): ...@@ -43,6 +42,10 @@ def infer_shape(outs, inputs, input_shapes):
# inside. We don't use the full ShapeFeature interface, but we # inside. We don't use the full ShapeFeature interface, but we
# let it initialize itself with an empty fgraph, otherwise we will # let it initialize itself with an empty fgraph, otherwise we will
# need to do it manually # need to do it manually
# TODO: ShapeFeature should live elsewhere
from pytensor.tensor.rewriting.shape import ShapeFeature
for inp, inp_shp in zip(inputs, input_shapes): for inp, inp_shp in zip(inputs, input_shapes):
if inp_shp is not None and len(inp_shp) != inp.type.ndim: if inp_shp is not None and len(inp_shp) != inp.type.ndim:
assert len(inp_shp) == inp.type.ndim assert len(inp_shp) == inp.type.ndim
...@@ -307,6 +310,7 @@ class OpFromGraph(Op, HasInnerGraph): ...@@ -307,6 +310,7 @@ class OpFromGraph(Op, HasInnerGraph):
connection_pattern: list[list[bool]] | None = None, connection_pattern: list[list[bool]] | None = None,
strict: bool = False, strict: bool = False,
name: str | None = None, name: str | None = None,
destroy_map: dict[int, tuple[int, ...]] | None = None,
**kwargs, **kwargs,
): ):
""" """
...@@ -464,6 +468,7 @@ class OpFromGraph(Op, HasInnerGraph): ...@@ -464,6 +468,7 @@ class OpFromGraph(Op, HasInnerGraph):
if name is not None: if name is not None:
assert isinstance(name, str), "name must be None or string object" assert isinstance(name, str), "name must be None or string object"
self.name = name self.name = name
self.destroy_map = destroy_map if destroy_map is not None else {}
def __eq__(self, other): def __eq__(self, other):
# TODO: recognize a copy # TODO: recognize a copy
...@@ -862,6 +867,7 @@ class OpFromGraph(Op, HasInnerGraph): ...@@ -862,6 +867,7 @@ class OpFromGraph(Op, HasInnerGraph):
rop_overrides=self.rop_overrides, rop_overrides=self.rop_overrides,
connection_pattern=self._connection_pattern, connection_pattern=self._connection_pattern,
name=self.name, name=self.name,
destroy_map=self.destroy_map,
**self.kwargs, **self.kwargs,
) )
new_inputs = ( new_inputs = (
......
...@@ -463,7 +463,7 @@ JAX = Mode( ...@@ -463,7 +463,7 @@ JAX = Mode(
NUMBA = Mode( NUMBA = Mode(
NumbaLinker(), NumbaLinker(),
RewriteDatabaseQuery( RewriteDatabaseQuery(
include=["fast_run"], include=["fast_run", "numba"],
exclude=["cxx_only", "BlasOpt", "local_careduce_fusion"], exclude=["cxx_only", "BlasOpt", "local_careduce_fusion"],
), ),
) )
......
...@@ -18,6 +18,7 @@ from numba.cpython.unsafe.tuple import tuple_setitem # noqa: F401 ...@@ -18,6 +18,7 @@ from numba.cpython.unsafe.tuple import tuple_setitem # noqa: F401
from numba.extending import box, overload from numba.extending import box, overload
from pytensor import config from pytensor import config
from pytensor.compile import NUMBA
from pytensor.compile.builders import OpFromGraph from pytensor.compile.builders import OpFromGraph
from pytensor.compile.ops import DeepCopyOp from pytensor.compile.ops import DeepCopyOp
from pytensor.graph.basic import Apply from pytensor.graph.basic import Apply
...@@ -440,6 +441,11 @@ def numba_funcify(op, node=None, storage_map=None, **kwargs): ...@@ -440,6 +441,11 @@ def numba_funcify(op, node=None, storage_map=None, **kwargs):
def numba_funcify_OpFromGraph(op, node=None, **kwargs): def numba_funcify_OpFromGraph(op, node=None, **kwargs):
_ = kwargs.pop("storage_map", None) _ = kwargs.pop("storage_map", None)
# Apply inner rewrites
# TODO: Not sure this is the right place to do this, should we have a rewrite that
# explicitly triggers the optimization of the inner graphs of OpFromGraph?
# The C-code defers it to the make_thunk phase
NUMBA.optimizer(op.fgraph)
fgraph_fn = numba_njit(numba_funcify(op.fgraph, **kwargs)) fgraph_fn = numba_njit(numba_funcify(op.fgraph, **kwargs))
if len(op.fgraph.outputs) == 1: if len(op.fgraph.outputs) == 1:
......
from collections.abc import Callable from collections.abc import Callable
from textwrap import dedent, indent from copy import copy
from typing import Any from functools import singledispatch
from textwrap import dedent
import numba
import numba.np.unsafe.ndarray as numba_ndarray import numba.np.unsafe.ndarray as numba_ndarray
import numpy as np import numpy as np
from numba import _helperlib, types from numba import types
from numba.core import cgutils from numba.core.extending import overload
from numba.extending import NativeValue, box, models, register_model, typeof_impl, unbox
from numpy.random import RandomState
import pytensor.tensor.random.basic as ptr import pytensor.tensor.random.basic as ptr
from pytensor.graph.basic import Apply from pytensor.graph import Apply
from pytensor.graph.op import Op from pytensor.graph.op import Op
from pytensor.link.numba.dispatch import basic as numba_basic from pytensor.link.numba.dispatch import basic as numba_basic
from pytensor.link.numba.dispatch.basic import numba_funcify, numba_typify from pytensor.link.numba.dispatch.basic import direct_cast, numba_funcify
from pytensor.link.numba.dispatch.vectorize_codegen import (
_jit_options,
_vectorized,
encode_literals,
store_core_outputs,
)
from pytensor.link.utils import ( from pytensor.link.utils import (
compile_function_src, compile_function_src,
get_name_for_object,
unique_name_generator,
) )
from pytensor.tensor.basic import get_vector_length from pytensor.tensor import get_vector_length
from pytensor.tensor.random.op import RandomVariable, RandomVariableWithCoreShape
from pytensor.tensor.random.type import RandomStateType from pytensor.tensor.random.type import RandomStateType
from pytensor.tensor.type_other import NoneTypeT from pytensor.tensor.type_other import NoneTypeT
from pytensor.tensor.utils import _parse_gufunc_signature
@overload(copy)
def copy_NumPyRandomGenerator(rng):
def impl(rng):
# TODO: Open issue on Numba?
with numba.objmode(new_rng=types.npy_rng):
new_rng = copy(rng)
return new_rng
return impl
@singledispatch
def numba_core_rv_funcify(op: Op, node: Apply) -> Callable:
"""Return the core function for a random variable operation."""
raise NotImplementedError(f"Core implementation of {op} not implemented.")
@numba_core_rv_funcify.register(ptr.UniformRV)
@numba_core_rv_funcify.register(ptr.TriangularRV)
@numba_core_rv_funcify.register(ptr.BetaRV)
@numba_core_rv_funcify.register(ptr.NormalRV)
@numba_core_rv_funcify.register(ptr.LogNormalRV)
@numba_core_rv_funcify.register(ptr.GammaRV)
@numba_core_rv_funcify.register(ptr.ExponentialRV)
@numba_core_rv_funcify.register(ptr.WeibullRV)
@numba_core_rv_funcify.register(ptr.LogisticRV)
@numba_core_rv_funcify.register(ptr.VonMisesRV)
@numba_core_rv_funcify.register(ptr.PoissonRV)
@numba_core_rv_funcify.register(ptr.GeometricRV)
# @numba_core_rv_funcify.register(ptr.HyperGeometricRV) # Not implemented in numba
@numba_core_rv_funcify.register(ptr.WaldRV)
@numba_core_rv_funcify.register(ptr.LaplaceRV)
@numba_core_rv_funcify.register(ptr.BinomialRV)
@numba_core_rv_funcify.register(ptr.NegBinomialRV)
@numba_core_rv_funcify.register(ptr.MultinomialRV)
@numba_core_rv_funcify.register(ptr.PermutationRV)
@numba_core_rv_funcify.register(ptr.IntegersRV)
def numba_core_rv_default(op, node):
"""Create a default RV core numba function.
@njit
def random(rng, i0, i1, ..., in):
return rng.name(i0, i1, ..., in)
"""
name = op.name
inputs = [f"i{i}" for i in range(len(op.ndims_params))]
input_signature = ",".join(inputs)
class RandomStateNumbaType(types.Type): func_src = dedent(f"""
def __init__(self): def {name}(rng, {input_signature}):
super().__init__(name="RandomState") return rng.{name}({input_signature})
""")
random_state_numba_type = RandomStateNumbaType()
@typeof_impl.register(RandomState)
def typeof_index(val, c):
return random_state_numba_type
@register_model(RandomStateNumbaType)
class RandomStateNumbaModel(models.StructModel):
def __init__(self, dmm, fe_type):
members = [
# TODO: We can add support for boxing and unboxing
# the attributes that describe a RandomState so that
# they can be accessed inside njit functions, if required.
("state_key", types.Array(types.uint32, 1, "C")),
]
models.StructModel.__init__(self, dmm, fe_type, members)
func = compile_function_src(func_src, name, {**globals()})
return numba_basic.numba_njit(func)
@unbox(RandomStateNumbaType)
def unbox_random_state(typ, obj, c):
"""Convert a `RandomState` object to a native `RandomStateNumbaModel` structure.
Note that this will create a 'fake' structure which will just get the @numba_core_rv_funcify.register(ptr.BernoulliRV)
`RandomState` objects accepted in Numba functions but the actual information def numba_core_BernoulliRV(op, node):
of the Numba's random state is stored internally and can be accessed out_dtype = node.outputs[1].type.numpy_dtype
anytime using ``numba._helperlib.rnd_get_np_state_ptr()``.
"""
interval = cgutils.create_struct_proxy(typ)(c.context, c.builder)
is_error = cgutils.is_not_null(c.builder, c.pyapi.err_occurred())
return NativeValue(interval._getvalue(), is_error=is_error)
@numba_basic.numba_njit()
def random(rng, p):
return (
direct_cast(0, out_dtype)
if p < rng.uniform()
else direct_cast(1, out_dtype)
)
@box(RandomStateNumbaType) return random
def box_random_state(typ, val, c):
"""Convert a native `RandomStateNumbaModel` structure to an `RandomState` object
using Numba's internal state array.
Note that `RandomStateNumbaModel` is just a placeholder structure with no
inherent information about Numba internal random state, all that information
is instead retrieved from Numba using ``_helperlib.rnd_get_state()`` and a new
`RandomState` is constructed using the Numba's current internal state.
"""
pos, state_list = _helperlib.rnd_get_state(_helperlib.rnd_get_np_state_ptr())
rng = RandomState()
rng.set_state(("MT19937", state_list, pos))
class_obj = c.pyapi.unserialize(c.pyapi.serialize_object(rng))
return class_obj
@numba_core_rv_funcify.register(ptr.HalfNormalRV)
def numba_core_HalfNormalRV(op, node):
@numba_basic.numba_njit
def random_fn(rng, loc, scale):
return loc + scale * np.abs(rng.standard_normal())
@numba_typify.register(RandomState) return random_fn
def numba_typify_RandomState(state, **kwargs):
# The numba_typify in this case is just an passthrough function
# that synchronizes Numba's internal random state with the current
# RandomState object
ints, index = state.get_state()[1:3]
ptr = _helperlib.rnd_get_np_state_ptr()
_helperlib.rnd_set_state(ptr, (index, [int(x) for x in ints]))
return state
def make_numba_random_fn(node, np_random_func): @numba_core_rv_funcify.register(ptr.CauchyRV)
"""Create Numba implementations for existing Numba-supported ``np.random`` functions. def numba_core_CauchyRV(op, node):
@numba_basic.numba_njit
def random(rng, loc, scale):
return (loc + rng.standard_cauchy()) / scale
The functions generated here add parameter broadcasting and the ``size`` return random
argument to the Numba-supported scalar ``np.random`` functions.
"""
op: ptr.RandomVariable = node.op
rng_param = op.rng_param(node)
if not isinstance(rng_param.type, RandomStateType):
raise TypeError("Numba does not support NumPy `Generator`s")
size_param = op.size_param(node)
size_len = (
None
if isinstance(size_param.type, NoneTypeT)
else int(get_vector_length(size_param))
)
dist_params = op.dist_params(node)
# Make a broadcast-capable version of the Numba supported scalar sampling
# function
bcast_fn_name = f"pytensor_random_{get_name_for_object(np_random_func)}"
sized_fn_name = "sized_random_variable"
unique_names = unique_name_generator(
[
bcast_fn_name,
sized_fn_name,
"np",
"np_random_func",
"numba_vectorize",
"to_fixed_tuple",
"size_len",
"size_dims",
"rng",
"size",
],
suffix_sep="_",
)
bcast_fn_input_names = ", ".join(
[unique_names(i, force_unique=True) for i in dist_params]
)
bcast_fn_global_env = {
"np_random_func": np_random_func,
"numba_vectorize": numba_basic.numba_vectorize,
}
bcast_fn_src = f"""
@numba_vectorize
def {bcast_fn_name}({bcast_fn_input_names}):
return np_random_func({bcast_fn_input_names})
"""
bcast_fn = compile_function_src(
bcast_fn_src, bcast_fn_name, {**globals(), **bcast_fn_global_env}
)
random_fn_input_names = ", ".join( @numba_core_rv_funcify.register(ptr.ParetoRV)
["rng", "size"] + [unique_names(i) for i in dist_params] def numba_core_ParetoRV(op, node):
) @numba_basic.numba_njit
def random(rng, b, scale):
# Follows scipy implementation
U = rng.random()
return np.power(1 - U, -1 / b) * scale
# Now, create a Numba JITable function that implements the `size` parameter return random
out_dtype = node.outputs[1].type.numpy_dtype
random_fn_global_env = {
bcast_fn_name: bcast_fn,
"out_dtype": out_dtype,
}
if size_len is not None:
size_dims = size_len - max(i.ndim for i in dist_params)
random_fn_body = dedent( @numba_core_rv_funcify.register(ptr.CategoricalRV)
f""" def core_CategoricalRV(op, node):
size = to_fixed_tuple(size, size_len) @numba_basic.numba_njit
def random_fn(rng, p):
unif_sample = rng.uniform(0, 1)
return np.searchsorted(np.cumsum(p), unif_sample)
data = np.empty(size, dtype=out_dtype) return random_fn
for i in np.ndindex(size[:size_dims]):
data[i] = {bcast_fn_name}({bcast_fn_input_names})
"""
)
random_fn_global_env.update(
{
"np": np,
"to_fixed_tuple": numba_ndarray.to_fixed_tuple,
"size_len": size_len,
"size_dims": size_dims,
}
)
else:
random_fn_body = f"""data = {bcast_fn_name}({bcast_fn_input_names})"""
sized_fn_src = dedent( @numba_core_rv_funcify.register(ptr.MvNormalRV)
f""" def core_MvNormalRV(op, node):
def {sized_fn_name}({random_fn_input_names}): @numba_basic.numba_njit
{indent(random_fn_body, " " * 4)} def random_fn(rng, mean, cov):
return (rng, data) chol = np.linalg.cholesky(cov)
""" stdnorm = rng.normal(size=cov.shape[-1])
) return np.dot(chol, stdnorm) + mean
random_fn = compile_function_src(
sized_fn_src, sized_fn_name, {**globals(), **random_fn_global_env}
)
random_fn = numba_basic.numba_njit(random_fn)
random_fn.handles_out = True
return random_fn return random_fn
@numba_funcify.register(ptr.UniformRV) @numba_core_rv_funcify.register(ptr.DirichletRV)
@numba_funcify.register(ptr.TriangularRV) def core_DirichletRV(op, node):
@numba_funcify.register(ptr.BetaRV) @numba_basic.numba_njit
@numba_funcify.register(ptr.NormalRV) def random_fn(rng, alpha):
@numba_funcify.register(ptr.LogNormalRV) y = np.empty_like(alpha)
@numba_funcify.register(ptr.GammaRV) for i in range(len(alpha)):
@numba_funcify.register(ptr.ParetoRV) y[i] = rng.gamma(alpha[i], 1.0)
@numba_funcify.register(ptr.GumbelRV) return y / y.sum()
@numba_funcify.register(ptr.ExponentialRV)
@numba_funcify.register(ptr.WeibullRV)
@numba_funcify.register(ptr.LogisticRV)
@numba_funcify.register(ptr.VonMisesRV)
@numba_funcify.register(ptr.PoissonRV)
@numba_funcify.register(ptr.GeometricRV)
@numba_funcify.register(ptr.HyperGeometricRV)
@numba_funcify.register(ptr.WaldRV)
@numba_funcify.register(ptr.LaplaceRV)
@numba_funcify.register(ptr.BinomialRV)
@numba_funcify.register(ptr.MultinomialRV)
@numba_funcify.register(ptr.RandIntRV) # only the first two arguments are supported
@numba_funcify.register(ptr.PermutationRV)
def numba_funcify_RandomVariable(op, node, **kwargs):
name = op.name
np_random_func = getattr(np.random, name)
return make_numba_random_fn(node, np_random_func) return random_fn
def create_numba_random_fn( @numba_core_rv_funcify.register(ptr.GumbelRV)
op: Op, def core_GumbelRV(op, node):
node: Apply, """Code adapted from Numpy Implementation
scalar_fn: Callable[[str], str],
global_env: dict[str, Any] | None = None,
) -> Callable:
"""Create a vectorized function from a callable that generates the ``str`` function body.
TODO: This could/should be generalized for other simple function https://github.com/numpy/numpy/blob/6f6be042c6208815b15b90ba87d04159bfa25fd3/numpy/random/src/distributions/distributions.c#L502-L511
construction cases that need unique-ified symbol names.
""" """
np_random_fn_name = f"pytensor_random_{get_name_for_object(op.name)}"
if global_env: @numba_basic.numba_njit
np_global_env = global_env.copy() def random_fn(rng, loc, scale):
U = 1.0 - rng.random()
if U < 1.0:
return loc - scale * np.log(-np.log(U))
else: else:
np_global_env = {} return random_fn(rng, loc, scale)
np_global_env["np"] = np
np_global_env["numba_vectorize"] = numba_basic.numba_vectorize
unique_names = unique_name_generator(
[np_random_fn_name, *np_global_env.keys(), "rng", "size"],
suffix_sep="_",
)
dist_params = op.dist_params(node)
np_names = [unique_names(i, force_unique=True) for i in dist_params]
np_input_names = ", ".join(np_names)
np_random_fn_src = f"""
@numba_vectorize
def {np_random_fn_name}({np_input_names}):
{scalar_fn(*np_names)}
"""
np_random_fn = compile_function_src(
np_random_fn_src, np_random_fn_name, {**globals(), **np_global_env}
)
return make_numba_random_fn(node, np_random_fn)
return random_fn
@numba_funcify.register(ptr.NegBinomialRV)
def numba_funcify_NegBinomialRV(op, node, **kwargs):
return make_numba_random_fn(node, np.random.negative_binomial)
@numba_funcify.register(ptr.CauchyRV)
def numba_funcify_CauchyRV(op, node, **kwargs):
def body_fn(loc, scale):
return f" return ({loc} + np.random.standard_cauchy()) / {scale}"
return create_numba_random_fn(op, node, body_fn)
@numba_funcify.register(ptr.HalfNormalRV)
def numba_funcify_HalfNormalRV(op, node, **kwargs):
def body_fn(a, b):
return f" return {a} + {b} * abs(np.random.normal(0, 1))"
return create_numba_random_fn(op, node, body_fn)
@numba_funcify.register(ptr.BernoulliRV) @numba_core_rv_funcify.register(ptr.VonMisesRV)
def numba_funcify_BernoulliRV(op, node, **kwargs): def core_VonMisesRV(op, node):
out_dtype = node.outputs[1].type.numpy_dtype """Code adapted from Numpy Implementation
def body_fn(a): https://github.com/numpy/numpy/blob/6f6be042c6208815b15b90ba87d04159bfa25fd3/numpy/random/src/distributions/distributions.c#L855-L925
return f"""
if {a} < np.random.uniform(0, 1):
return direct_cast(0, out_dtype)
else:
return direct_cast(1, out_dtype)
""" """
return create_numba_random_fn(
op,
node,
body_fn,
{"out_dtype": out_dtype, "direct_cast": numba_basic.direct_cast},
)
@numba_funcify.register(ptr.CategoricalRV)
def numba_funcify_CategoricalRV(op: ptr.CategoricalRV, node, **kwargs):
out_dtype = node.outputs[1].type.numpy_dtype
size_param = op.size_param(node)
size_len = (
None
if isinstance(size_param.type, NoneTypeT)
else int(get_vector_length(size_param))
)
p_ndim = node.inputs[-1].ndim
@numba_basic.numba_njit @numba_basic.numba_njit
def categorical_rv(rng, size, p): def random_fn(rng, mu, kappa):
if size_len is None: if np.isnan(kappa):
size_tpl = p.shape[:-1] return np.nan
if kappa < 1e-8:
# Use a uniform for very small values of kappa
return np.pi * (2 * rng.random() - 1)
else: else:
size_tpl = numba_ndarray.to_fixed_tuple(size, size_len) # with double precision rho is zero until 1.4e-8
p = np.broadcast_to(p, size_tpl + p.shape[-1:]) if kappa < 1e-5:
# second order taylor expansion around kappa = 0
# Workaround https://github.com/numba/numba/issues/8975 # precise until relatively large kappas as second order is 0
if size_len is None and p_ndim == 1: s = 1.0 / kappa + kappa
unif_samples = np.asarray(np.random.uniform(0, 1))
else: else:
unif_samples = np.random.uniform(0, 1, size_tpl) if kappa <= 1e6:
# Path for 1e-5 <= kappa <= 1e6
res = np.empty(size_tpl, dtype=out_dtype) r = 1 + np.sqrt(1 + 4 * kappa * kappa)
for idx in np.ndindex(*size_tpl): rho = (r - np.sqrt(2 * r)) / (2 * kappa)
res[idx] = np.searchsorted(np.cumsum(p[idx]), unif_samples[idx]) s = (1 + rho * rho) / (2 * rho)
else:
return (rng, res) # Fallback to wrapped normal distribution for kappa > 1e6
result = mu + np.sqrt(1.0 / kappa) * rng.standard_normal()
# Ensure result is within bounds
if result < -np.pi:
result += 2 * np.pi
if result > np.pi:
result -= 2 * np.pi
return result
while True:
U = rng.random()
Z = np.cos(np.pi * U)
W = (1 + s * Z) / (s + Z)
Y = kappa * (s - W)
V = rng.random()
# V == 0.0 is ok here since Y >= 0 always leads
# to accept, while Y < 0 always rejects
if (Y * (2 - Y) - V >= 0) or (np.log(Y / V) + 1 - Y >= 0):
break
U = rng.random()
result = np.arccos(W)
if U < 0.5:
result = -result
result += mu
neg = result < 0
mod = np.abs(result)
mod = np.mod(mod + np.pi, 2 * np.pi) - np.pi
if neg:
mod *= -1
return mod
return categorical_rv return random_fn
@numba_funcify.register(ptr.DirichletRV) @numba_core_rv_funcify.register(ptr.ChoiceWithoutReplacement)
def numba_funcify_DirichletRV(op, node, **kwargs): def core_ChoiceWithoutReplacement(op: ptr.ChoiceWithoutReplacement, node):
out_dtype = node.outputs[1].type.numpy_dtype [core_shape_len_sig] = _parse_gufunc_signature(op.signature)[0][-1]
alphas_ndim = op.dist_params(node)[0].type.ndim core_shape_len = int(core_shape_len_sig)
size_param = op.size_param(node) implicit_arange = op.ndims_params[0] == 0
size_len = (
None
if isinstance(size_param.type, NoneTypeT)
else int(get_vector_length(size_param))
)
if alphas_ndim > 1: if op.has_p_param:
@numba_basic.numba_njit @numba_basic.numba_njit
def dirichlet_rv(rng, size, alphas): def random_fn(rng, a, p, core_shape):
if size_len is None: # Adapted from Numpy: https://github.com/numpy/numpy/blob/2a9b9134270371b43223fc848b753fceab96b4a5/numpy/random/_generator.pyx#L922-L941
samples_shape = alphas.shape size = np.prod(core_shape)
core_shape = numba_ndarray.to_fixed_tuple(core_shape, core_shape_len)
if implicit_arange:
pop_size = a
else: else:
size_tpl = numba_ndarray.to_fixed_tuple(size, size_len) pop_size = a.shape[0]
samples_shape = size_tpl + alphas.shape[-1:]
res = np.empty(samples_shape, dtype=out_dtype) if size > pop_size:
alphas_bcast = np.broadcast_to(alphas, samples_shape) raise ValueError(
"Cannot take a larger sample than population without replacement"
)
if np.count_nonzero(p > 0) < size:
raise ValueError("Fewer non-zero entries in p than size")
p = p.copy()
n_uniq = 0
idx = np.zeros(core_shape, dtype=np.int64)
flat_idx = idx.ravel()
while n_uniq < size:
x = rng.random((size - n_uniq,))
# Set the probabilities of items that have already been found to 0
p[flat_idx[:n_uniq]] = 0
# Take new (unique) categorical draws from the remaining probabilities
cdf = np.cumsum(p)
cdf /= cdf[-1]
new = np.searchsorted(cdf, x, side="right")
# Numba doesn't support return_index in np.unique
# _, unique_indices = np.unique(new, return_index=True)
# unique_indices.sort()
new.sort()
unique_indices = [
idx
for idx, prev_item in enumerate(new[:-1], 1)
if new[idx] != prev_item
]
unique_indices = np.array([0] + unique_indices) # noqa: RUF005
for index in np.ndindex(*samples_shape[:-1]): new = new[unique_indices]
res[index] = np.random.dirichlet(alphas_bcast[index]) flat_idx[n_uniq : n_uniq + new.size] = new
n_uniq += new.size
return (rng, res) if implicit_arange:
return idx
else:
# Numba doesn't support advanced indexing, so we ravel index and reshape
return a[idx.ravel()].reshape(core_shape + a.shape[1:])
else: else:
@numba_basic.numba_njit @numba_basic.numba_njit
def dirichlet_rv(rng, size, alphas): def random_fn(rng, a, core_shape):
if size_len is not None: # Until Numba supports generator.choice we use a poor implementation
size = numba_ndarray.to_fixed_tuple(size, size_len) # that permutates the whole arange array and takes the first `size` elements
return (rng, np.random.dirichlet(alphas, size)) # This is widely inefficient when size << a.shape[0]
size = np.prod(core_shape)
return dirichlet_rv core_shape = numba_ndarray.to_fixed_tuple(core_shape, core_shape_len)
idx = rng.permutation(size)[:size]
@numba_funcify.register(ptr.ChoiceWithoutReplacement)
def numba_funcify_choice_without_replacement(op, node, **kwargs):
batch_ndim = op.batch_ndim(node)
if batch_ndim:
# The code isn't too hard to write, but Numba doesn't support a with ndim > 1,
# and I don't want to change the batched tests for this
# We'll just raise an error for now
raise NotImplementedError(
"ChoiceWithoutReplacement with batch_ndim not supported in Numba backend"
)
[core_shape_len] = node.inputs[-1].type.shape
if op.has_p_param: # Numba doesn't support advanced indexing so index on the flat dimension and reshape
# idx = idx.reshape(core_shape)
# if implicit_arange:
# return idx
# else:
# return a[idx]
@numba_basic.numba_njit if implicit_arange:
def choice_without_replacement_rv(rng, size, a, p, core_shape): return idx.reshape(core_shape)
core_shape = numba_ndarray.to_fixed_tuple(core_shape, core_shape_len)
samples = np.random.choice(a, size=core_shape, replace=False, p=p)
return (rng, samples)
else: else:
return a[idx].reshape(core_shape + a.shape[1:])
@numba_basic.numba_njit return random_fn
def choice_without_replacement_rv(rng, size, a, core_shape):
core_shape = numba_ndarray.to_fixed_tuple(core_shape, core_shape_len)
samples = np.random.choice(a, size=core_shape, replace=False)
return (rng, samples)
return choice_without_replacement_rv
@numba_funcify.register
def numba_funcify_RandomVariable_core(op: RandomVariable, **kwargs):
raise RuntimeError(
"It is necessary to replace RandomVariable with RandomVariableWithCoreShape. "
"This is done by the default rewrites during compilation."
)
@numba_funcify.register(ptr.PermutationRV)
def numba_funcify_permutation(op: ptr.PermutationRV, node, **kwargs):
size_is_none = isinstance(op.size_param(node).type, NoneTypeT)
batch_ndim = op.batch_ndim(node)
x_batch_ndim = node.inputs[-1].type.ndim - op.ndims_params[0]
@numba_basic.numba_njit @numba_funcify.register
def permutation_rv(rng, size, x): def numba_funcify_RandomVariable(op: RandomVariableWithCoreShape, node, **kwargs):
if batch_ndim: core_shape = node.inputs[0]
x_core_shape = x.shape[x_batch_ndim:]
if size_is_none:
size = x.shape[:batch_ndim]
else:
size = numba_ndarray.to_fixed_tuple(size, batch_ndim)
x = np.broadcast_to(x, size + x_core_shape)
samples = np.empty(size + x_core_shape, dtype=x.dtype) [rv_node] = op.fgraph.apply_nodes
for index in np.ndindex(size): rv_op: RandomVariable = rv_node.op
samples[index] = np.random.permutation(x[index]) rng_param = rv_op.rng_param(rv_node)
if isinstance(rng_param.type, RandomStateType):
raise TypeError("Numba does not support NumPy `RandomStateType`s")
size = rv_op.size_param(rv_node)
dist_params = rv_op.dist_params(rv_node)
size_len = None if isinstance(size.type, NoneTypeT) else get_vector_length(size)
core_shape_len = get_vector_length(core_shape)
inplace = rv_op.inplace
else: core_rv_fn = numba_core_rv_funcify(rv_op, rv_node)
samples = np.random.permutation(x) nin = 1 + len(dist_params) # rng + params
core_op_fn = store_core_outputs(core_rv_fn, nin=nin, nout=1)
batch_ndim = rv_op.batch_ndim(rv_node)
# numba doesn't support nested literals right now...
input_bc_patterns = encode_literals(
tuple(input_var.type.broadcastable[:batch_ndim] for input_var in dist_params)
)
output_bc_patterns = encode_literals(
(rv_node.outputs[1].type.broadcastable[:batch_ndim],)
)
output_dtypes = encode_literals((rv_node.default_output().type.dtype,))
inplace_pattern = encode_literals(())
def random_wrapper(core_shape, rng, size, *dist_params):
if not inplace:
rng = copy(rng)
draws = _vectorized(
core_op_fn,
input_bc_patterns,
output_bc_patterns,
output_dtypes,
inplace_pattern,
(rng,),
dist_params,
(numba_ndarray.to_fixed_tuple(core_shape, core_shape_len),),
None if size_len is None else numba_ndarray.to_fixed_tuple(size, size_len),
)
return rng, draws
def random(core_shape, rng, size, *dist_params):
pass
return (rng, samples) @overload(random, jit_options=_jit_options)
def ov_random(core_shape, rng, size, *dist_params):
return random_wrapper
return permutation_rv return random
...@@ -58,7 +58,11 @@ def numba_funcify_Scan(op, node, **kwargs): ...@@ -58,7 +58,11 @@ def numba_funcify_Scan(op, node, **kwargs):
# TODO: Not sure this is the right place to do this, should we have a rewrite that # TODO: Not sure this is the right place to do this, should we have a rewrite that
# explicitly triggers the optimization of the inner graphs of Scan? # explicitly triggers the optimization of the inner graphs of Scan?
# The C-code defers it to the make_thunk phase # The C-code defers it to the make_thunk phase
rewriter = op.mode_instance.excluding(*NUMBA._optimizer.exclude).optimizer rewriter = (
op.mode_instance.including("numba")
.excluding(*NUMBA._optimizer.exclude)
.optimizer
)
rewriter(op.fgraph) rewriter(op.fgraph)
scan_inner_func = numba_basic.numba_njit(numba_funcify(op.fgraph)) scan_inner_func = numba_basic.numba_njit(numba_funcify(op.fgraph))
......
...@@ -5,6 +5,7 @@ from typing import Any, cast ...@@ -5,6 +5,7 @@ from typing import Any, cast
import numpy as np import numpy as np
from pytensor import config from pytensor import config
from pytensor.compile.builders import OpFromGraph
from pytensor.gradient import DisconnectedType from pytensor.gradient import DisconnectedType
from pytensor.graph.basic import Apply, Constant from pytensor.graph.basic import Apply, Constant
from pytensor.graph.null_type import NullType from pytensor.graph.null_type import NullType
...@@ -377,3 +378,7 @@ def vectorize_node_fallback(op: Op, node: Apply, *bached_inputs) -> Apply: ...@@ -377,3 +378,7 @@ def vectorize_node_fallback(op: Op, node: Apply, *bached_inputs) -> Apply:
_vectorize_node.register(Blockwise, _vectorize_not_needed) _vectorize_node.register(Blockwise, _vectorize_not_needed)
class OpWithCoreShape(OpFromGraph):
"""Generalizes an `Op` to include core shape as an additional input."""
...@@ -2082,9 +2082,6 @@ def choice(a, size=None, replace=True, p=None, rng=None): ...@@ -2082,9 +2082,6 @@ def choice(a, size=None, replace=True, p=None, rng=None):
# This is equivalent to the numpy implementation: # This is equivalent to the numpy implementation:
# https://github.com/numpy/numpy/blob/2a9b9134270371b43223fc848b753fceab96b4a5/numpy/random/_generator.pyx#L905-L914 # https://github.com/numpy/numpy/blob/2a9b9134270371b43223fc848b753fceab96b4a5/numpy/random/_generator.pyx#L905-L914
if p is None: if p is None:
if rng is not None and isinstance(rng.type, RandomStateType):
idxs = randint(0, a_size, size=size, rng=rng)
else:
idxs = integers(0, a_size, size=size, rng=rng) idxs = integers(0, a_size, size=size, rng=rng)
else: else:
idxs = categorical(p, size=size, rng=rng) idxs = categorical(p, size=size, rng=rng)
......
...@@ -19,6 +19,7 @@ from pytensor.tensor.basic import ( ...@@ -19,6 +19,7 @@ from pytensor.tensor.basic import (
get_vector_length, get_vector_length,
infer_static_shape, infer_static_shape,
) )
from pytensor.tensor.blockwise import OpWithCoreShape
from pytensor.tensor.random.type import RandomGeneratorType, RandomStateType, RandomType from pytensor.tensor.random.type import RandomGeneratorType, RandomStateType, RandomType
from pytensor.tensor.random.utils import ( from pytensor.tensor.random.utils import (
compute_batch_shape, compute_batch_shape,
...@@ -476,3 +477,11 @@ def vectorize_random_variable( ...@@ -476,3 +477,11 @@ def vectorize_random_variable(
size = concatenate([new_size_dims, size]) size = concatenate([new_size_dims, size])
return op.make_node(rng, size, *dist_params) return op.make_node(rng, size, *dist_params)
class RandomVariableWithCoreShape(OpWithCoreShape):
"""Generalizes a random variable `Op` to include a core shape parameter."""
def __str__(self):
[rv_node] = self.fgraph.apply_nodes
return f"[{rv_node.op!s}]"
...@@ -4,7 +4,8 @@ from pytensor.tensor.random.rewriting.basic import * ...@@ -4,7 +4,8 @@ from pytensor.tensor.random.rewriting.basic import *
# isort: off # isort: off
# Register JAX specializations # Register Numba and JAX specializations
import pytensor.tensor.random.rewriting.numba
import pytensor.tensor.random.rewriting.jax import pytensor.tensor.random.rewriting.jax
# isort: on # isort: on
from pytensor.compile import optdb
from pytensor.graph import node_rewriter
from pytensor.graph.rewriting.basic import out2in
from pytensor.tensor import as_tensor, constant
from pytensor.tensor.random.op import RandomVariable, RandomVariableWithCoreShape
from pytensor.tensor.rewriting.shape import ShapeFeature
@node_rewriter([RandomVariable])
def introduce_explicit_core_shape_rv(fgraph, node):
"""Introduce the core shape of a RandomVariable.
We wrap RandomVariable graphs into a RandomVariableWithCoreShape OpFromGraph
that has an extra "non-functional" input that represents the core shape of the random variable.
This core_shape is used by the numba backend to pre-allocate the output array.
If available, the core shape is extracted from the shape feature of the graph,
which has a higher change of having been simplified, optimized, constant-folded.
If missing, we fall back to the op._supp_shape_from_params method.
This rewrite is required for the numba backend implementation of RandomVariable.
Example
-------
.. code-block:: python
import pytensor
import pytensor.tensor as pt
x = pt.random.dirichlet(alphas=[1, 2, 3], size=(5,))
pytensor.dprint(x, print_type=True)
# dirichlet_rv{"(a)->(a)"}.1 [id A] <Matrix(float64, shape=(5, 3))>
# ├─ RNG(<Generator(PCG64) at 0x7F09E59C18C0>) [id B] <RandomGeneratorType>
# ├─ [5] [id C] <Vector(int64, shape=(1,))>
# └─ ExpandDims{axis=0} [id D] <Matrix(int64, shape=(1, 3))>
# └─ [1 2 3] [id E] <Vector(int64, shape=(3,))>
# After the rewrite, note the new core shape input [3] [id B]
fn = pytensor.function([], x, mode="NUMBA")
pytensor.dprint(fn.maker.fgraph)
# [dirichlet_rv{"(a)->(a)"}].1 [id A] 0
# ├─ [3] [id B]
# ├─ RNG(<Generator(PCG64) at 0x7F15B8E844A0>) [id C]
# ├─ [5] [id D]
# └─ [[1 2 3]] [id E]
# Inner graphs:
# [dirichlet_rv{"(a)->(a)"}] [id A]
# ← dirichlet_rv{"(a)->(a)"}.0 [id F]
# ├─ *1-<RandomGeneratorType> [id G]
# ├─ *2-<Vector(int64, shape=(1,))> [id H]
# └─ *3-<Matrix(int64, shape=(1, 3))> [id I]
# ← dirichlet_rv{"(a)->(a)"}.1 [id F]
# └─ ···
"""
op: RandomVariable = node.op # type: ignore[annotation-unchecked]
next_rng, rv = node.outputs
shape_feature: ShapeFeature | None = getattr(fgraph, "shape_feature", None) # type: ignore[annotation-unchecked]
if shape_feature:
core_shape = [
shape_feature.get_shape(rv, -i - 1) for i in reversed(range(op.ndim_supp))
]
else:
core_shape = op._supp_shape_from_params(op.dist_params(node))
if len(core_shape) == 0:
core_shape = constant([], dtype="int64")
else:
core_shape = as_tensor(core_shape)
return (
RandomVariableWithCoreShape(
[core_shape, *node.inputs],
node.outputs,
destroy_map={0: [1]} if op.inplace else None,
)
.make_node(core_shape, *node.inputs)
.outputs
)
optdb.register(
introduce_explicit_core_shape_rv.__name__,
out2in(introduce_explicit_core_shape_rv),
"numba",
position=100,
)
...@@ -740,13 +740,13 @@ class UnShapeOptimizer(GraphRewriter): ...@@ -740,13 +740,13 @@ class UnShapeOptimizer(GraphRewriter):
# Register it after merge1 optimization at 0. We don't want to track # Register it after merge1 optimization at 0. We don't want to track
# the shape of merged node. # the shape of merged node.
pytensor.compile.mode.optdb.register( # type: ignore pytensor.compile.mode.optdb.register(
"ShapeOpt", ShapeOptimizer(), "fast_run", "fast_compile", position=0.1 "ShapeOpt", ShapeOptimizer(), "fast_run", "fast_compile", position=0.1
) )
# Not enabled by default for now. Some crossentropy opt use the # Not enabled by default for now. Some crossentropy opt use the
# shape_feature. They are at step 2.01. uncanonicalize is at step # shape_feature. They are at step 2.01. uncanonicalize is at step
# 3. After it goes to 48.5 that move to the gpu. So 10 seems reasonable. # 3. After it goes to 48.5 that move to the gpu. So 10 seems reasonable.
pytensor.compile.mode.optdb.register("UnShapeOpt", UnShapeOptimizer(), position=10) # type: ignore pytensor.compile.mode.optdb.register("UnShapeOpt", UnShapeOptimizer(), position=10)
def local_reshape_chain(op): def local_reshape_chain(op):
......
...@@ -8,7 +8,6 @@ pytensor/graph/rewriting/basic.py ...@@ -8,7 +8,6 @@ pytensor/graph/rewriting/basic.py
pytensor/ifelse.py pytensor/ifelse.py
pytensor/link/basic.py pytensor/link/basic.py
pytensor/link/numba/dispatch/elemwise.py pytensor/link/numba/dispatch/elemwise.py
pytensor/link/numba/dispatch/random.py
pytensor/link/numba/dispatch/scan.py pytensor/link/numba/dispatch/scan.py
pytensor/printing.py pytensor/printing.py
pytensor/raise_op.py pytensor/raise_op.py
......
...@@ -29,7 +29,6 @@ from pytensor.graph.rewriting.db import RewriteDatabaseQuery ...@@ -29,7 +29,6 @@ from pytensor.graph.rewriting.db import RewriteDatabaseQuery
from pytensor.graph.type import Type from pytensor.graph.type import Type
from pytensor.ifelse import ifelse from pytensor.ifelse import ifelse
from pytensor.link.numba.dispatch import basic as numba_basic from pytensor.link.numba.dispatch import basic as numba_basic
from pytensor.link.numba.dispatch import numba_typify
from pytensor.link.numba.linker import NumbaLinker from pytensor.link.numba.linker import NumbaLinker
from pytensor.raise_op import assert_op from pytensor.raise_op import assert_op
from pytensor.scalar.basic import ScalarOp, as_scalar from pytensor.scalar.basic import ScalarOp, as_scalar
...@@ -120,7 +119,7 @@ my_multi_out.ufunc = MyMultiOut.impl ...@@ -120,7 +119,7 @@ my_multi_out.ufunc = MyMultiOut.impl
my_multi_out.ufunc.nin = 2 my_multi_out.ufunc.nin = 2
my_multi_out.ufunc.nout = 2 my_multi_out.ufunc.nout = 2
opts = RewriteDatabaseQuery(include=[None], exclude=["cxx_only", "BlasOpt"]) opts = RewriteDatabaseQuery(include=[None], exclude=["cxx_only", "BlasOpt"])
numba_mode = Mode(NumbaLinker(), opts) numba_mode = Mode(NumbaLinker(), opts.including("numba"))
py_mode = Mode("py", opts) py_mode = Mode("py", opts)
rng = np.random.default_rng(42849) rng = np.random.default_rng(42849)
...@@ -229,6 +228,7 @@ def compare_numba_and_py( ...@@ -229,6 +228,7 @@ def compare_numba_and_py(
numba_mode=numba_mode, numba_mode=numba_mode,
py_mode=py_mode, py_mode=py_mode,
updates=None, updates=None,
eval_obj_mode: bool = True,
) -> tuple[Callable, Any]: ) -> tuple[Callable, Any]:
"""Function to compare python graph output and Numba compiled output for testing equality """Function to compare python graph output and Numba compiled output for testing equality
...@@ -247,6 +247,8 @@ def compare_numba_and_py( ...@@ -247,6 +247,8 @@ def compare_numba_and_py(
provided uses `np.testing.assert_allclose`. provided uses `np.testing.assert_allclose`.
updates updates
Updates to be passed to `pytensor.function`. Updates to be passed to `pytensor.function`.
eval_obj_mode : bool, default True
Whether to do an isolated call in object mode. Used for test coverage
Returns Returns
------- -------
...@@ -283,6 +285,7 @@ def compare_numba_and_py( ...@@ -283,6 +285,7 @@ def compare_numba_and_py(
numba_res = pytensor_numba_fn(*inputs) numba_res = pytensor_numba_fn(*inputs)
# Get some coverage # Get some coverage
if eval_obj_mode:
eval_python_only(fn_inputs, fn_outputs, inputs, mode=numba_mode) eval_python_only(fn_inputs, fn_outputs, inputs, mode=numba_mode)
if len(fn_outputs) > 1: if len(fn_outputs) > 1:
...@@ -359,26 +362,6 @@ def test_create_numba_signature(v, expected, force_scalar): ...@@ -359,26 +362,6 @@ def test_create_numba_signature(v, expected, force_scalar):
assert res == expected assert res == expected
@pytest.mark.parametrize(
"input, wrapper_fn, check_fn",
[
(
np.random.RandomState(1),
numba_typify,
lambda x, y: np.all(x.get_state()[1] == y.get_state()[1]),
)
],
)
def test_box_unbox(input, wrapper_fn, check_fn):
input = wrapper_fn(input)
pass_through = numba.njit(lambda x: x)
res = pass_through(input)
assert isinstance(res, type(input))
assert check_fn(res, input)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"x, indices", "x, indices",
[ [
......
...@@ -8,13 +8,13 @@ import scipy.stats as stats ...@@ -8,13 +8,13 @@ import scipy.stats as stats
import pytensor.tensor as pt import pytensor.tensor as pt
import pytensor.tensor.random.basic as ptr import pytensor.tensor.random.basic as ptr
from pytensor import shared from pytensor import shared
from pytensor.compile.builders import OpFromGraph
from pytensor.compile.function import function from pytensor.compile.function import function
from pytensor.compile.sharedvalue import SharedVariable from pytensor.compile.sharedvalue import SharedVariable
from pytensor.graph.basic import Constant from pytensor.graph.basic import Constant
from pytensor.graph.fg import FunctionGraph from pytensor.graph.fg import FunctionGraph
from tests.link.numba.test_basic import ( from tests.link.numba.test_basic import (
compare_numba_and_py, compare_numba_and_py,
eval_python_only,
numba_mode, numba_mode,
set_test_value, set_test_value,
) )
...@@ -28,37 +28,139 @@ from tests.tensor.random.test_basic import ( ...@@ -28,37 +28,139 @@ from tests.tensor.random.test_basic import (
rng = np.random.default_rng(42849) rng = np.random.default_rng(42849)
@pytest.mark.xfail( @pytest.mark.parametrize("mu_shape", [(), (3,), (5, 1)])
reason="Most RVs are not working correctly with explicit expand_dims" @pytest.mark.parametrize("sigma_shape", [(), (1,), (5, 3)])
) @pytest.mark.parametrize("size_type", (None, "constant", "mutable"))
def test_random_size(mu_shape, sigma_shape, size_type):
test_value_rng = np.random.default_rng(637)
mu = test_value_rng.normal(size=mu_shape)
sigma = np.exp(test_value_rng.normal(size=sigma_shape))
# For testing
rng = np.random.default_rng(123)
pt_rng = shared(rng)
if size_type is None:
size = None
pt_size = None
elif size_type == "constant":
size = (5, 3)
pt_size = pt.as_tensor(size, dtype="int64")
else:
size = (5, 3)
pt_size = shared(np.array(size, dtype="int64"), shape=(2,))
next_rng, x = pt.random.normal(mu, sigma, rng=pt_rng, size=pt_size).owner.outputs
fn = function([], x, updates={pt_rng: next_rng}, mode="NUMBA")
res1 = fn()
np.testing.assert_allclose(
res1,
rng.normal(mu, sigma, size=size),
)
res2 = fn()
np.testing.assert_allclose(
res2,
rng.normal(mu, sigma, size=size),
)
pt_rng.set_value(np.random.default_rng(123))
res3 = fn()
np.testing.assert_array_equal(res1, res3)
if size_type == "mutable" and len(mu_shape) < 2 and len(sigma_shape) < 2:
pt_size.set_value(np.array((6, 3), dtype="int64"))
res4 = fn()
assert res4.shape == (6, 3)
def test_rng_copy():
rng = shared(np.random.default_rng(123))
x = pt.random.normal(rng=rng)
fn = function([], x, mode="NUMBA")
np.testing.assert_array_equal(fn(), fn())
rng.type.values_eq(rng.get_value(), np.random.default_rng(123))
def test_rng_non_default_update():
rng = shared(np.random.default_rng(1))
rng_new = shared(np.random.default_rng(2))
x = pt.random.normal(size=10, rng=rng)
fn = function([], x, updates={rng: rng_new}, mode=numba_mode)
ref = np.random.default_rng(1).normal(size=10)
np.testing.assert_allclose(fn(), ref)
ref = np.random.default_rng(2).normal(size=10)
np.testing.assert_allclose(fn(), ref)
np.testing.assert_allclose(fn(), ref)
def test_categorical_rv():
"""This is also a smoke test for a vector input scalar output RV"""
p = np.array(
[
[
[1.0, 0, 0, 0],
[0.0, 1.0, 0, 0],
[0.0, 0, 1.0, 0],
],
[
[0, 0, 0, 1.0],
[0, 0, 0, 1.0],
[0, 0, 0, 1.0],
],
]
)
x = pt.random.categorical(p=p, size=None)
updates = {x.owner.inputs[0]: x.owner.outputs[0]}
fn = function([], x, updates=updates, mode="NUMBA")
res = fn()
assert np.all(np.argmax(p, axis=-1) == res)
# Batch size
x = pt.random.categorical(p=p, size=(3, *p.shape[:-1]))
fn = function([], x, updates=updates, mode="NUMBA")
new_res = fn()
assert new_res.shape == (3, *res.shape)
for new_res_row in new_res:
assert np.all(new_res_row == res)
def test_multivariate_normal():
"""This is also a smoke test for a multivariate RV"""
rng = np.random.default_rng(123)
x = pt.random.multivariate_normal(
mean=np.zeros((3, 2)),
cov=np.eye(2),
rng=shared(rng),
)
fn = function([], x, mode="NUMBA")
np.testing.assert_array_equal(
fn(),
rng.multivariate_normal(np.zeros(2), np.eye(2), size=(3,)),
)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"rv_op, dist_args, size", "rv_op, dist_args, size",
[ [
( (
ptr.normal, ptr.uniform,
[ [
set_test_value(
pt.dvector(),
np.array([1.0, 2.0], dtype=np.float64),
),
set_test_value( set_test_value(
pt.dscalar(), pt.dscalar(),
np.array(1.0, dtype=np.float64), np.array(1.0, dtype=np.float64),
), ),
],
pt.as_tensor([3, 2]),
),
(
ptr.uniform,
[
set_test_value( set_test_value(
pt.dvector(), pt.dvector(),
np.array([1.0, 2.0], dtype=np.float64), np.array([1.0, 2.0], dtype=np.float64),
), ),
set_test_value(
pt.dscalar(),
np.array(1.0, dtype=np.float64),
),
], ],
pt.as_tensor([3, 2]), pt.as_tensor([3, 2]),
), ),
...@@ -94,7 +196,7 @@ rng = np.random.default_rng(42849) ...@@ -94,7 +196,7 @@ rng = np.random.default_rng(42849)
], ],
pt.as_tensor([3, 2]), pt.as_tensor([3, 2]),
), ),
pytest.param( (
ptr.pareto, ptr.pareto,
[ [
set_test_value( set_test_value(
...@@ -107,7 +209,6 @@ rng = np.random.default_rng(42849) ...@@ -107,7 +209,6 @@ rng = np.random.default_rng(42849)
), ),
], ],
pt.as_tensor([3, 2]), pt.as_tensor([3, 2]),
marks=pytest.mark.xfail(reason="Not implemented"),
), ),
( (
ptr.exponential, ptr.exponential,
...@@ -153,7 +254,7 @@ rng = np.random.default_rng(42849) ...@@ -153,7 +254,7 @@ rng = np.random.default_rng(42849)
], ],
pt.as_tensor([3, 2]), pt.as_tensor([3, 2]),
), ),
( pytest.param(
ptr.hypergeometric, ptr.hypergeometric,
[ [
set_test_value( set_test_value(
...@@ -170,6 +271,7 @@ rng = np.random.default_rng(42849) ...@@ -170,6 +271,7 @@ rng = np.random.default_rng(42849)
), ),
], ],
pt.as_tensor([3, 2]), pt.as_tensor([3, 2]),
marks=pytest.mark.xfail, # Not implemented
), ),
( (
ptr.wald, ptr.wald,
...@@ -262,33 +364,70 @@ rng = np.random.default_rng(42849) ...@@ -262,33 +364,70 @@ rng = np.random.default_rng(42849)
None, None,
), ),
( (
ptr.randint, ptr.beta,
[ [
set_test_value( set_test_value(
pt.lscalar(), pt.dvector(),
np.array(0, dtype=np.int64), np.array([1.0, 2.0], dtype=np.float64),
), ),
set_test_value( set_test_value(
pt.lscalar(), pt.dscalar(),
np.array(5, dtype=np.int64), np.array(1.0, dtype=np.float64),
), ),
], ],
pt.as_tensor([3, 2]), (2,),
), ),
pytest.param( (
ptr.multivariate_normal, ptr._gamma,
[ [
set_test_value( set_test_value(
pt.dmatrix(), pt.dvector(),
np.array([[1, 2], [3, 4]], dtype=np.float64), np.array([1.0, 2.0], dtype=np.float64),
),
set_test_value(
pt.dvector(),
np.array([0.5, 3.0], dtype=np.float64),
),
],
(2,),
), ),
(
ptr.chisquare,
[
set_test_value( set_test_value(
pt.tensor(dtype="float64", shape=(1, None, None)), pt.dvector(),
np.eye(2)[None, ...], np.array([1.0, 2.0], dtype=np.float64),
)
],
(2,),
),
(
ptr.negative_binomial,
[
set_test_value(
pt.lvector(),
np.array([100, 200], dtype=np.int64),
),
set_test_value(
pt.dscalar(),
np.array(0.09, dtype=np.float64),
), ),
], ],
pt.as_tensor(tuple(set_test_value(pt.lscalar(), v) for v in [4, 3, 2])), (2,),
marks=pytest.mark.xfail(reason="Not implemented"), ),
(
ptr.vonmises,
[
set_test_value(
pt.dvector(),
np.array([-0.5, 0.5], dtype=np.float64),
),
set_test_value(
pt.dscalar(),
np.array(1.0, dtype=np.float64),
),
],
(2,),
), ),
( (
ptr.permutation, ptr.permutation,
...@@ -312,17 +451,21 @@ rng = np.random.default_rng(42849) ...@@ -312,17 +451,21 @@ rng = np.random.default_rng(42849)
[ [
set_test_value(pt.dmatrix(), np.eye(3, dtype=np.float64)), set_test_value(pt.dmatrix(), np.eye(3, dtype=np.float64)),
set_test_value( set_test_value(
pt.dvector(), np.array([0.5, 0.0, 0.5], dtype=np.float64) pt.dvector(), np.array([0.25, 0.5, 0.25], dtype=np.float64)
), ),
], ],
(), (pt.as_tensor([2, 3])),
), ),
( pytest.param(
partial(ptr.choice, replace=False), partial(ptr.choice, replace=False),
[ [
set_test_value(pt.dvector(), np.arange(5, dtype=np.float64)), set_test_value(pt.dvector(), np.arange(5, dtype=np.float64)),
], ],
pt.as_tensor([2]), pt.as_tensor([2]),
marks=pytest.mark.xfail(
AssertionError,
reason="Not aligned with NumPy implementation",
),
), ),
pytest.param( pytest.param(
partial(ptr.choice, replace=False), partial(ptr.choice, replace=False),
...@@ -331,28 +474,23 @@ rng = np.random.default_rng(42849) ...@@ -331,28 +474,23 @@ rng = np.random.default_rng(42849)
], ],
pt.as_tensor([2]), pt.as_tensor([2]),
marks=pytest.mark.xfail( marks=pytest.mark.xfail(
raises=ValueError, raises=AssertionError,
reason="Numba random.choice does not support >=1D `a`", reason="Not aligned with NumPy implementation",
), ),
), ),
pytest.param( (
# p must be passed by kwarg # p must be passed by kwarg
lambda a, p, size, rng: ptr.choice( lambda a, p, size, rng: ptr.choice(
a, p=p, size=size, replace=False, rng=rng a, p=p, size=size, replace=False, rng=rng
), ),
[ [
set_test_value(pt.vector(), np.arange(5, dtype=np.float64)), set_test_value(pt.vector(), np.arange(5, dtype=np.float64)),
# Boring p, because the variable is not truly "aligned"
set_test_value( set_test_value(
pt.dvector(), pt.dvector(),
np.array([0.5, 0.0, 0.25, 0.0, 0.25], dtype=np.float64), np.array([0.5, 0.0, 0.25, 0.0, 0.25], dtype=np.float64),
), ),
], ],
(), pt.as_tensor([2]),
marks=pytest.mark.xfail(
raises=Exception, # numba.TypeError
reason="Numba random.choice does not support `p` parameter",
),
), ),
pytest.param( pytest.param(
# p must be passed by kwarg # p must be passed by kwarg
...@@ -361,23 +499,31 @@ rng = np.random.default_rng(42849) ...@@ -361,23 +499,31 @@ rng = np.random.default_rng(42849)
), ),
[ [
set_test_value(pt.dmatrix(), np.eye(3, dtype=np.float64)), set_test_value(pt.dmatrix(), np.eye(3, dtype=np.float64)),
# Boring p, because the variable is not truly "aligned"
set_test_value( set_test_value(
pt.dvector(), np.array([0.5, 0.0, 0.5], dtype=np.float64) pt.dvector(), np.array([0.25, 0.5, 0.25], dtype=np.float64)
), ),
], ],
(), (),
marks=pytest.mark.xfail(
raises=ValueError,
reason="Numba random.choice does not support >=1D `a`",
), ),
pytest.param(
# p must be passed by kwarg
lambda a, p, size, rng: ptr.choice(
a, p=p, size=size, replace=False, rng=rng
),
[
set_test_value(pt.dmatrix(), np.eye(3, dtype=np.float64)),
set_test_value(
pt.dvector(), np.array([0.25, 0.5, 0.25], dtype=np.float64)
),
],
(pt.as_tensor([2, 1])),
), ),
], ],
ids=str, ids=str,
) )
def test_aligned_RandomVariable(rv_op, dist_args, size): def test_aligned_RandomVariable(rv_op, dist_args, size):
"""Tests for Numba samplers that are one-to-one with PyTensor's/NumPy's samplers.""" """Tests for Numba samplers that are one-to-one with PyTensor's/NumPy's samplers."""
rng = shared(np.random.RandomState(29402)) rng = shared(np.random.default_rng(29402))
g = rv_op(*dist_args, size=size, rng=rng) g = rv_op(*dist_args, size=size, rng=rng)
g_fg = FunctionGraph(outputs=[g]) g_fg = FunctionGraph(outputs=[g])
...@@ -388,45 +534,13 @@ def test_aligned_RandomVariable(rv_op, dist_args, size): ...@@ -388,45 +534,13 @@ def test_aligned_RandomVariable(rv_op, dist_args, size):
for i in g_fg.inputs for i in g_fg.inputs
if not isinstance(i, SharedVariable | Constant) if not isinstance(i, SharedVariable | Constant)
], ],
eval_obj_mode=False, # No python impl
) )
@pytest.mark.xfail(reason="Test is not working correctly with explicit expand_dims")
@pytest.mark.parametrize( @pytest.mark.parametrize(
"rv_op, dist_args, base_size, cdf_name, params_conv", "rv_op, dist_args, base_size, cdf_name, params_conv",
[ [
(
ptr.beta,
[
set_test_value(
pt.dvector(),
np.array([1.0, 2.0], dtype=np.float64),
),
set_test_value(
pt.dscalar(),
np.array(1.0, dtype=np.float64),
),
],
(2,),
"beta",
lambda *args: args,
),
(
ptr._gamma,
[
set_test_value(
pt.dvector(),
np.array([1.0, 2.0], dtype=np.float64),
),
set_test_value(
pt.dvector(),
np.array([0.5, 3.0], dtype=np.float64),
),
],
(2,),
"gamma",
lambda a, b: (a, 0.0, b),
),
( (
ptr.cauchy, ptr.cauchy,
[ [
...@@ -443,18 +557,6 @@ def test_aligned_RandomVariable(rv_op, dist_args, size): ...@@ -443,18 +557,6 @@ def test_aligned_RandomVariable(rv_op, dist_args, size):
"cauchy", "cauchy",
lambda *args: args, lambda *args: args,
), ),
(
ptr.chisquare,
[
set_test_value(
pt.dvector(),
np.array([1.0, 2.0], dtype=np.float64),
)
],
(2,),
"chi2",
lambda *args: args,
),
( (
ptr.gumbel, ptr.gumbel,
[ [
...@@ -471,49 +573,11 @@ def test_aligned_RandomVariable(rv_op, dist_args, size): ...@@ -471,49 +573,11 @@ def test_aligned_RandomVariable(rv_op, dist_args, size):
"gumbel_r", "gumbel_r",
lambda *args: args, lambda *args: args,
), ),
(
ptr.negative_binomial,
[
set_test_value(
pt.lvector(),
np.array([100, 200], dtype=np.int64),
),
set_test_value(
pt.dscalar(),
np.array(0.09, dtype=np.float64),
),
],
(2,),
"nbinom",
lambda *args: args,
),
pytest.param(
ptr.vonmises,
[
set_test_value(
pt.dvector(),
np.array([-0.5, 0.5], dtype=np.float64),
),
set_test_value(
pt.dscalar(),
np.array(1.0, dtype=np.float64),
),
],
(2,),
"vonmises_line",
lambda mu, kappa: (kappa, mu),
marks=pytest.mark.xfail(
reason=(
"Numba's parameterization of `vonmises` does not match NumPy's."
"See https://github.com/numba/numba/issues/7886"
)
),
),
], ],
) )
def test_unaligned_RandomVariable(rv_op, dist_args, base_size, cdf_name, params_conv): def test_unaligned_RandomVariable(rv_op, dist_args, base_size, cdf_name, params_conv):
"""Tests for Numba samplers that are not one-to-one with PyTensor's/NumPy's samplers.""" """Tests for Numba samplers that are not one-to-one with PyTensor's/NumPy's samplers."""
rng = shared(np.random.RandomState(29402)) rng = shared(np.random.default_rng(29402))
g = rv_op(*dist_args, size=(2000, *base_size), rng=rng) g = rv_op(*dist_args, size=(2000, *base_size), rng=rng)
g_fn = function(dist_args, g, mode=numba_mode) g_fn = function(dist_args, g, mode=numba_mode)
samples = g_fn( samples = g_fn(
...@@ -535,20 +599,17 @@ def test_unaligned_RandomVariable(rv_op, dist_args, base_size, cdf_name, params_ ...@@ -535,20 +599,17 @@ def test_unaligned_RandomVariable(rv_op, dist_args, base_size, cdf_name, params_
@pytest.mark.parametrize( @pytest.mark.parametrize(
"dist_args, size, cm", "a, size, cm",
[ [
pytest.param( pytest.param(
[
set_test_value( set_test_value(
pt.dvector(), pt.dvector(),
np.array([100000, 1, 1], dtype=np.float64), np.array([100000, 1, 1], dtype=np.float64),
), ),
],
None, None,
contextlib.suppress(), contextlib.suppress(),
), ),
pytest.param( pytest.param(
[
set_test_value( set_test_value(
pt.dmatrix(), pt.dmatrix(),
np.array( np.array(
...@@ -556,25 +617,10 @@ def test_unaligned_RandomVariable(rv_op, dist_args, base_size, cdf_name, params_ ...@@ -556,25 +617,10 @@ def test_unaligned_RandomVariable(rv_op, dist_args, base_size, cdf_name, params_
dtype=np.float64, dtype=np.float64,
), ),
), ),
],
(10, 3), (10, 3),
contextlib.suppress(), contextlib.suppress(),
), ),
pytest.param( pytest.param(
[
set_test_value(
pt.dmatrix(),
np.array(
[[100000, 1, 1]],
dtype=np.float64,
),
),
],
(5, 4, 3),
contextlib.suppress(),
),
pytest.param(
[
set_test_value( set_test_value(
pt.dmatrix(), pt.dmatrix(),
np.array( np.array(
...@@ -582,76 +628,22 @@ def test_unaligned_RandomVariable(rv_op, dist_args, base_size, cdf_name, params_ ...@@ -582,76 +628,22 @@ def test_unaligned_RandomVariable(rv_op, dist_args, base_size, cdf_name, params_
dtype=np.float64, dtype=np.float64,
), ),
), ),
],
(10, 4), (10, 4),
pytest.raises( pytest.raises(
ValueError, match="objects cannot be broadcast to a single shape" ValueError,
match="Vectorized input 0 has an incompatible shape in axis 1.",
), ),
), ),
], ],
) )
def test_CategoricalRV(dist_args, size, cm):
rng = shared(np.random.RandomState(29402))
g = ptr.categorical(*dist_args, size=size, rng=rng)
g_fg = FunctionGraph(outputs=[g])
with cm:
compare_numba_and_py(
g_fg,
[
i.tag.test_value
for i in g_fg.inputs
if not isinstance(i, SharedVariable | Constant)
],
)
@pytest.mark.parametrize(
"a, size, cm",
[
pytest.param(
set_test_value(
pt.dvector(),
np.array([100000, 1, 1], dtype=np.float64),
),
None,
contextlib.suppress(),
),
pytest.param(
set_test_value(
pt.dmatrix(),
np.array(
[[100000, 1, 1], [1, 100000, 1], [1, 1, 100000]],
dtype=np.float64,
),
),
(10, 3),
contextlib.suppress(),
),
pytest.param(
set_test_value(
pt.dmatrix(),
np.array(
[[100000, 1, 1], [1, 100000, 1], [1, 1, 100000]],
dtype=np.float64,
),
),
(10, 4),
pytest.raises(ValueError, match="operands could not be broadcast together"),
),
],
)
def test_DirichletRV(a, size, cm): def test_DirichletRV(a, size, cm):
rng = shared(np.random.RandomState(29402)) rng = shared(np.random.default_rng(29402))
g = ptr.dirichlet(a, size=size, rng=rng) g = ptr.dirichlet(a, size=size, rng=rng)
g_fn = function([a], g, mode=numba_mode) g_fn = function([a], g, mode=numba_mode)
with cm: with cm:
a_val = a.tag.test_value a_val = a.tag.test_value
# For coverage purposes only...
eval_python_only([a], [g], [a_val])
all_samples = [] all_samples = []
for i in range(1000): for i in range(1000):
samples = g_fn(a_val) samples = g_fn(a_val)
...@@ -662,48 +654,34 @@ def test_DirichletRV(a, size, cm): ...@@ -662,48 +654,34 @@ def test_DirichletRV(a, size, cm):
assert np.allclose(res, exp_res, atol=1e-4) assert np.allclose(res, exp_res, atol=1e-4)
@pytest.mark.xfail(reason="RandomState is not aligned with explicit expand_dims") def test_rv_inside_ofg():
def test_RandomState_updates(): rng_np = np.random.default_rng(562)
rng = shared(np.random.RandomState(1)) rng = shared(rng_np)
rng_new = shared(np.random.RandomState(2))
x = pt.random.normal(size=10, rng=rng) rng_dummy = rng.type()
res = function([], x, updates={rng: rng_new}, mode=numba_mode)() next_rng_dummy, rv_dummy = ptr.normal(
0, 1, size=(3, 2), rng=rng_dummy
).owner.outputs
out_dummy = rv_dummy.T
ref = np.random.RandomState(2).normal(size=10) next_rng, out = OpFromGraph([rng_dummy], [next_rng_dummy, out_dummy])(rng)
assert np.allclose(res, ref) fn = function([], out, updates={rng: next_rng}, mode=numba_mode)
res1, res2 = fn(), fn()
assert res1.shape == (2, 3)
def test_random_Generator(): np.testing.assert_allclose(res1, rng_np.normal(0, 1, size=(3, 2)).T)
rng = shared(np.random.default_rng(29402)) np.testing.assert_allclose(res2, rng_np.normal(0, 1, size=(3, 2)).T)
g = ptr.normal(rng=rng)
g_fg = FunctionGraph(outputs=[g])
with pytest.raises(TypeError):
compare_numba_and_py(
g_fg,
[
i.tag.test_value
for i in g_fg.inputs
if not isinstance(i, SharedVariable | Constant)
],
)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"batch_dims_tester", "batch_dims_tester",
[ [
pytest.param(
batched_unweighted_choice_without_replacement_tester, batched_unweighted_choice_without_replacement_tester,
marks=pytest.mark.xfail(raises=NotImplementedError),
),
pytest.param(
batched_weighted_choice_without_replacement_tester, batched_weighted_choice_without_replacement_tester,
marks=pytest.mark.xfail(raises=NotImplementedError),
),
batched_permutation_tester, batched_permutation_tester,
], ],
) )
def test_unnatural_batched_dims(batch_dims_tester): def test_unnatural_batched_dims(batch_dims_tester):
"""Tests for RVs that don't have natural batch dims in Numba API.""" """Tests for RVs that don't have natural batch dims in Numba API."""
batch_dims_tester(mode="NUMBA", rng_ctor=np.random.RandomState) batch_dims_tester(mode="NUMBA")
...@@ -77,15 +77,13 @@ from tests.link.numba.test_basic import compare_numba_and_py ...@@ -77,15 +77,13 @@ from tests.link.numba.test_basic import compare_numba_and_py
), ),
# nit-sot, shared input/output # nit-sot, shared input/output
( (
lambda: RandomStream(seed=1930, rng_ctor=np.random.RandomState).normal( lambda: RandomStream(seed=1930).normal(0, 1, name="a"),
0, 1, name="a"
),
[], [],
[{}], [{}],
[], [],
3, 3,
[], [],
[np.array([-1.63408257, 0.18046406, 2.43265803])], [np.array([0.50100236, 2.16822932, 1.36326596])],
lambda op: op.info.n_shared_outs > 0, lambda op: op.info.n_shared_outs > 0,
), ),
# mit-sot (that's also a type of sit-sot) # mit-sot (that's also a type of sit-sot)
......
...@@ -1452,9 +1452,7 @@ def test_permutation_shape(): ...@@ -1452,9 +1452,7 @@ def test_permutation_shape():
assert tuple(permutation(np.arange(5), size=(2, 3)).shape.eval()) == (2, 3, 5) assert tuple(permutation(np.arange(5), size=(2, 3)).shape.eval()) == (2, 3, 5)
def batched_unweighted_choice_without_replacement_tester( def batched_unweighted_choice_without_replacement_tester(mode="FAST_RUN"):
mode="FAST_RUN", rng_ctor=np.random.default_rng
):
"""Test unweighted choice without replacement with batched ndims. """Test unweighted choice without replacement with batched ndims.
This has no corresponding in numpy, but is supported for consistency within the This has no corresponding in numpy, but is supported for consistency within the
...@@ -1462,7 +1460,7 @@ def batched_unweighted_choice_without_replacement_tester( ...@@ -1462,7 +1460,7 @@ def batched_unweighted_choice_without_replacement_tester(
It can be triggered by manual buiding the Op or during automatic vectorization. It can be triggered by manual buiding the Op or during automatic vectorization.
""" """
rng = shared(rng_ctor()) rng = shared(np.random.default_rng())
# Batched a implicit size # Batched a implicit size
rv_op = ChoiceWithoutReplacement( rv_op = ChoiceWithoutReplacement(
...@@ -1499,9 +1497,7 @@ def batched_unweighted_choice_without_replacement_tester( ...@@ -1499,9 +1497,7 @@ def batched_unweighted_choice_without_replacement_tester(
assert np.all((draw >= i * 10) & (draw < (i + 1) * 10)) assert np.all((draw >= i * 10) & (draw < (i + 1) * 10))
def batched_weighted_choice_without_replacement_tester( def batched_weighted_choice_without_replacement_tester(mode="FAST_RUN"):
mode="FAST_RUN", rng_ctor=np.random.default_rng
):
"""Test weighted choice without replacement with batched ndims. """Test weighted choice without replacement with batched ndims.
This has no corresponding in numpy, but is supported for consistency within the This has no corresponding in numpy, but is supported for consistency within the
...@@ -1509,7 +1505,7 @@ def batched_weighted_choice_without_replacement_tester( ...@@ -1509,7 +1505,7 @@ def batched_weighted_choice_without_replacement_tester(
It can be triggered by manual buiding the Op or during automatic vectorization. It can be triggered by manual buiding the Op or during automatic vectorization.
""" """
rng = shared(rng_ctor()) rng = shared(np.random.default_rng())
rv_op = ChoiceWithoutReplacement( rv_op = ChoiceWithoutReplacement(
signature="(a0,a1),(a0),(1)->(s0,a1)", signature="(a0,a1),(a0),(1)->(s0,a1)",
...@@ -1574,7 +1570,7 @@ def batched_weighted_choice_without_replacement_tester( ...@@ -1574,7 +1570,7 @@ def batched_weighted_choice_without_replacement_tester(
assert np.all((draw >= i * 10 + 2) & (draw < (i + 1) * 10)) assert np.all((draw >= i * 10 + 2) & (draw < (i + 1) * 10))
def batched_permutation_tester(mode="FAST_RUN", rng_ctor=np.random.default_rng): def batched_permutation_tester(mode="FAST_RUN"):
"""Test permutation with batched ndims. """Test permutation with batched ndims.
This has no corresponding in numpy, but is supported for consistency within the This has no corresponding in numpy, but is supported for consistency within the
...@@ -1583,7 +1579,7 @@ def batched_permutation_tester(mode="FAST_RUN", rng_ctor=np.random.default_rng): ...@@ -1583,7 +1579,7 @@ def batched_permutation_tester(mode="FAST_RUN", rng_ctor=np.random.default_rng):
It can be triggered by manual buiding the Op or during automatic vectorization. It can be triggered by manual buiding the Op or during automatic vectorization.
""" """
rng = shared(rng_ctor()) rng = shared(np.random.default_rng())
rv_op = PermutationRV(ndim_supp=2, ndims_params=[2], dtype="int64") rv_op = PermutationRV(ndim_supp=2, ndims_params=[2], dtype="int64")
x = np.arange(5 * 3 * 2).reshape((5, 3, 2)) x = np.arange(5 * 3 * 2).reshape((5, 3, 2))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论