提交 af88b456 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Don't implement default _supp_shape_from_params.

The errors raised by the default when it fails are rather cryptic Also fix bug in helper function
上级 ecd9c3b8
...@@ -5,10 +5,13 @@ import numpy as np ...@@ -5,10 +5,13 @@ import numpy as np
import scipy.stats as stats import scipy.stats as stats
import pytensor import pytensor
from pytensor.tensor.basic import as_tensor_variable from pytensor.tensor.basic import as_tensor_variable, arange
from pytensor.tensor.random.op import RandomVariable, default_supp_shape_from_params from pytensor.tensor.random.op import RandomVariable
from pytensor.tensor.random.type import RandomGeneratorType, RandomStateType from pytensor.tensor.random.type import RandomGeneratorType, RandomStateType
from pytensor.tensor.random.utils import broadcast_params from pytensor.tensor.random.utils import (
broadcast_params,
supp_shape_from_ref_param_shape,
)
from pytensor.tensor.random.var import ( from pytensor.tensor.random.var import (
RandomGeneratorSharedVariable, RandomGeneratorSharedVariable,
RandomStateSharedVariable, RandomStateSharedVariable,
...@@ -855,6 +858,14 @@ class MvNormalRV(RandomVariable): ...@@ -855,6 +858,14 @@ class MvNormalRV(RandomVariable):
dtype = "floatX" dtype = "floatX"
_print_name = ("MultivariateNormal", "\\operatorname{MultivariateNormal}") _print_name = ("MultivariateNormal", "\\operatorname{MultivariateNormal}")
def _supp_shape_from_params(self, dist_params, param_shapes=None):
return supp_shape_from_ref_param_shape(
ndim_supp=self.ndim_supp,
dist_params=dist_params,
param_shapes=param_shapes,
ref_param_idx=0,
)
def __call__(self, mean=None, cov=None, size=None, **kwargs): def __call__(self, mean=None, cov=None, size=None, **kwargs):
r""" "Draw samples from a multivariate normal distribution. r""" "Draw samples from a multivariate normal distribution.
...@@ -933,6 +944,14 @@ class DirichletRV(RandomVariable): ...@@ -933,6 +944,14 @@ class DirichletRV(RandomVariable):
dtype = "floatX" dtype = "floatX"
_print_name = ("Dirichlet", "\\operatorname{Dirichlet}") _print_name = ("Dirichlet", "\\operatorname{Dirichlet}")
def _supp_shape_from_params(self, dist_params, param_shapes=None):
return supp_shape_from_ref_param_shape(
ndim_supp=self.ndim_supp,
dist_params=dist_params,
param_shapes=param_shapes,
ref_param_idx=0,
)
def __call__(self, alphas, size=None, **kwargs): def __call__(self, alphas, size=None, **kwargs):
r"""Draw samples from a dirichlet distribution. r"""Draw samples from a dirichlet distribution.
...@@ -1776,9 +1795,12 @@ class MultinomialRV(RandomVariable): ...@@ -1776,9 +1795,12 @@ class MultinomialRV(RandomVariable):
""" """
return super().__call__(n, p, size=size, **kwargs) return super().__call__(n, p, size=size, **kwargs)
def _supp_shape_from_params(self, dist_params, rep_param_idx=1, param_shapes=None): def _supp_shape_from_params(self, dist_params, param_shapes=None):
return default_supp_shape_from_params( return supp_shape_from_ref_param_shape(
self.ndim_supp, dist_params, rep_param_idx, param_shapes ndim_supp=self.ndim_supp,
dist_params=dist_params,
param_shapes=param_shapes,
ref_param_idx=1,
) )
@classmethod @classmethod
......
...@@ -24,64 +24,6 @@ from pytensor.tensor.type_other import NoneConst ...@@ -24,64 +24,6 @@ from pytensor.tensor.type_other import NoneConst
from pytensor.tensor.var import TensorVariable from pytensor.tensor.var import TensorVariable
def default_supp_shape_from_params(
ndim_supp: int,
dist_params: Sequence[Variable],
rep_param_idx: int = 0,
param_shapes: Optional[Sequence[Tuple[ScalarVariable, ...]]] = None,
) -> Union[TensorVariable, Tuple[ScalarVariable, ...]]:
"""Infer the dimensions for the output of a `RandomVariable`.
This is a function that derives a random variable's support
shape/dimensions from one of its parameters.
XXX: It's not always possible to determine a random variable's support
shape from its parameters, so this function has fundamentally limited
applicability and must be replaced by custom logic in such cases.
XXX: This function is not expected to handle `ndim_supp = 0` (i.e.
scalars), since that is already definitively handled in the `Op` that
calls this.
TODO: Consider using `pytensor.compile.ops.shape_i` alongside `ShapeFeature`.
Parameters
----------
ndim_supp: int
Total number of dimensions for a single draw of the random variable
(e.g. a multivariate normal draw is 1D, so `ndim_supp = 1`).
dist_params: list of `pytensor.graph.basic.Variable`
The distribution parameters.
rep_param_idx: int (optional)
The index of the distribution parameter to use as a reference
In other words, a parameter in `dist_param` with a shape corresponding
to the support's shape.
The default is the first parameter (i.e. the value 0).
param_shapes: list of tuple of `ScalarVariable` (optional)
Symbolic shapes for each distribution parameter. These will
be used in place of distribution parameter-generated shapes.
Results
-------
out: a tuple representing the support shape for a distribution with the
given `dist_params`.
"""
if ndim_supp <= 0:
raise ValueError("ndim_supp must be greater than 0")
if param_shapes is not None:
ref_param = param_shapes[rep_param_idx]
return (ref_param[-ndim_supp],)
else:
ref_param = dist_params[rep_param_idx]
if ref_param.ndim < ndim_supp:
raise ValueError(
"Reference parameter does not match the "
f"expected dimensions; {ref_param} has less than {ndim_supp} dim(s)."
)
return ref_param.shape[-ndim_supp:]
class RandomVariable(Op): class RandomVariable(Op):
"""An `Op` that produces a sample from a random variable. """An `Op` that produces a sample from a random variable.
...@@ -151,15 +93,29 @@ class RandomVariable(Op): ...@@ -151,15 +93,29 @@ class RandomVariable(Op):
if self.inplace: if self.inplace:
self.destroy_map = {0: [0]} self.destroy_map = {0: [0]}
def _supp_shape_from_params(self, dist_params, **kwargs): def _supp_shape_from_params(self, dist_params, param_shapes=None):
"""Determine the support shape of a `RandomVariable`'s output given its parameters. """Determine the support shape of a multivariate `RandomVariable`'s output given its parameters.
This does *not* consider the extra dimensions added by the `size` parameter This does *not* consider the extra dimensions added by the `size` parameter
or independent (batched) parameters. or independent (batched) parameters.
Defaults to `param_supp_shape_fn`. When provided, `param_shapes` should be given preference over `[d.shape for d in dist_params]`,
as it will avoid redundancies in PyTensor shape inference.
Examples
--------
Common multivariate `RandomVariable`s derive their support shapes implicitly from the
last dimension of some of their parameters. For example `multivariate_normal` support shape
corresponds to the last dimension of the mean or covariance parameters, `support_shape=(mu.shape[-1])`.
For this case the helper `pytensor.tensor.random.utils.supp_shape_from_ref_param_shape` can be used.
Other variables have fixed support shape such as `support_shape=(2,)` or it is determined by the
values (not shapes) of some parameters. For instance, a `gaussian_random_walk(steps, size=(2,))`,
might have `support_shape=(steps,)`.
""" """
return default_supp_shape_from_params(self.ndim_supp, dist_params, **kwargs) raise NotImplementedError(
"`_supp_shape_from_params` must be implemented for multivariate RVs"
)
def rng_fn(self, rng, *args, **kwargs) -> Union[int, float, np.ndarray]: def rng_fn(self, rng, *args, **kwargs) -> Union[int, float, np.ndarray]:
"""Sample a numeric random variate.""" """Sample a numeric random variate."""
......
from collections.abc import Sequence
from functools import wraps from functools import wraps
from itertools import zip_longest from itertools import zip_longest
from types import ModuleType from types import ModuleType
from typing import TYPE_CHECKING, Literal, Optional, Union from typing import TYPE_CHECKING, Literal, Optional, Sequence, Tuple, Union
import numpy as np import numpy as np
from pytensor.compile.sharedvalue import shared from pytensor.compile.sharedvalue import shared
from pytensor.graph.basic import Constant, Variable from pytensor.graph.basic import Constant, Variable
from pytensor.scalar import ScalarVariable
from pytensor.tensor import get_vector_length from pytensor.tensor import get_vector_length
from pytensor.tensor.basic import as_tensor_variable, cast, constant from pytensor.tensor.basic import as_tensor_variable, cast, constant
from pytensor.tensor.extra_ops import broadcast_to from pytensor.tensor.extra_ops import broadcast_to
...@@ -285,3 +285,50 @@ class RandomStream: ...@@ -285,3 +285,50 @@ class RandomStream:
rng.default_update = new_rng rng.default_update = new_rng
return out return out
def supp_shape_from_ref_param_shape(
*,
ndim_supp: int,
dist_params: Sequence[Variable],
param_shapes: Optional[Sequence[Tuple[ScalarVariable, ...]]] = None,
ref_param_idx: int,
) -> Union[TensorVariable, Tuple[ScalarVariable, ...]]:
"""Extract the support shape of a multivariate `RandomVariable` from the shape of a reference parameter.
Several multivariate `RandomVariable`s have a support shape determined by the last dimensions of a parameter.
For example `multivariate_normal(zeros(5, 3), eye(3)) has a support shape of (3,) that is determined by the
last dimension of the mean or the covariance.
Parameters
----------
ndim_supp: int
Support dimensionality of the `RandomVariable`.
(e.g. a multivariate normal draw is 1D, so `ndim_supp = 1`).
dist_params: list of `pytensor.graph.basic.Variable`
The distribution parameters.
param_shapes: list of tuple of `ScalarVariable` (optional)
Symbolic shapes for each distribution parameter. These will
be used in place of distribution parameter-generated shapes.
ref_param_idx: int
The index of the distribution parameter to use as a reference
Returns
-------
out: tuple
Representing the support shape for a `RandomVariable` with the given `dist_params`.
"""
if ndim_supp <= 0:
raise ValueError("ndim_supp must be greater than 0")
if param_shapes is not None:
ref_param = param_shapes[ref_param_idx]
return (ref_param[-ndim_supp],)
else:
ref_param = dist_params[ref_param_idx]
if ref_param.ndim < ndim_supp:
raise ValueError(
"Reference parameter does not match the expected dimensions; "
f"{ref_param} has less than {ndim_supp} dim(s)."
)
return ref_param.shape[-ndim_supp:]
...@@ -6,12 +6,7 @@ from pytensor import config, function ...@@ -6,12 +6,7 @@ from pytensor import config, function
from pytensor.gradient import NullTypeGradError, grad from pytensor.gradient import NullTypeGradError, grad
from pytensor.raise_op import Assert from pytensor.raise_op import Assert
from pytensor.tensor.math import eq from pytensor.tensor.math import eq
from pytensor.tensor.random.op import ( from pytensor.tensor.random.op import RandomState, RandomVariable, default_rng
RandomState,
RandomVariable,
default_rng,
default_supp_shape_from_params,
)
from pytensor.tensor.shape import specify_shape from pytensor.tensor.shape import specify_shape
from pytensor.tensor.type import all_dtypes, iscalar, tensor from pytensor.tensor.type import all_dtypes, iscalar, tensor
...@@ -22,29 +17,6 @@ def set_pytensor_flags(): ...@@ -22,29 +17,6 @@ def set_pytensor_flags():
yield yield
def test_default_supp_shape_from_params():
with pytest.raises(ValueError, match="^ndim_supp*"):
default_supp_shape_from_params(0, (np.array([1, 2]), 0))
res = default_supp_shape_from_params(
1, (np.array([1, 2]), np.eye(2)), rep_param_idx=0
)
assert res == (2,)
res = default_supp_shape_from_params(
1, (np.array([1, 2]), 0), param_shapes=((2,), ())
)
assert res == (2,)
with pytest.raises(ValueError, match="^Reference parameter*"):
default_supp_shape_from_params(1, (np.array(1),), rep_param_idx=0)
res = default_supp_shape_from_params(
2, (np.array([1, 2]), np.ones((2, 3, 4))), rep_param_idx=1
)
assert res == (3, 4)
def test_RandomVariable_basics(): def test_RandomVariable_basics():
str_res = str( str_res = str(
RandomVariable( RandomVariable(
......
...@@ -4,7 +4,11 @@ import pytest ...@@ -4,7 +4,11 @@ import pytest
from pytensor import config, function from pytensor import config, function
from pytensor.compile.mode import Mode from pytensor.compile.mode import Mode
from pytensor.graph.rewriting.db import RewriteDatabaseQuery from pytensor.graph.rewriting.db import RewriteDatabaseQuery
from pytensor.tensor.random.utils import RandomStream, broadcast_params from pytensor.tensor.random.utils import (
RandomStream,
broadcast_params,
supp_shape_from_ref_param_shape,
)
from pytensor.tensor.type import matrix, tensor from pytensor.tensor.type import matrix, tensor
from tests import unittest_tools as utt from tests import unittest_tools as utt
...@@ -271,3 +275,41 @@ class TestSharedRandomStream: ...@@ -271,3 +275,41 @@ class TestSharedRandomStream:
su2[0].set_value(su1[0].get_value()) su2[0].set_value(su1[0].get_value())
np.testing.assert_array_almost_equal(f1(), f2(), decimal=6) np.testing.assert_array_almost_equal(f1(), f2(), decimal=6)
def test_supp_shape_from_ref_param_shape():
with pytest.raises(ValueError, match="^ndim_supp*"):
supp_shape_from_ref_param_shape(
ndim_supp=0,
dist_params=(np.array([1, 2]), 0),
ref_param_idx=0,
)
res = supp_shape_from_ref_param_shape(
ndim_supp=1,
dist_params=(np.array([1, 2]), np.eye(2)),
ref_param_idx=0,
)
assert res == (2,)
res = supp_shape_from_ref_param_shape(
ndim_supp=1,
dist_params=(np.array([1, 2]), 0),
param_shapes=((2,), ()),
ref_param_idx=0,
)
assert res == (2,)
with pytest.raises(ValueError, match="^Reference parameter*"):
supp_shape_from_ref_param_shape(
ndim_supp=1,
dist_params=(np.array(1),),
ref_param_idx=0,
)
res = supp_shape_from_ref_param_shape(
ndim_supp=2,
dist_params=(np.array([1, 2]), np.ones((2, 3, 4))),
ref_param_idx=1,
)
assert res == (3, 4)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论