提交 7886cf83 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Implement XTensorVariable version of RandomVariables

上级 33d04c36
...@@ -1625,8 +1625,7 @@ class NegBinomialRV(ScipyRandomVariable): ...@@ -1625,8 +1625,7 @@ class NegBinomialRV(ScipyRandomVariable):
return stats.nbinom.rvs(n, p, size=size, random_state=rng) return stats.nbinom.rvs(n, p, size=size, random_state=rng)
nbinom = NegBinomialRV() nbinom = negative_binomial = NegBinomialRV()
negative_binomial = NegBinomialRV()
class BetaBinomialRV(ScipyRandomVariable): class BetaBinomialRV(ScipyRandomVariable):
...@@ -1808,6 +1807,7 @@ class MultinomialRV(RandomVariable): ...@@ -1808,6 +1807,7 @@ class MultinomialRV(RandomVariable):
multinomial = MultinomialRV() multinomial = MultinomialRV()
vsearchsorted = np.vectorize(np.searchsorted, otypes=[int], signature="(n),()->()") vsearchsorted = np.vectorize(np.searchsorted, otypes=[int], signature="(n),()->()")
......
import abc
import warnings import warnings
from collections.abc import Sequence from collections.abc import Sequence
from copy import deepcopy from copy import deepcopy
...@@ -32,7 +33,20 @@ from pytensor.tensor.utils import _parse_gufunc_signature, safe_signature ...@@ -32,7 +33,20 @@ from pytensor.tensor.utils import _parse_gufunc_signature, safe_signature
from pytensor.tensor.variable import TensorVariable from pytensor.tensor.variable import TensorVariable
class RandomVariable(Op): class RNGConsumerOp(Op):
"""Baseclass for Ops that consume RNGs."""
@abc.abstractmethod
def update(self, node: Apply) -> dict[Variable, Variable]:
"""Symbolic update expression for input RNG variables.
Returns a dictionary with the symbolic expressions required for correct updating
of RNG variables in repeated function evaluations.
"""
pass
class RandomVariable(RNGConsumerOp):
"""An `Op` that produces a sample from a random variable. """An `Op` that produces a sample from a random variable.
This is essentially `RandomFunction`, except that it removes the This is essentially `RandomFunction`, except that it removes the
...@@ -123,6 +137,9 @@ class RandomVariable(Op): ...@@ -123,6 +137,9 @@ class RandomVariable(Op):
if self.inplace: if self.inplace:
self.destroy_map = {0: [0]} self.destroy_map = {0: [0]}
def update(self, node: Apply) -> dict[Variable, Variable]:
return {node.inputs[0]: node.outputs[0]}
def _supp_shape_from_params(self, dist_params, param_shapes=None): def _supp_shape_from_params(self, dist_params, param_shapes=None):
"""Determine the support shape of a multivariate `RandomVariable`'s output given its parameters. """Determine the support shape of a multivariate `RandomVariable`'s output given its parameters.
......
...@@ -759,6 +759,7 @@ def local_remove_useless_assert(fgraph, node): ...@@ -759,6 +759,7 @@ def local_remove_useless_assert(fgraph, node):
return [new_var] return [new_var]
@register_infer_shape
@node_rewriter([Assert]) @node_rewriter([Assert])
def local_remove_all_assert(fgraph, node): def local_remove_all_assert(fgraph, node):
r"""A rewrite that removes all `Assert`\s from a graph. r"""A rewrite that removes all `Assert`\s from a graph.
...@@ -768,9 +769,6 @@ def local_remove_all_assert(fgraph, node): ...@@ -768,9 +769,6 @@ def local_remove_all_assert(fgraph, node):
See the :ref:`unsafe` section. See the :ref:`unsafe` section.
""" """
if not isinstance(node.op, Assert):
return
return [node.inputs[0]] return [node.inputs[0]]
......
...@@ -9,6 +9,7 @@ from numpy import nditer ...@@ -9,6 +9,7 @@ from numpy import nditer
import pytensor import pytensor
from pytensor.graph import FunctionGraph, Variable from pytensor.graph import FunctionGraph, Variable
from pytensor.npy_2_compat import normalize_axis_tuple from pytensor.npy_2_compat import normalize_axis_tuple
from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.utils import hash_from_code from pytensor.utils import hash_from_code
...@@ -256,3 +257,31 @@ def faster_ndindex(shape: Sequence[int]): ...@@ -256,3 +257,31 @@ def faster_ndindex(shape: Sequence[int]):
https://github.com/numpy/numpy/issues/28921 https://github.com/numpy/numpy/issues/28921
""" """
return product(*(range(s) for s in shape)) return product(*(range(s) for s in shape))
def get_static_shape_from_size_variables(
size_vars: Sequence[Variable],
) -> tuple[int | None, ...]:
"""Get static shape from size variables.
Parameters
----------
size_vars : Sequence[Variable]
A sequence of variables representing the size of each dimension.
Returns
-------
tuple[int | None, ...]
A tuple containing the static lengths of each dimension, or None if
the length is not statically known.
"""
from pytensor.tensor.basic import get_scalar_constant_value
static_lengths: list[None | int] = [None] * len(size_vars)
for i, length in enumerate(size_vars):
try:
static_length = get_scalar_constant_value(length)
except NotScalarConstantError:
pass
else:
static_lengths[i] = int(static_length)
return tuple(static_lengths)
from collections.abc import Sequence
from functools import wraps
from typing import Literal
import pytensor.tensor.random.basic as ptr
from pytensor.graph.basic import Variable
from pytensor.tensor.random.op import RandomVariable
from pytensor.xtensor import as_xtensor
from pytensor.xtensor.math import sqrt
from pytensor.xtensor.vectorization import XRV
def _as_xrv(
core_op: RandomVariable,
core_inps_dims_map: Sequence[Sequence[int]] | None = None,
core_out_dims_map: Sequence[int] | None = None,
):
"""Helper function to define an XRV constructor.
Parameters
----------
core_op : RandomVariable
The core random variable operation to wrap.
core_inps_dims_map : Sequence[Sequence[int]] | None, optional
A sequence of sequences mapping the core dimensions (specified by the user)
for each input parameter. This is used when lowering to a RandomVariable operation,
to decide the ordering of the core dimensions for each input.
If None, it assumes the core dimensions are positional from left to right.
core_out_dims_map : Sequence[int] | None, optional
A sequence mapping the core dimensions (specified by the user) for the output variable.
This is used when lowering to a RandomVariable operation,
to decide the ordering of the core dimensions for the output.
If None, it assumes the core dimensions are positional from left to right.
"""
if core_inps_dims_map is None:
# Assume core_dims map positionally from left to right
core_inps_dims_map = [tuple(range(ndim)) for ndim in core_op.ndims_params]
if core_out_dims_map is None:
# Assume core_dims map positionally from left to right
core_out_dims_map = tuple(range(core_op.ndim_supp))
core_dims_needed = max(
(*(len(i) for i in core_inps_dims_map), len(core_out_dims_map)), default=0
)
@wraps(core_op)
def xrv_constructor(
*params,
core_dims: Sequence[str] | str | None = None,
extra_dims: dict[str, Variable] | None = None,
rng: Variable | None = None,
):
if core_dims is None:
core_dims = ()
if core_dims_needed:
raise ValueError(
f"{core_op.name} needs {core_dims_needed} core_dims to be specified"
)
elif isinstance(core_dims, str):
core_dims = (core_dims,)
if len(core_dims) != core_dims_needed:
raise ValueError(
f"{core_op.name} needs {core_dims_needed} core_dims, but got {len(core_dims)}"
)
full_input_core_dims = tuple(
tuple(core_dims[i] for i in inp_dims_map)
for inp_dims_map in core_inps_dims_map
)
full_output_core_dims = tuple(core_dims[i] for i in core_out_dims_map)
full_core_dims = (full_input_core_dims, full_output_core_dims)
if extra_dims is None:
extra_dims = {}
return XRV(
core_op, core_dims=full_core_dims, extra_dims=tuple(extra_dims.keys())
)(rng, *extra_dims.values(), *params)
return xrv_constructor
bernoulli = _as_xrv(ptr.bernoulli)
beta = _as_xrv(ptr.beta)
betabinom = _as_xrv(ptr.betabinom)
binomial = _as_xrv(ptr.binomial)
categorical = _as_xrv(ptr.categorical)
cauchy = _as_xrv(ptr.cauchy)
dirichlet = _as_xrv(ptr.dirichlet)
exponential = _as_xrv(ptr.exponential)
gamma = _as_xrv(ptr._gamma)
gengamma = _as_xrv(ptr.gengamma)
geometric = _as_xrv(ptr.geometric)
gumbel = _as_xrv(ptr.gumbel)
halfcauchy = _as_xrv(ptr.halfcauchy)
halfnormal = _as_xrv(ptr.halfnormal)
hypergeometric = _as_xrv(ptr.hypergeometric)
integers = _as_xrv(ptr.integers)
invgamma = _as_xrv(ptr.invgamma)
laplace = _as_xrv(ptr.laplace)
logistic = _as_xrv(ptr.logistic)
lognormal = _as_xrv(ptr.lognormal)
multinomial = _as_xrv(ptr.multinomial)
nbinom = negative_binomial = _as_xrv(ptr.negative_binomial)
normal = _as_xrv(ptr.normal)
pareto = _as_xrv(ptr.pareto)
poisson = _as_xrv(ptr.poisson)
t = _as_xrv(ptr.t)
triangular = _as_xrv(ptr.triangular)
truncexpon = _as_xrv(ptr.truncexpon)
uniform = _as_xrv(ptr.uniform)
vonmises = _as_xrv(ptr.vonmises)
wald = _as_xrv(ptr.wald)
weibull = _as_xrv(ptr.weibull)
def multivariate_normal(
mean,
cov,
*,
core_dims: Sequence[str],
extra_dims=None,
rng=None,
method: Literal["cholesky", "svd", "eigh"] = "cholesky",
):
mean = as_xtensor(mean)
if len(core_dims) != 2:
raise ValueError(
f"multivariate_normal requires 2 core_dims, got {len(core_dims)}"
)
# Align core_dims, so that the dim that exists in mean comes before the one that only exists in cov
# This will be the core dimension of the output
if core_dims[0] not in mean.type.dims:
core_dims = core_dims[::-1]
xop = _as_xrv(ptr.MvNormalRV(method=method))
return xop(mean, cov, core_dims=core_dims, extra_dims=extra_dims, rng=rng)
def standard_normal(
extra_dims: dict[str, Variable] | None = None,
rng: Variable | None = None,
):
"""Standard normal random variable."""
return normal(0, 1, extra_dims=extra_dims, rng=rng)
def chisquare(
df,
extra_dims: dict[str, Variable] | None = None,
rng: Variable | None = None,
):
"""Chi-square random variable."""
return gamma(df / 2.0, 2.0, extra_dims=extra_dims, rng=rng)
def rayleigh(
scale,
extra_dims: dict[str, Variable] | None = None,
rng: Variable | None = None,
):
"""Rayleigh random variable."""
df = scale * 0 + 2 # Poor man's broadcasting, to pass dimensions of scale to the RV
return sqrt(chisquare(df, extra_dims=extra_dims, rng=rng)) * scale
from pytensor.graph import node_rewriter from pytensor.graph import node_rewriter
from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.elemwise import Elemwise from pytensor.tensor.elemwise import Elemwise
from pytensor.tensor.random.utils import compute_batch_shape
from pytensor.xtensor.basic import tensor_from_xtensor, xtensor_from_tensor from pytensor.xtensor.basic import tensor_from_xtensor, xtensor_from_tensor
from pytensor.xtensor.rewriting.utils import register_lower_xtensor from pytensor.xtensor.rewriting.utils import register_lower_xtensor
from pytensor.xtensor.vectorization import XBlockwise, XElemwise from pytensor.xtensor.vectorization import XRV, XBlockwise, XElemwise
@register_lower_xtensor @register_lower_xtensor
...@@ -74,3 +75,49 @@ def lower_blockwise(fgraph, node): ...@@ -74,3 +75,49 @@ def lower_blockwise(fgraph, node):
for (tensor_out, old_out) in zip(tensor_outs, node.outputs, strict=True) for (tensor_out, old_out) in zip(tensor_outs, node.outputs, strict=True)
] ]
return new_outs return new_outs
@register_lower_xtensor
@node_rewriter(tracks=[XRV])
def lower_rv(fgraph, node):
op: XRV = node.op
core_op = op.core_op
_, old_out = node.outputs
rng, *extra_dim_lengths_and_params = node.inputs
extra_dim_lengths = extra_dim_lengths_and_params[: len(op.extra_dims)]
params = extra_dim_lengths_and_params[len(op.extra_dims) :]
batch_ndim = old_out.type.ndim - len(op.core_dims[1])
param_batch_dims = old_out.type.dims[len(op.extra_dims) : batch_ndim]
# Convert params Tensors to XTensors, align batch dimensions and place core dimension at the end
tensor_params = []
for inp, core_dims in zip(params, op.core_dims[0]):
inp_dims = inp.type.dims
# Align the batch dims of the input, and place the core dims on the right
batch_order = [
inp_dims.index(batch_dim) if batch_dim in inp_dims else "x"
for batch_dim in param_batch_dims
]
core_order = [inp_dims.index(core_dim) for core_dim in core_dims]
tensor_inp = tensor_from_xtensor(inp).dimshuffle(batch_order + core_order)
tensor_params.append(tensor_inp)
size = None
if op.extra_dims:
# RV size contains the lengths of all batch dimensions, including those coming from the parameters
if tensor_params:
param_batch_shape = tuple(
compute_batch_shape(tensor_params, ndims_params=core_op.ndims_params)
)
else:
param_batch_shape = ()
size = [*extra_dim_lengths, *param_batch_shape]
# RVs are their own core Op
new_next_rng, tensor_out = core_op(*tensor_params, rng=rng, size=size).owner.outputs
# Convert output Tensors to XTensors
new_out = xtensor_from_tensor(tensor_out, dims=old_out.type.dims)
return [new_next_rng, new_out]
...@@ -11,6 +11,7 @@ from pytensor.scalar import discrete_dtypes, upcast ...@@ -11,6 +11,7 @@ from pytensor.scalar import discrete_dtypes, upcast
from pytensor.tensor import as_tensor, get_scalar_constant_value from pytensor.tensor import as_tensor, get_scalar_constant_value
from pytensor.tensor.exceptions import NotScalarConstantError from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.type import integer_dtypes from pytensor.tensor.type import integer_dtypes
from pytensor.tensor.utils import get_static_shape_from_size_variables
from pytensor.xtensor.basic import XOp from pytensor.xtensor.basic import XOp
from pytensor.xtensor.type import as_xtensor, xtensor from pytensor.xtensor.type import as_xtensor, xtensor
...@@ -131,14 +132,9 @@ class UnStack(XOp): ...@@ -131,14 +132,9 @@ class UnStack(XOp):
) )
) )
static_unstacked_lengths = [None] * len(unstacked_lengths) static_unstacked_lengths = get_static_shape_from_size_variables(
for i, length in enumerate(unstacked_lengths): unstacked_lengths
try: )
static_length = get_scalar_constant_value(length)
except NotScalarConstantError:
pass
else:
static_unstacked_lengths[i] = int(static_length)
output = xtensor( output = xtensor(
dtype=x.type.dtype, dtype=x.type.dtype,
......
from itertools import chain from itertools import chain
import numpy as np
from pytensor import scalar as ps from pytensor import scalar as ps
from pytensor import shared
from pytensor.graph import Apply, Op from pytensor.graph import Apply, Op
from pytensor.scalar import discrete_dtypes
from pytensor.tensor import tensor from pytensor.tensor import tensor
from pytensor.tensor.random.op import RNGConsumerOp
from pytensor.tensor.random.type import RandomType
from pytensor.tensor.utils import (
get_static_shape_from_size_variables,
)
from pytensor.xtensor.basic import XOp from pytensor.xtensor.basic import XOp
from pytensor.xtensor.type import as_xtensor, xtensor from pytensor.xtensor.type import as_xtensor, xtensor
...@@ -108,3 +117,139 @@ class XBlockwise(XOp): ...@@ -108,3 +117,139 @@ class XBlockwise(XOp):
for core_out, core_out_dims in zip(core_node.outputs, core_outputs_dims) for core_out, core_out_dims in zip(core_node.outputs, core_outputs_dims)
] ]
return Apply(self, inputs, outputs) return Apply(self, inputs, outputs)
class XRV(XOp, RNGConsumerOp):
"""Wrapper for RandomVariable operations that follows xarray-like broadcasting semantics.
Xarray does not offer random generators, so this class implements a new API.
It mostly works like a gufunc (or XBlockwise), which specifies core dimensions for inputs and output, and
enforces dim-based broadcasting between inputs and output.
It differs from XBlockwise in a couple of ways:
1. It is restricted to one sample output
2. It takes a random generator as the first input and returns the consumed generator as the first output.
3. It has the concept of extra dimensions, which determine extra batch dimensions of the output, that are not
implied by batch dimensions of the parameters.
"""
default_output = 1
__props__ = ("core_op", "core_dims", "extra_dims")
def __init__(
self,
core_op,
core_dims: tuple[tuple[tuple[str, ...], ...], tuple[str, ...]],
extra_dims: tuple[str, ...],
):
super().__init__()
self.core_op = core_op
inps_core_dims, out_core_dims = core_dims
for operand_dims in (*inps_core_dims, out_core_dims):
if len(set(operand_dims)) != len(operand_dims):
raise ValueError(f"Operand has repeated dims {operand_dims}")
self.core_dims = (tuple(i for i in inps_core_dims), tuple(out_core_dims))
if len(set(extra_dims)) != len(extra_dims):
raise ValueError("size_dims must be unique")
self.extra_dims = tuple(extra_dims)
def update(self, node):
# RNG input and update are the first input and output respectively
return {node.inputs[0]: node.outputs[0]}
def make_node(self, rng, *extra_dim_lengths_and_params):
if rng is None:
rng = shared(np.random.default_rng())
elif not isinstance(rng.type, RandomType):
raise TypeError(
"The type of rng should be an instance of RandomGeneratorType "
)
extra_dim_lengths = [
as_xtensor(dim_length).values
for dim_length in extra_dim_lengths_and_params[: len(self.extra_dims)]
]
if not all(
(dim_length.type.ndim == 0 and dim_length.type.dtype in discrete_dtypes)
for dim_length in extra_dim_lengths
):
raise TypeError("All dimension lengths should be scalar discrete dtype.")
params = [
as_xtensor(param)
for param in extra_dim_lengths_and_params[len(self.extra_dims) :]
]
if len(params) != len(self.core_op.ndims_params):
raise ValueError(
f"Expected {len(self.core_op.ndims_params)} parameters + {len(self.extra_dims)} dim_lengths, "
f"got {len(extra_dim_lengths_and_params)}"
)
param_core_dims, output_core_dims = self.core_dims
input_core_dims_set = set(chain.from_iterable(param_core_dims))
# Check parameters don't have core dimensions they shouldn't have
for param, core_param_dims in zip(params, param_core_dims):
if invalid_core_dims := (
set(param.type.dims) - set(core_param_dims)
).intersection(input_core_dims_set):
raise ValueError(
f"Parameter {param} has invalid core dimensions {sorted(invalid_core_dims)}"
)
extra_dims_and_shape = dict(
zip(
self.extra_dims, get_static_shape_from_size_variables(extra_dim_lengths)
)
)
params_dims_and_shape = combine_dims_and_shape(params)
# Check that no parameter dims conflict with size dims
if conflict_dims := set(extra_dims_and_shape).intersection(
params_dims_and_shape
):
raise ValueError(
f"Size dimensions {sorted(conflict_dims)} conflict with parameter dimensions. They should be unique."
)
batch_dims_and_shape = [
(dim, dim_length)
for dim, dim_length in (
extra_dims_and_shape | params_dims_and_shape
).items()
if dim not in input_core_dims_set
]
if batch_dims_and_shape:
batch_output_dims, batch_output_shape = zip(*batch_dims_and_shape)
else:
batch_output_dims, batch_output_shape = (), ()
dummy_core_inputs = []
for param, core_param_dims in zip(params, param_core_dims):
try:
core_static_shape = [
param.type.shape[param.type.dims.index(d)] for d in core_param_dims
]
except ValueError:
raise ValueError(
f"At least one core dim={core_param_dims} missing from input {param} with dims={param.type.dims}"
)
dummy_core_inputs.append(
tensor(dtype=param.type.dtype, shape=core_static_shape)
)
core_node = self.core_op.make_node(rng, None, *dummy_core_inputs)
if not len(core_node.outputs) == 2:
raise NotImplementedError(
"XRandomVariable only supports core ops with two outputs (rng, out)"
)
_, core_out = core_node.outputs
out = xtensor(
dtype=core_out.type.dtype,
shape=batch_output_shape + core_out.type.shape,
dims=batch_output_dims + output_core_dims,
)
return Apply(self, [rng, *extra_dim_lengths, *params], [rng.type(), out])
import inspect
import re
from copy import deepcopy
import numpy as np
import pytest
import pytensor.tensor.random as ptr
import pytensor.xtensor.random as pxr
from pytensor import function, shared
from pytensor.graph import rewrite_graph
from pytensor.graph.basic import equal_computations
from pytensor.tensor import broadcast_arrays, tensor
from pytensor.tensor.random.op import RandomVariable
from pytensor.tensor.random.type import random_generator_type
from pytensor.xtensor import as_xtensor, xtensor
from pytensor.xtensor.random import (
categorical,
multinomial,
multivariate_normal,
normal,
)
from pytensor.xtensor.vectorization import XRV
def lower_rewrite(vars):
return rewrite_graph(
vars,
include=(
"lower_xtensor",
"canonicalize",
),
)
def test_all_basic_rvs_are_wrapped():
# This ignores wrapper functions
pxr_members = {name for name, _ in inspect.getmembers(pxr)}
for name, op in inspect.getmembers(ptr.basic):
if name in "_gamma":
name = "gamma"
if isinstance(op, RandomVariable) and name not in pxr_members:
raise NotImplementedError(f"Variable {name} not implemented as XRV")
def test_updates():
rng = shared(np.random.default_rng(40))
next_rng, draws = normal(0, 1, rng=rng).owner.outputs
fn = function([], [draws], updates=[(rng, next_rng)])
res1, res2 = fn(), fn()
rng = np.random.default_rng(40)
expected_res1, expected_res2 = rng.normal(0, 1), rng.normal(0, 1)
np.testing.assert_allclose(res1, expected_res1)
np.testing.assert_allclose(res2, expected_res2)
def test_zero_inputs():
class ZeroInputRV(RandomVariable):
signature = "->()"
dtype = "floatX"
name = "ZeroInputRV"
@classmethod
def rng_fn(cls, rng, size=None):
return rng.random(size=size)
zero_input_rv = ZeroInputRV()
zero_input_xrv = XRV(zero_input_rv, core_dims=((), ()), extra_dims=["a"])
rng = random_generator_type("rng")
a_size = xtensor("a_size", dims=(), dtype=int)
rv = zero_input_xrv(rng, a_size)
assert rv.type.dims == ("a",)
assert rv.type.shape == (None,)
rng_test = np.random.default_rng(12345)
a_size_val = np.array(5)
np.testing.assert_allclose(
rv.eval({rng: rng_test, a_size: a_size_val}),
rng_test.random(size=(a_size_val,)),
)
def test_output_dim_does_not_map_from_input_dims():
class NewDimRV(RandomVariable):
signature = "()->(p)"
dtype = "floatX"
name = "NewDimRV"
@classmethod
def rng_fn(cls, rng, n, size=None):
r = np.stack([n, n + 1], axis=-1)
if size is None:
return r
return np.broadcast_to(r, (*size, 2))
def _supp_shape_from_params(self, dist_params, param_shapes=None):
return (2,)
new_dim_rv = NewDimRV()
new_dim_xrv = XRV(new_dim_rv, core_dims=(((),), ("p",)), extra_dims=["a"])
a_size = xtensor("a_size", dims=(), dtype=int)
rv = new_dim_xrv(None, a_size, 1)
assert rv.type.dims == ("a", "p")
assert rv.type.shape == (None, 2)
a_size_val = np.array(5)
np.testing.assert_allclose(
rv.eval({a_size: a_size_val}), np.broadcast_to((1, 2), (a_size_val, 2))
)
def test_normal():
rng = random_generator_type("rng")
c_size = tensor("c_size", shape=(), dtype=int)
mu = tensor("mu", shape=(3,))
sigma = tensor("sigma", shape=(2,))
mu_val = np.array([-10, 0.0, 10.0])
sigma_val = np.array([1.0, 10.0])
c_size_val = np.array(5)
rng_val = np.random.default_rng(12345)
c_size_xr = as_xtensor(c_size, name="c_size_xr")
mu_xr = as_xtensor(mu, dims=("mu_dim",), name="mu_xr")
sigma_xr = as_xtensor(sigma, dims=("sigma_dim",), name="sigma_xr")
out = normal(mu_xr, sigma_xr, rng=rng)
assert out.type.dims == ("mu_dim", "sigma_dim")
assert out.type.shape == (3, 2)
assert equal_computations(
[lower_rewrite(out.values)],
[rewrite_graph(ptr.normal(mu[:, None], sigma[None, :], rng=rng))],
)
out_eval = out.eval(
{
mu: mu_val,
sigma: sigma_val,
rng: rng_val,
}
)
out_expected = deepcopy(rng_val).normal(mu_val[:, None], sigma_val[None, :])
np.testing.assert_allclose(out_eval, out_expected)
# Test with batch dimension
out = normal(mu_xr, sigma_xr, extra_dims=dict(c_dim=c_size_xr), rng=rng)
assert out.type.dims == ("c_dim", "mu_dim", "sigma_dim")
assert out.type.shape == (None, 3, 2)
lowered_size = (c_size, *broadcast_arrays(mu[:, None], sigma[None, :])[0].shape)
assert equal_computations(
[lower_rewrite(out.values)],
[
rewrite_graph(
ptr.normal(mu[:, None], sigma[None, :], size=lowered_size, rng=rng)
)
],
)
out_eval = out.eval(
{
mu: mu_val,
sigma: sigma_val,
c_size: c_size_val,
rng: rng_val,
}
)
out_expected = deepcopy(rng_val).normal(
mu_val[:, None],
sigma_val[None, :],
size=(c_size_val, mu_val.shape[0], sigma_val.shape[0]),
)
np.testing.assert_allclose(out_eval, out_expected)
# Test invalid core_dims
with pytest.raises(
ValueError,
match=re.escape("normal needs 0 core_dims, but got 1"),
):
normal(mu_xr, sigma_xr, core_dims=("a",), rng=rng)
# Test Invalid extra_dims (conflicting with existing batch dims)
with pytest.raises(
ValueError,
match=re.escape(
"Size dimensions ['mu_dim'] conflict with parameter dimensions. They should be unique."
),
):
pxr.normal(mu_xr, sigma_xr, extra_dims=dict(mu_dim=c_size_xr), rng=rng)
def test_categorical():
rng = random_generator_type("rng")
p = tensor("p", shape=(2, 3))
c_size = tensor("c", shape=(), dtype=int)
p_xr = as_xtensor(p, dims=("p", "batch_dim"), name="p_xr")
c_size_xr = as_xtensor(c_size, name="c_size_xr")
out = categorical(p_xr, core_dims=("p",), rng=rng)
assert out.type.dims == ("batch_dim",)
assert out.type.shape == (3,)
assert equal_computations(
[lower_rewrite(out.values)], [ptr.categorical(p.T, rng=rng)]
)
np.testing.assert_allclose(
out.eval(
{
p: np.array([[1.0, 0], [0, 1.0], [1.0, 0]]).T,
rng: np.random.default_rng(),
}
),
np.array([0, 1, 0]),
)
out = categorical(
p_xr, core_dims=("p",), extra_dims=dict(cp1=c_size_xr + 1, c=c_size_xr), rng=rng
)
assert out.type.dims == ("cp1", "c", "batch_dim")
assert out.type.shape == (None, None, 3)
assert equal_computations(
[lower_rewrite(out.values)],
[
rewrite_graph(
ptr.categorical(
p.T, size=(1 + c_size, c_size, p[0].shape.squeeze()), rng=rng
)
)
],
)
np.testing.assert_allclose(
out.eval(
{
p: np.array([[1.0, 0.0], [0.0, 1.0], [1.0, 0.0]]).T,
c_size: np.array(5),
rng: np.random.default_rng(),
}
),
np.broadcast_to([0, 1, 0], shape=(6, 5, 3)),
)
# Test invaild core dims
with pytest.raises(
ValueError, match="categorical needs 1 core_dims to be specified"
):
categorical(p_xr, rng=rng)
with pytest.raises(
ValueError,
match=re.escape(
"At least one core dim=('px',) missing from input p_xr with dims=('p', 'batch_dim')"
),
):
categorical(p_xr, core_dims=("px",), rng=rng)
def test_multinomial():
rng = random_generator_type("rng")
n = tensor("n", shape=(2,))
p = tensor("p", shape=(3, None))
c_size = tensor("c", shape=(), dtype=int)
n_xr = as_xtensor(n, dims=("a",), name="a_xr")
p_xr = as_xtensor(p, dims=("p", "a"), name="p_xr")
c_size_xr = as_xtensor(c_size, name="c_size_xr")
a_size_xr = n_xr.sizes["a"]
out = multinomial(n_xr, p_xr, core_dims=("p",), rng=rng)
assert out.type.dims == ("a", "p")
assert out.type.shape == (2, 3)
assert equal_computations(
[lower_rewrite(out.values)],
[ptr.multinomial(n, p.T, size=None, rng=rng)],
)
# Test we can actually evaluate it
np.testing.assert_allclose(
out.eval(
{
n: [5, 10],
p: np.array([[1.0, 0, 0], [0, 0, 1.0]]).T,
rng: np.random.default_rng(),
}
),
np.array([[5, 0, 0], [0, 0, 10]]),
)
out = multinomial(
n_xr, p_xr, core_dims=("p",), extra_dims=dict(c=c_size_xr), rng=rng
)
assert out.type.dims == ("c", "a", "p")
assert equal_computations(
[lower_rewrite(out.values)],
[rewrite_graph(ptr.multinomial(n, p.T, size=(c_size, n.shape[0]), rng=rng))],
)
# Test we can actually evaluate it with extra_dims
np.testing.assert_allclose(
out.eval(
{
n: [5, 10],
p: np.array([[1.0, 0, 0], [0, 0, 1.0]]).T,
c_size: 5,
rng: np.random.default_rng(),
}
),
np.broadcast_to(
[[5, 0, 0], [0, 0, 10]],
shape=(5, 2, 3),
),
)
# Test invalid core_dims
with pytest.raises(
ValueError, match="multinomial needs 1 core_dims to be specified"
):
multinomial(n_xr, p_xr, rng=rng)
with pytest.raises(ValueError, match="multinomial needs 1 core_dims, but got 2"):
multinomial(n_xr, p_xr, core_dims=("p1", "p2"), rng=rng)
with pytest.raises(
ValueError, match=re.escape("Parameter a_xr has invalid core dimensions ['a']")
):
# n cannot have a core dimension
multinomial(n_xr, p_xr, core_dims=("a",), rng=rng)
with pytest.raises(
ValueError,
match=re.escape(
"At least one core dim=('px',) missing from input p_xr with dims=('p', 'a')"
),
):
multinomial(n_xr, p_xr, core_dims=("px",), rng=rng)
# Test invalid extra_dims
with pytest.raises(
ValueError,
match=re.escape(
"Size dimensions ['a'] conflict with parameter dimensions. They should be unique."
),
):
multinomial(
n_xr,
p_xr,
core_dims=("p",),
extra_dims=dict(c=c_size_xr, a=a_size_xr),
rng=rng,
)
def test_multivariate_normal():
rng = random_generator_type("rng")
mu = tensor("mu", shape=(4, 2))
cov = tensor("cov", shape=(2, 3, 2, 4))
mu_xr = as_xtensor(mu, dims=("b1", "rows"), name="mu_xr")
cov_xr = as_xtensor(cov, dims=("cols", "b2", "rows", "b1"), name="cov_xr")
out = multivariate_normal(mu_xr, cov_xr, core_dims=("rows", "cols"), rng=rng)
assert out.type.dims == ("b1", "b2", "rows")
assert out.type.shape == (4, 3, 2)
assert equal_computations(
[lower_rewrite(out.values)],
[ptr.multivariate_normal(mu[:, None], cov.transpose(3, 1, 2, 0), rng=rng)],
)
# Order of core_dims doesn't matter
out = multivariate_normal(mu_xr, cov_xr, core_dims=("cols", "rows"), rng=rng)
assert out.type.dims == ("b1", "b2", "rows")
assert out.type.shape == (4, 3, 2)
assert equal_computations(
[lower_rewrite(out.values)],
[ptr.multivariate_normal(mu[:, None], cov.transpose(3, 1, 2, 0), rng=rng)],
)
# Test method
out = multivariate_normal(
mu_xr, cov_xr, core_dims=("rows", "cols"), rng=rng, method="svd"
)
assert equal_computations(
[lower_rewrite(out.values)],
[
ptr.multivariate_normal(
mu[:, None], cov.transpose(3, 1, 2, 0), rng=rng, method="svd"
)
],
)
# Test invalid core_dims
with pytest.raises(
TypeError,
match=re.escape(
"multivariate_normal() missing 1 required keyword-only argument: 'core_dims'"
),
):
multivariate_normal(mu_xr, cov_xr)
with pytest.raises(
ValueError, match="multivariate_normal requires 2 core_dims, got 3"
):
multivariate_normal(mu_xr, cov_xr, core_dims=("b1", "rows", "cols"))
with pytest.raises(
ValueError, match=re.escape("Operand has repeated dims ('rows', 'rows')")
):
multivariate_normal(mu_xr, cov_xr, core_dims=("rows", "rows"))
with pytest.raises(
ValueError,
match=re.escape("Parameter mu_xr has invalid core dimensions ['b1']"),
):
# mu cannot have two core_dims
multivariate_normal(mu_xr, cov_xr, core_dims=("rows", "b1"))
with pytest.raises(
ValueError,
match=re.escape(
"At least one core dim=('rows', 'missing_cols') missing from input cov_xr with dims=('cols', 'b2', 'rows', 'b1')"
),
):
# cov must have both core_dims
multivariate_normal(mu_xr, cov_xr, core_dims=("rows", "missing_cols"))
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论