提交 7886cf83 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Implement XTensorVariable version of RandomVariables

上级 33d04c36
......@@ -1625,8 +1625,7 @@ class NegBinomialRV(ScipyRandomVariable):
return stats.nbinom.rvs(n, p, size=size, random_state=rng)
nbinom = NegBinomialRV()
negative_binomial = NegBinomialRV()
nbinom = negative_binomial = NegBinomialRV()
class BetaBinomialRV(ScipyRandomVariable):
......@@ -1808,6 +1807,7 @@ class MultinomialRV(RandomVariable):
multinomial = MultinomialRV()
vsearchsorted = np.vectorize(np.searchsorted, otypes=[int], signature="(n),()->()")
......
import abc
import warnings
from collections.abc import Sequence
from copy import deepcopy
......@@ -32,7 +33,20 @@ from pytensor.tensor.utils import _parse_gufunc_signature, safe_signature
from pytensor.tensor.variable import TensorVariable
class RandomVariable(Op):
class RNGConsumerOp(Op):
"""Baseclass for Ops that consume RNGs."""
@abc.abstractmethod
def update(self, node: Apply) -> dict[Variable, Variable]:
"""Symbolic update expression for input RNG variables.
Returns a dictionary with the symbolic expressions required for correct updating
of RNG variables in repeated function evaluations.
"""
pass
class RandomVariable(RNGConsumerOp):
"""An `Op` that produces a sample from a random variable.
This is essentially `RandomFunction`, except that it removes the
......@@ -123,6 +137,9 @@ class RandomVariable(Op):
if self.inplace:
self.destroy_map = {0: [0]}
def update(self, node: Apply) -> dict[Variable, Variable]:
return {node.inputs[0]: node.outputs[0]}
def _supp_shape_from_params(self, dist_params, param_shapes=None):
"""Determine the support shape of a multivariate `RandomVariable`'s output given its parameters.
......
......@@ -759,6 +759,7 @@ def local_remove_useless_assert(fgraph, node):
return [new_var]
@register_infer_shape
@node_rewriter([Assert])
def local_remove_all_assert(fgraph, node):
r"""A rewrite that removes all `Assert`\s from a graph.
......@@ -768,9 +769,6 @@ def local_remove_all_assert(fgraph, node):
See the :ref:`unsafe` section.
"""
if not isinstance(node.op, Assert):
return
return [node.inputs[0]]
......
......@@ -9,6 +9,7 @@ from numpy import nditer
import pytensor
from pytensor.graph import FunctionGraph, Variable
from pytensor.npy_2_compat import normalize_axis_tuple
from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.utils import hash_from_code
......@@ -256,3 +257,31 @@ def faster_ndindex(shape: Sequence[int]):
https://github.com/numpy/numpy/issues/28921
"""
return product(*(range(s) for s in shape))
def get_static_shape_from_size_variables(
size_vars: Sequence[Variable],
) -> tuple[int | None, ...]:
"""Get static shape from size variables.
Parameters
----------
size_vars : Sequence[Variable]
A sequence of variables representing the size of each dimension.
Returns
-------
tuple[int | None, ...]
A tuple containing the static lengths of each dimension, or None if
the length is not statically known.
"""
from pytensor.tensor.basic import get_scalar_constant_value
static_lengths: list[None | int] = [None] * len(size_vars)
for i, length in enumerate(size_vars):
try:
static_length = get_scalar_constant_value(length)
except NotScalarConstantError:
pass
else:
static_lengths[i] = int(static_length)
return tuple(static_lengths)
from collections.abc import Sequence
from functools import wraps
from typing import Literal
import pytensor.tensor.random.basic as ptr
from pytensor.graph.basic import Variable
from pytensor.tensor.random.op import RandomVariable
from pytensor.xtensor import as_xtensor
from pytensor.xtensor.math import sqrt
from pytensor.xtensor.vectorization import XRV
def _as_xrv(
core_op: RandomVariable,
core_inps_dims_map: Sequence[Sequence[int]] | None = None,
core_out_dims_map: Sequence[int] | None = None,
):
"""Helper function to define an XRV constructor.
Parameters
----------
core_op : RandomVariable
The core random variable operation to wrap.
core_inps_dims_map : Sequence[Sequence[int]] | None, optional
A sequence of sequences mapping the core dimensions (specified by the user)
for each input parameter. This is used when lowering to a RandomVariable operation,
to decide the ordering of the core dimensions for each input.
If None, it assumes the core dimensions are positional from left to right.
core_out_dims_map : Sequence[int] | None, optional
A sequence mapping the core dimensions (specified by the user) for the output variable.
This is used when lowering to a RandomVariable operation,
to decide the ordering of the core dimensions for the output.
If None, it assumes the core dimensions are positional from left to right.
"""
if core_inps_dims_map is None:
# Assume core_dims map positionally from left to right
core_inps_dims_map = [tuple(range(ndim)) for ndim in core_op.ndims_params]
if core_out_dims_map is None:
# Assume core_dims map positionally from left to right
core_out_dims_map = tuple(range(core_op.ndim_supp))
core_dims_needed = max(
(*(len(i) for i in core_inps_dims_map), len(core_out_dims_map)), default=0
)
@wraps(core_op)
def xrv_constructor(
*params,
core_dims: Sequence[str] | str | None = None,
extra_dims: dict[str, Variable] | None = None,
rng: Variable | None = None,
):
if core_dims is None:
core_dims = ()
if core_dims_needed:
raise ValueError(
f"{core_op.name} needs {core_dims_needed} core_dims to be specified"
)
elif isinstance(core_dims, str):
core_dims = (core_dims,)
if len(core_dims) != core_dims_needed:
raise ValueError(
f"{core_op.name} needs {core_dims_needed} core_dims, but got {len(core_dims)}"
)
full_input_core_dims = tuple(
tuple(core_dims[i] for i in inp_dims_map)
for inp_dims_map in core_inps_dims_map
)
full_output_core_dims = tuple(core_dims[i] for i in core_out_dims_map)
full_core_dims = (full_input_core_dims, full_output_core_dims)
if extra_dims is None:
extra_dims = {}
return XRV(
core_op, core_dims=full_core_dims, extra_dims=tuple(extra_dims.keys())
)(rng, *extra_dims.values(), *params)
return xrv_constructor
bernoulli = _as_xrv(ptr.bernoulli)
beta = _as_xrv(ptr.beta)
betabinom = _as_xrv(ptr.betabinom)
binomial = _as_xrv(ptr.binomial)
categorical = _as_xrv(ptr.categorical)
cauchy = _as_xrv(ptr.cauchy)
dirichlet = _as_xrv(ptr.dirichlet)
exponential = _as_xrv(ptr.exponential)
gamma = _as_xrv(ptr._gamma)
gengamma = _as_xrv(ptr.gengamma)
geometric = _as_xrv(ptr.geometric)
gumbel = _as_xrv(ptr.gumbel)
halfcauchy = _as_xrv(ptr.halfcauchy)
halfnormal = _as_xrv(ptr.halfnormal)
hypergeometric = _as_xrv(ptr.hypergeometric)
integers = _as_xrv(ptr.integers)
invgamma = _as_xrv(ptr.invgamma)
laplace = _as_xrv(ptr.laplace)
logistic = _as_xrv(ptr.logistic)
lognormal = _as_xrv(ptr.lognormal)
multinomial = _as_xrv(ptr.multinomial)
nbinom = negative_binomial = _as_xrv(ptr.negative_binomial)
normal = _as_xrv(ptr.normal)
pareto = _as_xrv(ptr.pareto)
poisson = _as_xrv(ptr.poisson)
t = _as_xrv(ptr.t)
triangular = _as_xrv(ptr.triangular)
truncexpon = _as_xrv(ptr.truncexpon)
uniform = _as_xrv(ptr.uniform)
vonmises = _as_xrv(ptr.vonmises)
wald = _as_xrv(ptr.wald)
weibull = _as_xrv(ptr.weibull)
def multivariate_normal(
mean,
cov,
*,
core_dims: Sequence[str],
extra_dims=None,
rng=None,
method: Literal["cholesky", "svd", "eigh"] = "cholesky",
):
mean = as_xtensor(mean)
if len(core_dims) != 2:
raise ValueError(
f"multivariate_normal requires 2 core_dims, got {len(core_dims)}"
)
# Align core_dims, so that the dim that exists in mean comes before the one that only exists in cov
# This will be the core dimension of the output
if core_dims[0] not in mean.type.dims:
core_dims = core_dims[::-1]
xop = _as_xrv(ptr.MvNormalRV(method=method))
return xop(mean, cov, core_dims=core_dims, extra_dims=extra_dims, rng=rng)
def standard_normal(
extra_dims: dict[str, Variable] | None = None,
rng: Variable | None = None,
):
"""Standard normal random variable."""
return normal(0, 1, extra_dims=extra_dims, rng=rng)
def chisquare(
df,
extra_dims: dict[str, Variable] | None = None,
rng: Variable | None = None,
):
"""Chi-square random variable."""
return gamma(df / 2.0, 2.0, extra_dims=extra_dims, rng=rng)
def rayleigh(
scale,
extra_dims: dict[str, Variable] | None = None,
rng: Variable | None = None,
):
"""Rayleigh random variable."""
df = scale * 0 + 2 # Poor man's broadcasting, to pass dimensions of scale to the RV
return sqrt(chisquare(df, extra_dims=extra_dims, rng=rng)) * scale
from pytensor.graph import node_rewriter
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.elemwise import Elemwise
from pytensor.tensor.random.utils import compute_batch_shape
from pytensor.xtensor.basic import tensor_from_xtensor, xtensor_from_tensor
from pytensor.xtensor.rewriting.utils import register_lower_xtensor
from pytensor.xtensor.vectorization import XBlockwise, XElemwise
from pytensor.xtensor.vectorization import XRV, XBlockwise, XElemwise
@register_lower_xtensor
......@@ -74,3 +75,49 @@ def lower_blockwise(fgraph, node):
for (tensor_out, old_out) in zip(tensor_outs, node.outputs, strict=True)
]
return new_outs
@register_lower_xtensor
@node_rewriter(tracks=[XRV])
def lower_rv(fgraph, node):
op: XRV = node.op
core_op = op.core_op
_, old_out = node.outputs
rng, *extra_dim_lengths_and_params = node.inputs
extra_dim_lengths = extra_dim_lengths_and_params[: len(op.extra_dims)]
params = extra_dim_lengths_and_params[len(op.extra_dims) :]
batch_ndim = old_out.type.ndim - len(op.core_dims[1])
param_batch_dims = old_out.type.dims[len(op.extra_dims) : batch_ndim]
# Convert params Tensors to XTensors, align batch dimensions and place core dimension at the end
tensor_params = []
for inp, core_dims in zip(params, op.core_dims[0]):
inp_dims = inp.type.dims
# Align the batch dims of the input, and place the core dims on the right
batch_order = [
inp_dims.index(batch_dim) if batch_dim in inp_dims else "x"
for batch_dim in param_batch_dims
]
core_order = [inp_dims.index(core_dim) for core_dim in core_dims]
tensor_inp = tensor_from_xtensor(inp).dimshuffle(batch_order + core_order)
tensor_params.append(tensor_inp)
size = None
if op.extra_dims:
# RV size contains the lengths of all batch dimensions, including those coming from the parameters
if tensor_params:
param_batch_shape = tuple(
compute_batch_shape(tensor_params, ndims_params=core_op.ndims_params)
)
else:
param_batch_shape = ()
size = [*extra_dim_lengths, *param_batch_shape]
# RVs are their own core Op
new_next_rng, tensor_out = core_op(*tensor_params, rng=rng, size=size).owner.outputs
# Convert output Tensors to XTensors
new_out = xtensor_from_tensor(tensor_out, dims=old_out.type.dims)
return [new_next_rng, new_out]
......@@ -11,6 +11,7 @@ from pytensor.scalar import discrete_dtypes, upcast
from pytensor.tensor import as_tensor, get_scalar_constant_value
from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.type import integer_dtypes
from pytensor.tensor.utils import get_static_shape_from_size_variables
from pytensor.xtensor.basic import XOp
from pytensor.xtensor.type import as_xtensor, xtensor
......@@ -131,14 +132,9 @@ class UnStack(XOp):
)
)
static_unstacked_lengths = [None] * len(unstacked_lengths)
for i, length in enumerate(unstacked_lengths):
try:
static_length = get_scalar_constant_value(length)
except NotScalarConstantError:
pass
else:
static_unstacked_lengths[i] = int(static_length)
static_unstacked_lengths = get_static_shape_from_size_variables(
unstacked_lengths
)
output = xtensor(
dtype=x.type.dtype,
......
from itertools import chain
import numpy as np
from pytensor import scalar as ps
from pytensor import shared
from pytensor.graph import Apply, Op
from pytensor.scalar import discrete_dtypes
from pytensor.tensor import tensor
from pytensor.tensor.random.op import RNGConsumerOp
from pytensor.tensor.random.type import RandomType
from pytensor.tensor.utils import (
get_static_shape_from_size_variables,
)
from pytensor.xtensor.basic import XOp
from pytensor.xtensor.type import as_xtensor, xtensor
......@@ -108,3 +117,139 @@ class XBlockwise(XOp):
for core_out, core_out_dims in zip(core_node.outputs, core_outputs_dims)
]
return Apply(self, inputs, outputs)
class XRV(XOp, RNGConsumerOp):
"""Wrapper for RandomVariable operations that follows xarray-like broadcasting semantics.
Xarray does not offer random generators, so this class implements a new API.
It mostly works like a gufunc (or XBlockwise), which specifies core dimensions for inputs and output, and
enforces dim-based broadcasting between inputs and output.
It differs from XBlockwise in a couple of ways:
1. It is restricted to one sample output
2. It takes a random generator as the first input and returns the consumed generator as the first output.
3. It has the concept of extra dimensions, which determine extra batch dimensions of the output, that are not
implied by batch dimensions of the parameters.
"""
default_output = 1
__props__ = ("core_op", "core_dims", "extra_dims")
def __init__(
self,
core_op,
core_dims: tuple[tuple[tuple[str, ...], ...], tuple[str, ...]],
extra_dims: tuple[str, ...],
):
super().__init__()
self.core_op = core_op
inps_core_dims, out_core_dims = core_dims
for operand_dims in (*inps_core_dims, out_core_dims):
if len(set(operand_dims)) != len(operand_dims):
raise ValueError(f"Operand has repeated dims {operand_dims}")
self.core_dims = (tuple(i for i in inps_core_dims), tuple(out_core_dims))
if len(set(extra_dims)) != len(extra_dims):
raise ValueError("size_dims must be unique")
self.extra_dims = tuple(extra_dims)
def update(self, node):
# RNG input and update are the first input and output respectively
return {node.inputs[0]: node.outputs[0]}
def make_node(self, rng, *extra_dim_lengths_and_params):
if rng is None:
rng = shared(np.random.default_rng())
elif not isinstance(rng.type, RandomType):
raise TypeError(
"The type of rng should be an instance of RandomGeneratorType "
)
extra_dim_lengths = [
as_xtensor(dim_length).values
for dim_length in extra_dim_lengths_and_params[: len(self.extra_dims)]
]
if not all(
(dim_length.type.ndim == 0 and dim_length.type.dtype in discrete_dtypes)
for dim_length in extra_dim_lengths
):
raise TypeError("All dimension lengths should be scalar discrete dtype.")
params = [
as_xtensor(param)
for param in extra_dim_lengths_and_params[len(self.extra_dims) :]
]
if len(params) != len(self.core_op.ndims_params):
raise ValueError(
f"Expected {len(self.core_op.ndims_params)} parameters + {len(self.extra_dims)} dim_lengths, "
f"got {len(extra_dim_lengths_and_params)}"
)
param_core_dims, output_core_dims = self.core_dims
input_core_dims_set = set(chain.from_iterable(param_core_dims))
# Check parameters don't have core dimensions they shouldn't have
for param, core_param_dims in zip(params, param_core_dims):
if invalid_core_dims := (
set(param.type.dims) - set(core_param_dims)
).intersection(input_core_dims_set):
raise ValueError(
f"Parameter {param} has invalid core dimensions {sorted(invalid_core_dims)}"
)
extra_dims_and_shape = dict(
zip(
self.extra_dims, get_static_shape_from_size_variables(extra_dim_lengths)
)
)
params_dims_and_shape = combine_dims_and_shape(params)
# Check that no parameter dims conflict with size dims
if conflict_dims := set(extra_dims_and_shape).intersection(
params_dims_and_shape
):
raise ValueError(
f"Size dimensions {sorted(conflict_dims)} conflict with parameter dimensions. They should be unique."
)
batch_dims_and_shape = [
(dim, dim_length)
for dim, dim_length in (
extra_dims_and_shape | params_dims_and_shape
).items()
if dim not in input_core_dims_set
]
if batch_dims_and_shape:
batch_output_dims, batch_output_shape = zip(*batch_dims_and_shape)
else:
batch_output_dims, batch_output_shape = (), ()
dummy_core_inputs = []
for param, core_param_dims in zip(params, param_core_dims):
try:
core_static_shape = [
param.type.shape[param.type.dims.index(d)] for d in core_param_dims
]
except ValueError:
raise ValueError(
f"At least one core dim={core_param_dims} missing from input {param} with dims={param.type.dims}"
)
dummy_core_inputs.append(
tensor(dtype=param.type.dtype, shape=core_static_shape)
)
core_node = self.core_op.make_node(rng, None, *dummy_core_inputs)
if not len(core_node.outputs) == 2:
raise NotImplementedError(
"XRandomVariable only supports core ops with two outputs (rng, out)"
)
_, core_out = core_node.outputs
out = xtensor(
dtype=core_out.type.dtype,
shape=batch_output_shape + core_out.type.shape,
dims=batch_output_dims + output_core_dims,
)
return Apply(self, [rng, *extra_dim_lengths, *params], [rng.type(), out])
差异被折叠。
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论