提交 af88b456 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Don't implement default _supp_shape_from_params.

The errors raised by the default when it fails are rather cryptic Also fix bug in helper function
上级 ecd9c3b8
......@@ -5,10 +5,13 @@ import numpy as np
import scipy.stats as stats
import pytensor
from pytensor.tensor.basic import as_tensor_variable
from pytensor.tensor.random.op import RandomVariable, default_supp_shape_from_params
from pytensor.tensor.basic import as_tensor_variable, arange
from pytensor.tensor.random.op import RandomVariable
from pytensor.tensor.random.type import RandomGeneratorType, RandomStateType
from pytensor.tensor.random.utils import broadcast_params
from pytensor.tensor.random.utils import (
broadcast_params,
supp_shape_from_ref_param_shape,
)
from pytensor.tensor.random.var import (
RandomGeneratorSharedVariable,
RandomStateSharedVariable,
......@@ -855,6 +858,14 @@ class MvNormalRV(RandomVariable):
dtype = "floatX"
_print_name = ("MultivariateNormal", "\\operatorname{MultivariateNormal}")
def _supp_shape_from_params(self, dist_params, param_shapes=None):
return supp_shape_from_ref_param_shape(
ndim_supp=self.ndim_supp,
dist_params=dist_params,
param_shapes=param_shapes,
ref_param_idx=0,
)
def __call__(self, mean=None, cov=None, size=None, **kwargs):
r""" "Draw samples from a multivariate normal distribution.
......@@ -933,6 +944,14 @@ class DirichletRV(RandomVariable):
dtype = "floatX"
_print_name = ("Dirichlet", "\\operatorname{Dirichlet}")
def _supp_shape_from_params(self, dist_params, param_shapes=None):
return supp_shape_from_ref_param_shape(
ndim_supp=self.ndim_supp,
dist_params=dist_params,
param_shapes=param_shapes,
ref_param_idx=0,
)
def __call__(self, alphas, size=None, **kwargs):
r"""Draw samples from a dirichlet distribution.
......@@ -1776,9 +1795,12 @@ class MultinomialRV(RandomVariable):
"""
return super().__call__(n, p, size=size, **kwargs)
def _supp_shape_from_params(self, dist_params, rep_param_idx=1, param_shapes=None):
return default_supp_shape_from_params(
self.ndim_supp, dist_params, rep_param_idx, param_shapes
def _supp_shape_from_params(self, dist_params, param_shapes=None):
return supp_shape_from_ref_param_shape(
ndim_supp=self.ndim_supp,
dist_params=dist_params,
param_shapes=param_shapes,
ref_param_idx=1,
)
@classmethod
......
......@@ -24,64 +24,6 @@ from pytensor.tensor.type_other import NoneConst
from pytensor.tensor.var import TensorVariable
def default_supp_shape_from_params(
ndim_supp: int,
dist_params: Sequence[Variable],
rep_param_idx: int = 0,
param_shapes: Optional[Sequence[Tuple[ScalarVariable, ...]]] = None,
) -> Union[TensorVariable, Tuple[ScalarVariable, ...]]:
"""Infer the dimensions for the output of a `RandomVariable`.
This is a function that derives a random variable's support
shape/dimensions from one of its parameters.
XXX: It's not always possible to determine a random variable's support
shape from its parameters, so this function has fundamentally limited
applicability and must be replaced by custom logic in such cases.
XXX: This function is not expected to handle `ndim_supp = 0` (i.e.
scalars), since that is already definitively handled in the `Op` that
calls this.
TODO: Consider using `pytensor.compile.ops.shape_i` alongside `ShapeFeature`.
Parameters
----------
ndim_supp: int
Total number of dimensions for a single draw of the random variable
(e.g. a multivariate normal draw is 1D, so `ndim_supp = 1`).
dist_params: list of `pytensor.graph.basic.Variable`
The distribution parameters.
rep_param_idx: int (optional)
The index of the distribution parameter to use as a reference
In other words, a parameter in `dist_param` with a shape corresponding
to the support's shape.
The default is the first parameter (i.e. the value 0).
param_shapes: list of tuple of `ScalarVariable` (optional)
Symbolic shapes for each distribution parameter. These will
be used in place of distribution parameter-generated shapes.
Results
-------
out: a tuple representing the support shape for a distribution with the
given `dist_params`.
"""
if ndim_supp <= 0:
raise ValueError("ndim_supp must be greater than 0")
if param_shapes is not None:
ref_param = param_shapes[rep_param_idx]
return (ref_param[-ndim_supp],)
else:
ref_param = dist_params[rep_param_idx]
if ref_param.ndim < ndim_supp:
raise ValueError(
"Reference parameter does not match the "
f"expected dimensions; {ref_param} has less than {ndim_supp} dim(s)."
)
return ref_param.shape[-ndim_supp:]
class RandomVariable(Op):
"""An `Op` that produces a sample from a random variable.
......@@ -151,15 +93,29 @@ class RandomVariable(Op):
if self.inplace:
self.destroy_map = {0: [0]}
def _supp_shape_from_params(self, dist_params, **kwargs):
"""Determine the support shape of a `RandomVariable`'s output given its parameters.
def _supp_shape_from_params(self, dist_params, param_shapes=None):
"""Determine the support shape of a multivariate `RandomVariable`'s output given its parameters.
This does *not* consider the extra dimensions added by the `size` parameter
or independent (batched) parameters.
Defaults to `param_supp_shape_fn`.
When provided, `param_shapes` should be given preference over `[d.shape for d in dist_params]`,
as it will avoid redundancies in PyTensor shape inference.
Examples
--------
Common multivariate `RandomVariable`s derive their support shapes implicitly from the
last dimension of some of their parameters. For example `multivariate_normal` support shape
corresponds to the last dimension of the mean or covariance parameters, `support_shape=(mu.shape[-1])`.
For this case the helper `pytensor.tensor.random.utils.supp_shape_from_ref_param_shape` can be used.
Other variables have fixed support shape such as `support_shape=(2,)` or it is determined by the
values (not shapes) of some parameters. For instance, a `gaussian_random_walk(steps, size=(2,))`,
might have `support_shape=(steps,)`.
"""
return default_supp_shape_from_params(self.ndim_supp, dist_params, **kwargs)
raise NotImplementedError(
"`_supp_shape_from_params` must be implemented for multivariate RVs"
)
def rng_fn(self, rng, *args, **kwargs) -> Union[int, float, np.ndarray]:
"""Sample a numeric random variate."""
......
from collections.abc import Sequence
from functools import wraps
from itertools import zip_longest
from types import ModuleType
from typing import TYPE_CHECKING, Literal, Optional, Union
from typing import TYPE_CHECKING, Literal, Optional, Sequence, Tuple, Union
import numpy as np
from pytensor.compile.sharedvalue import shared
from pytensor.graph.basic import Constant, Variable
from pytensor.scalar import ScalarVariable
from pytensor.tensor import get_vector_length
from pytensor.tensor.basic import as_tensor_variable, cast, constant
from pytensor.tensor.extra_ops import broadcast_to
......@@ -285,3 +285,50 @@ class RandomStream:
rng.default_update = new_rng
return out
def supp_shape_from_ref_param_shape(
*,
ndim_supp: int,
dist_params: Sequence[Variable],
param_shapes: Optional[Sequence[Tuple[ScalarVariable, ...]]] = None,
ref_param_idx: int,
) -> Union[TensorVariable, Tuple[ScalarVariable, ...]]:
"""Extract the support shape of a multivariate `RandomVariable` from the shape of a reference parameter.
Several multivariate `RandomVariable`s have a support shape determined by the last dimensions of a parameter.
For example `multivariate_normal(zeros(5, 3), eye(3)) has a support shape of (3,) that is determined by the
last dimension of the mean or the covariance.
Parameters
----------
ndim_supp: int
Support dimensionality of the `RandomVariable`.
(e.g. a multivariate normal draw is 1D, so `ndim_supp = 1`).
dist_params: list of `pytensor.graph.basic.Variable`
The distribution parameters.
param_shapes: list of tuple of `ScalarVariable` (optional)
Symbolic shapes for each distribution parameter. These will
be used in place of distribution parameter-generated shapes.
ref_param_idx: int
The index of the distribution parameter to use as a reference
Returns
-------
out: tuple
Representing the support shape for a `RandomVariable` with the given `dist_params`.
"""
if ndim_supp <= 0:
raise ValueError("ndim_supp must be greater than 0")
if param_shapes is not None:
ref_param = param_shapes[ref_param_idx]
return (ref_param[-ndim_supp],)
else:
ref_param = dist_params[ref_param_idx]
if ref_param.ndim < ndim_supp:
raise ValueError(
"Reference parameter does not match the expected dimensions; "
f"{ref_param} has less than {ndim_supp} dim(s)."
)
return ref_param.shape[-ndim_supp:]
......@@ -6,12 +6,7 @@ from pytensor import config, function
from pytensor.gradient import NullTypeGradError, grad
from pytensor.raise_op import Assert
from pytensor.tensor.math import eq
from pytensor.tensor.random.op import (
RandomState,
RandomVariable,
default_rng,
default_supp_shape_from_params,
)
from pytensor.tensor.random.op import RandomState, RandomVariable, default_rng
from pytensor.tensor.shape import specify_shape
from pytensor.tensor.type import all_dtypes, iscalar, tensor
......@@ -22,29 +17,6 @@ def set_pytensor_flags():
yield
def test_default_supp_shape_from_params():
with pytest.raises(ValueError, match="^ndim_supp*"):
default_supp_shape_from_params(0, (np.array([1, 2]), 0))
res = default_supp_shape_from_params(
1, (np.array([1, 2]), np.eye(2)), rep_param_idx=0
)
assert res == (2,)
res = default_supp_shape_from_params(
1, (np.array([1, 2]), 0), param_shapes=((2,), ())
)
assert res == (2,)
with pytest.raises(ValueError, match="^Reference parameter*"):
default_supp_shape_from_params(1, (np.array(1),), rep_param_idx=0)
res = default_supp_shape_from_params(
2, (np.array([1, 2]), np.ones((2, 3, 4))), rep_param_idx=1
)
assert res == (3, 4)
def test_RandomVariable_basics():
str_res = str(
RandomVariable(
......
......@@ -4,7 +4,11 @@ import pytest
from pytensor import config, function
from pytensor.compile.mode import Mode
from pytensor.graph.rewriting.db import RewriteDatabaseQuery
from pytensor.tensor.random.utils import RandomStream, broadcast_params
from pytensor.tensor.random.utils import (
RandomStream,
broadcast_params,
supp_shape_from_ref_param_shape,
)
from pytensor.tensor.type import matrix, tensor
from tests import unittest_tools as utt
......@@ -271,3 +275,41 @@ class TestSharedRandomStream:
su2[0].set_value(su1[0].get_value())
np.testing.assert_array_almost_equal(f1(), f2(), decimal=6)
def test_supp_shape_from_ref_param_shape():
with pytest.raises(ValueError, match="^ndim_supp*"):
supp_shape_from_ref_param_shape(
ndim_supp=0,
dist_params=(np.array([1, 2]), 0),
ref_param_idx=0,
)
res = supp_shape_from_ref_param_shape(
ndim_supp=1,
dist_params=(np.array([1, 2]), np.eye(2)),
ref_param_idx=0,
)
assert res == (2,)
res = supp_shape_from_ref_param_shape(
ndim_supp=1,
dist_params=(np.array([1, 2]), 0),
param_shapes=((2,), ()),
ref_param_idx=0,
)
assert res == (2,)
with pytest.raises(ValueError, match="^Reference parameter*"):
supp_shape_from_ref_param_shape(
ndim_supp=1,
dist_params=(np.array(1),),
ref_param_idx=0,
)
res = supp_shape_from_ref_param_shape(
ndim_supp=2,
dist_params=(np.array([1, 2]), np.ones((2, 3, 4))),
ref_param_idx=1,
)
assert res == (3, 4)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论