提交 3e9c6a4f authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Introduce signature instead of ndim_supp and ndims_params

上级 a576fa2c
...@@ -13,7 +13,6 @@ from pytensor.tensor.random.type import RandomGeneratorType, RandomStateType ...@@ -13,7 +13,6 @@ from pytensor.tensor.random.type import RandomGeneratorType, RandomStateType
from pytensor.tensor.random.utils import ( from pytensor.tensor.random.utils import (
broadcast_params, broadcast_params,
normalize_size_param, normalize_size_param,
supp_shape_from_ref_param_shape,
) )
from pytensor.tensor.random.var import ( from pytensor.tensor.random.var import (
RandomGeneratorSharedVariable, RandomGeneratorSharedVariable,
...@@ -91,8 +90,7 @@ class UniformRV(RandomVariable): ...@@ -91,8 +90,7 @@ class UniformRV(RandomVariable):
""" """
name = "uniform" name = "uniform"
ndim_supp = 0 signature = "(),()->()"
ndims_params = [0, 0]
dtype = "floatX" dtype = "floatX"
_print_name = ("Uniform", "\\operatorname{Uniform}") _print_name = ("Uniform", "\\operatorname{Uniform}")
...@@ -146,8 +144,7 @@ class TriangularRV(RandomVariable): ...@@ -146,8 +144,7 @@ class TriangularRV(RandomVariable):
""" """
name = "triangular" name = "triangular"
ndim_supp = 0 signature = "(),(),()->()"
ndims_params = [0, 0, 0]
dtype = "floatX" dtype = "floatX"
_print_name = ("Triangular", "\\operatorname{Triangular}") _print_name = ("Triangular", "\\operatorname{Triangular}")
...@@ -202,8 +199,7 @@ class BetaRV(RandomVariable): ...@@ -202,8 +199,7 @@ class BetaRV(RandomVariable):
""" """
name = "beta" name = "beta"
ndim_supp = 0 signature = "(),()->()"
ndims_params = [0, 0]
dtype = "floatX" dtype = "floatX"
_print_name = ("Beta", "\\operatorname{Beta}") _print_name = ("Beta", "\\operatorname{Beta}")
...@@ -249,8 +245,7 @@ class NormalRV(RandomVariable): ...@@ -249,8 +245,7 @@ class NormalRV(RandomVariable):
""" """
name = "normal" name = "normal"
ndim_supp = 0 signature = "(),()->()"
ndims_params = [0, 0]
dtype = "floatX" dtype = "floatX"
_print_name = ("Normal", "\\operatorname{Normal}") _print_name = ("Normal", "\\operatorname{Normal}")
...@@ -316,8 +311,7 @@ class HalfNormalRV(ScipyRandomVariable): ...@@ -316,8 +311,7 @@ class HalfNormalRV(ScipyRandomVariable):
""" """
name = "halfnormal" name = "halfnormal"
ndim_supp = 0 signature = "(),()->()"
ndims_params = [0, 0]
dtype = "floatX" dtype = "floatX"
_print_name = ("HalfNormal", "\\operatorname{HalfNormal}") _print_name = ("HalfNormal", "\\operatorname{HalfNormal}")
...@@ -382,8 +376,7 @@ class LogNormalRV(RandomVariable): ...@@ -382,8 +376,7 @@ class LogNormalRV(RandomVariable):
""" """
name = "lognormal" name = "lognormal"
ndim_supp = 0 signature = "(),()->()"
ndims_params = [0, 0]
dtype = "floatX" dtype = "floatX"
_print_name = ("LogNormal", "\\operatorname{LogNormal}") _print_name = ("LogNormal", "\\operatorname{LogNormal}")
...@@ -434,8 +427,7 @@ class GammaRV(RandomVariable): ...@@ -434,8 +427,7 @@ class GammaRV(RandomVariable):
""" """
name = "gamma" name = "gamma"
ndim_supp = 0 signature = "(),()->()"
ndims_params = [0, 0]
dtype = "floatX" dtype = "floatX"
_print_name = ("Gamma", "\\operatorname{Gamma}") _print_name = ("Gamma", "\\operatorname{Gamma}")
...@@ -567,8 +559,7 @@ class ParetoRV(ScipyRandomVariable): ...@@ -567,8 +559,7 @@ class ParetoRV(ScipyRandomVariable):
""" """
name = "pareto" name = "pareto"
ndim_supp = 0 signature = "(),()->()"
ndims_params = [0, 0]
dtype = "floatX" dtype = "floatX"
_print_name = ("Pareto", "\\operatorname{Pareto}") _print_name = ("Pareto", "\\operatorname{Pareto}")
...@@ -618,8 +609,7 @@ class GumbelRV(ScipyRandomVariable): ...@@ -618,8 +609,7 @@ class GumbelRV(ScipyRandomVariable):
""" """
name = "gumbel" name = "gumbel"
ndim_supp = 0 signature = "(),()->()"
ndims_params = [0, 0]
dtype = "floatX" dtype = "floatX"
_print_name = ("Gumbel", "\\operatorname{Gumbel}") _print_name = ("Gumbel", "\\operatorname{Gumbel}")
...@@ -680,8 +670,7 @@ class ExponentialRV(RandomVariable): ...@@ -680,8 +670,7 @@ class ExponentialRV(RandomVariable):
""" """
name = "exponential" name = "exponential"
ndim_supp = 0 signature = "()->()"
ndims_params = [0]
dtype = "floatX" dtype = "floatX"
_print_name = ("Exponential", "\\operatorname{Exponential}") _print_name = ("Exponential", "\\operatorname{Exponential}")
...@@ -724,8 +713,7 @@ class WeibullRV(RandomVariable): ...@@ -724,8 +713,7 @@ class WeibullRV(RandomVariable):
""" """
name = "weibull" name = "weibull"
ndim_supp = 0 signature = "()->()"
ndims_params = [0]
dtype = "floatX" dtype = "floatX"
_print_name = ("Weibull", "\\operatorname{Weibull}") _print_name = ("Weibull", "\\operatorname{Weibull}")
...@@ -769,8 +757,7 @@ class LogisticRV(RandomVariable): ...@@ -769,8 +757,7 @@ class LogisticRV(RandomVariable):
""" """
name = "logistic" name = "logistic"
ndim_supp = 0 signature = "(),()->()"
ndims_params = [0, 0]
dtype = "floatX" dtype = "floatX"
_print_name = ("Logistic", "\\operatorname{Logistic}") _print_name = ("Logistic", "\\operatorname{Logistic}")
...@@ -818,8 +805,7 @@ class VonMisesRV(RandomVariable): ...@@ -818,8 +805,7 @@ class VonMisesRV(RandomVariable):
""" """
name = "vonmises" name = "vonmises"
ndim_supp = 0 signature = "(),()->()"
ndims_params = [0, 0]
dtype = "floatX" dtype = "floatX"
_print_name = ("VonMises", "\\operatorname{VonMises}") _print_name = ("VonMises", "\\operatorname{VonMises}")
...@@ -886,19 +872,10 @@ class MvNormalRV(RandomVariable): ...@@ -886,19 +872,10 @@ class MvNormalRV(RandomVariable):
""" """
name = "multivariate_normal" name = "multivariate_normal"
ndim_supp = 1 signature = "(n),(n,n)->(n)"
ndims_params = [1, 2]
dtype = "floatX" dtype = "floatX"
_print_name = ("MultivariateNormal", "\\operatorname{MultivariateNormal}") _print_name = ("MultivariateNormal", "\\operatorname{MultivariateNormal}")
def _supp_shape_from_params(self, dist_params, param_shapes=None):
return supp_shape_from_ref_param_shape(
ndim_supp=self.ndim_supp,
dist_params=dist_params,
param_shapes=param_shapes,
ref_param_idx=0,
)
def __call__(self, mean=None, cov=None, size=None, **kwargs): def __call__(self, mean=None, cov=None, size=None, **kwargs):
r""" "Draw samples from a multivariate normal distribution. r""" "Draw samples from a multivariate normal distribution.
...@@ -942,7 +919,7 @@ class MvNormalRV(RandomVariable): ...@@ -942,7 +919,7 @@ class MvNormalRV(RandomVariable):
mean = np.broadcast_to(mean, size + mean.shape[-1:]) mean = np.broadcast_to(mean, size + mean.shape[-1:])
cov = np.broadcast_to(cov, size + cov.shape[-2:]) cov = np.broadcast_to(cov, size + cov.shape[-2:])
else: else:
mean, cov = broadcast_params([mean, cov], cls.ndims_params) mean, cov = broadcast_params([mean, cov], [1, 2])
res = np.empty(mean.shape) res = np.empty(mean.shape)
for idx in np.ndindex(mean.shape[:-1]): for idx in np.ndindex(mean.shape[:-1]):
...@@ -973,19 +950,10 @@ class DirichletRV(RandomVariable): ...@@ -973,19 +950,10 @@ class DirichletRV(RandomVariable):
""" """
name = "dirichlet" name = "dirichlet"
ndim_supp = 1 signature = "(a)->(a)"
ndims_params = [1]
dtype = "floatX" dtype = "floatX"
_print_name = ("Dirichlet", "\\operatorname{Dirichlet}") _print_name = ("Dirichlet", "\\operatorname{Dirichlet}")
def _supp_shape_from_params(self, dist_params, param_shapes=None):
return supp_shape_from_ref_param_shape(
ndim_supp=self.ndim_supp,
dist_params=dist_params,
param_shapes=param_shapes,
ref_param_idx=0,
)
def __call__(self, alphas, size=None, **kwargs): def __call__(self, alphas, size=None, **kwargs):
r"""Draw samples from a dirichlet distribution. r"""Draw samples from a dirichlet distribution.
...@@ -1047,8 +1015,7 @@ class PoissonRV(RandomVariable): ...@@ -1047,8 +1015,7 @@ class PoissonRV(RandomVariable):
""" """
name = "poisson" name = "poisson"
ndim_supp = 0 signature = "()->()"
ndims_params = [0]
dtype = "int64" dtype = "int64"
_print_name = ("Poisson", "\\operatorname{Poisson}") _print_name = ("Poisson", "\\operatorname{Poisson}")
...@@ -1093,8 +1060,7 @@ class GeometricRV(RandomVariable): ...@@ -1093,8 +1060,7 @@ class GeometricRV(RandomVariable):
""" """
name = "geometric" name = "geometric"
ndim_supp = 0 signature = "()->()"
ndims_params = [0]
dtype = "int64" dtype = "int64"
_print_name = ("Geometric", "\\operatorname{Geometric}") _print_name = ("Geometric", "\\operatorname{Geometric}")
...@@ -1136,8 +1102,7 @@ class HyperGeometricRV(RandomVariable): ...@@ -1136,8 +1102,7 @@ class HyperGeometricRV(RandomVariable):
""" """
name = "hypergeometric" name = "hypergeometric"
ndim_supp = 0 signature = "(),(),()->()"
ndims_params = [0, 0, 0]
dtype = "int64" dtype = "int64"
_print_name = ("HyperGeometric", "\\operatorname{HyperGeometric}") _print_name = ("HyperGeometric", "\\operatorname{HyperGeometric}")
...@@ -1185,8 +1150,7 @@ class CauchyRV(ScipyRandomVariable): ...@@ -1185,8 +1150,7 @@ class CauchyRV(ScipyRandomVariable):
""" """
name = "cauchy" name = "cauchy"
ndim_supp = 0 signature = "(),()->()"
ndims_params = [0, 0]
dtype = "floatX" dtype = "floatX"
_print_name = ("Cauchy", "\\operatorname{Cauchy}") _print_name = ("Cauchy", "\\operatorname{Cauchy}")
...@@ -1236,8 +1200,7 @@ class HalfCauchyRV(ScipyRandomVariable): ...@@ -1236,8 +1200,7 @@ class HalfCauchyRV(ScipyRandomVariable):
""" """
name = "halfcauchy" name = "halfcauchy"
ndim_supp = 0 signature = "(),()->()"
ndims_params = [0, 0]
dtype = "floatX" dtype = "floatX"
_print_name = ("HalfCauchy", "\\operatorname{HalfCauchy}") _print_name = ("HalfCauchy", "\\operatorname{HalfCauchy}")
...@@ -1291,8 +1254,7 @@ class InvGammaRV(ScipyRandomVariable): ...@@ -1291,8 +1254,7 @@ class InvGammaRV(ScipyRandomVariable):
""" """
name = "invgamma" name = "invgamma"
ndim_supp = 0 signature = "(),()->()"
ndims_params = [0, 0]
dtype = "floatX" dtype = "floatX"
_print_name = ("InverseGamma", "\\operatorname{InverseGamma}") _print_name = ("InverseGamma", "\\operatorname{InverseGamma}")
...@@ -1342,8 +1304,7 @@ class WaldRV(RandomVariable): ...@@ -1342,8 +1304,7 @@ class WaldRV(RandomVariable):
""" """
name = "wald" name = "wald"
ndim_supp = 0 signature = "(),()->()"
ndims_params = [0, 0]
dtype = "floatX" dtype = "floatX"
_print_name_ = ("Wald", "\\operatorname{Wald}") _print_name_ = ("Wald", "\\operatorname{Wald}")
...@@ -1390,8 +1351,7 @@ class TruncExponentialRV(ScipyRandomVariable): ...@@ -1390,8 +1351,7 @@ class TruncExponentialRV(ScipyRandomVariable):
""" """
name = "truncexpon" name = "truncexpon"
ndim_supp = 0 signature = "(),(),()->()"
ndims_params = [0, 0, 0]
dtype = "floatX" dtype = "floatX"
_print_name = ("TruncatedExponential", "\\operatorname{TruncatedExponential}") _print_name = ("TruncatedExponential", "\\operatorname{TruncatedExponential}")
...@@ -1446,8 +1406,7 @@ class StudentTRV(ScipyRandomVariable): ...@@ -1446,8 +1406,7 @@ class StudentTRV(ScipyRandomVariable):
""" """
name = "t" name = "t"
ndim_supp = 0 signature = "(),(),()->()"
ndims_params = [0, 0, 0]
dtype = "floatX" dtype = "floatX"
_print_name = ("StudentT", "\\operatorname{StudentT}") _print_name = ("StudentT", "\\operatorname{StudentT}")
...@@ -1506,8 +1465,7 @@ class BernoulliRV(ScipyRandomVariable): ...@@ -1506,8 +1465,7 @@ class BernoulliRV(ScipyRandomVariable):
""" """
name = "bernoulli" name = "bernoulli"
ndim_supp = 0 signature = "()->()"
ndims_params = [0]
dtype = "int64" dtype = "int64"
_print_name = ("Bernoulli", "\\operatorname{Bernoulli}") _print_name = ("Bernoulli", "\\operatorname{Bernoulli}")
...@@ -1554,8 +1512,7 @@ class LaplaceRV(RandomVariable): ...@@ -1554,8 +1512,7 @@ class LaplaceRV(RandomVariable):
""" """
name = "laplace" name = "laplace"
ndim_supp = 0 signature = "(),()->()"
ndims_params = [0, 0]
dtype = "floatX" dtype = "floatX"
_print_name = ("Laplace", "\\operatorname{Laplace}") _print_name = ("Laplace", "\\operatorname{Laplace}")
...@@ -1601,8 +1558,7 @@ class BinomialRV(RandomVariable): ...@@ -1601,8 +1558,7 @@ class BinomialRV(RandomVariable):
""" """
name = "binomial" name = "binomial"
ndim_supp = 0 signature = "(),()->()"
ndims_params = [0, 0]
dtype = "int64" dtype = "int64"
_print_name = ("Binomial", "\\operatorname{Binomial}") _print_name = ("Binomial", "\\operatorname{Binomial}")
...@@ -1645,9 +1601,8 @@ class NegBinomialRV(ScipyRandomVariable): ...@@ -1645,9 +1601,8 @@ class NegBinomialRV(ScipyRandomVariable):
""" """
name = "nbinom" name = "negative_binomial"
ndim_supp = 0 signature = "(),()->()"
ndims_params = [0, 0]
dtype = "int64" dtype = "int64"
_print_name = ("NegativeBinomial", "\\operatorname{NegativeBinomial}") _print_name = ("NegativeBinomial", "\\operatorname{NegativeBinomial}")
...@@ -1702,8 +1657,7 @@ class BetaBinomialRV(ScipyRandomVariable): ...@@ -1702,8 +1657,7 @@ class BetaBinomialRV(ScipyRandomVariable):
""" """
name = "beta_binomial" name = "beta_binomial"
ndim_supp = 0 signature = "(),(),()->()"
ndims_params = [0, 0, 0]
dtype = "int64" dtype = "int64"
_print_name = ("BetaBinomial", "\\operatorname{BetaBinomial}") _print_name = ("BetaBinomial", "\\operatorname{BetaBinomial}")
...@@ -1754,8 +1708,7 @@ class GenGammaRV(ScipyRandomVariable): ...@@ -1754,8 +1708,7 @@ class GenGammaRV(ScipyRandomVariable):
""" """
name = "gengamma" name = "gengamma"
ndim_supp = 0 signature = "(),(),()->()"
ndims_params = [0, 0, 0]
dtype = "floatX" dtype = "floatX"
_print_name = ("GeneralizedGamma", "\\operatorname{GeneralizedGamma}") _print_name = ("GeneralizedGamma", "\\operatorname{GeneralizedGamma}")
...@@ -1817,8 +1770,7 @@ class MultinomialRV(RandomVariable): ...@@ -1817,8 +1770,7 @@ class MultinomialRV(RandomVariable):
""" """
name = "multinomial" name = "multinomial"
ndim_supp = 1 signature = "(),(p)->(p)"
ndims_params = [0, 1]
dtype = "int64" dtype = "int64"
_print_name = ("Multinomial", "\\operatorname{Multinomial}") _print_name = ("Multinomial", "\\operatorname{Multinomial}")
...@@ -1845,14 +1797,6 @@ class MultinomialRV(RandomVariable): ...@@ -1845,14 +1797,6 @@ class MultinomialRV(RandomVariable):
""" """
return super().__call__(n, p, size=size, **kwargs) return super().__call__(n, p, size=size, **kwargs)
def _supp_shape_from_params(self, dist_params, param_shapes=None):
return supp_shape_from_ref_param_shape(
ndim_supp=self.ndim_supp,
dist_params=dist_params,
param_shapes=param_shapes,
ref_param_idx=1,
)
@classmethod @classmethod
def rng_fn(cls, rng, n, p, size): def rng_fn(cls, rng, n, p, size):
if n.ndim > 0 or p.ndim > 1: if n.ndim > 0 or p.ndim > 1:
...@@ -1862,7 +1806,7 @@ class MultinomialRV(RandomVariable): ...@@ -1862,7 +1806,7 @@ class MultinomialRV(RandomVariable):
n = np.broadcast_to(n, size) n = np.broadcast_to(n, size)
p = np.broadcast_to(p, size + p.shape[-1:]) p = np.broadcast_to(p, size + p.shape[-1:])
else: else:
n, p = broadcast_params([n, p], cls.ndims_params) n, p = broadcast_params([n, p], [0, 1])
res = np.empty(p.shape, dtype=cls.dtype) res = np.empty(p.shape, dtype=cls.dtype)
for idx in np.ndindex(p.shape[:-1]): for idx in np.ndindex(p.shape[:-1]):
...@@ -1892,8 +1836,7 @@ class CategoricalRV(RandomVariable): ...@@ -1892,8 +1836,7 @@ class CategoricalRV(RandomVariable):
""" """
name = "categorical" name = "categorical"
ndim_supp = 0 signature = "(p)->()"
ndims_params = [1]
dtype = "int64" dtype = "int64"
_print_name = ("Categorical", "\\operatorname{Categorical}") _print_name = ("Categorical", "\\operatorname{Categorical}")
...@@ -1948,8 +1891,7 @@ class RandIntRV(RandomVariable): ...@@ -1948,8 +1891,7 @@ class RandIntRV(RandomVariable):
""" """
name = "randint" name = "randint"
ndim_supp = 0 signature = "(),()->()"
ndims_params = [0, 0]
dtype = "int64" dtype = "int64"
_print_name = ("randint", "\\operatorname{randint}") _print_name = ("randint", "\\operatorname{randint}")
...@@ -2001,8 +1943,7 @@ class IntegersRV(RandomVariable): ...@@ -2001,8 +1943,7 @@ class IntegersRV(RandomVariable):
""" """
name = "integers" name = "integers"
ndim_supp = 0 signature = "(),()->()"
ndims_params = [0, 0]
dtype = "int64" dtype = "int64"
_print_name = ("integers", "\\operatorname{integers}") _print_name = ("integers", "\\operatorname{integers}")
...@@ -2174,17 +2115,23 @@ def choice(a, size=None, replace=True, p=None, rng=None): ...@@ -2174,17 +2115,23 @@ def choice(a, size=None, replace=True, p=None, rng=None):
a_ndim = a.type.ndim a_ndim = a.type.ndim
dtype = a.type.dtype dtype = a.type.dtype
a_dims = [f"a{i}" for i in range(a_ndim)]
a_sig = ",".join(a_dims)
idx_dims = [f"s{i}" for i in range(core_shape_length)]
if a_ndim == 0:
p_sig = "a"
out_dims = idx_dims
else:
p_sig = a_dims[0]
out_dims = idx_dims + a_dims[1:]
out_sig = ",".join(out_dims)
if p is None: if p is None:
ndims_params = [a_ndim, 1] signature = f"({a_sig}),({core_shape_length})->({out_sig})"
else: else:
ndims_params = [a_ndim, 1, 1] signature = f"({a_sig}),({p_sig}),({core_shape_length})->({out_sig})"
ndim_supp = max(a_ndim - 1, 0) + core_shape_length
op = ChoiceWithoutReplacement( op = ChoiceWithoutReplacement(signature=signature, dtype=dtype)
ndim_supp=ndim_supp,
ndims_params=ndims_params,
dtype=dtype,
)
params = (a, core_shape) if p is None else (a, p, core_shape) params = (a, core_shape) if p is None else (a, p, core_shape)
return op(*params, size=None, rng=rng) return op(*params, size=None, rng=rng)
...@@ -2247,10 +2194,12 @@ def permutation(x, **kwargs): ...@@ -2247,10 +2194,12 @@ def permutation(x, **kwargs):
x_dtype = x.type.dtype x_dtype = x.type.dtype
# PermutationRV has a signature () -> (x) if x is a scalar # PermutationRV has a signature () -> (x) if x is a scalar
# and (*x) -> (*x) otherwise, with has many entries as the dimensionsality of x # and (*x) -> (*x) otherwise, with has many entries as the dimensionsality of x
ndim_supp = max(x_ndim, 1) if x_ndim == 0:
return PermutationRV(ndim_supp=ndim_supp, ndims_params=[x_ndim], dtype=x_dtype)( signature = "()->(x)"
x, **kwargs else:
) arg_sig = ",".join(f"x{i}" for i in range(x_ndim))
signature = f"({arg_sig})->({arg_sig})"
return PermutationRV(signature=signature, dtype=x_dtype)(x, **kwargs)
__all__ = [ __all__ = [
......
import warnings
from collections.abc import Sequence from collections.abc import Sequence
from copy import copy from copy import copy
from typing import cast from typing import cast
...@@ -28,6 +29,7 @@ from pytensor.tensor.random.utils import ( ...@@ -28,6 +29,7 @@ from pytensor.tensor.random.utils import (
from pytensor.tensor.shape import shape_tuple from pytensor.tensor.shape import shape_tuple
from pytensor.tensor.type import TensorType, all_dtypes from pytensor.tensor.type import TensorType, all_dtypes
from pytensor.tensor.type_other import NoneConst from pytensor.tensor.type_other import NoneConst
from pytensor.tensor.utils import _parse_gufunc_signature, safe_signature
from pytensor.tensor.variable import TensorVariable from pytensor.tensor.variable import TensorVariable
...@@ -42,7 +44,7 @@ class RandomVariable(Op): ...@@ -42,7 +44,7 @@ class RandomVariable(Op):
_output_type_depends_on_input_value = True _output_type_depends_on_input_value = True
__props__ = ("name", "ndim_supp", "ndims_params", "dtype", "inplace") __props__ = ("name", "signature", "dtype", "inplace")
default_output = 1 default_output = 1
def __init__( def __init__(
...@@ -50,8 +52,9 @@ class RandomVariable(Op): ...@@ -50,8 +52,9 @@ class RandomVariable(Op):
name=None, name=None,
ndim_supp=None, ndim_supp=None,
ndims_params=None, ndims_params=None,
dtype=None, dtype: str | None = None,
inplace=None, inplace=None,
signature: str | None = None,
): ):
"""Create a random variable `Op`. """Create a random variable `Op`.
...@@ -59,44 +62,63 @@ class RandomVariable(Op): ...@@ -59,44 +62,63 @@ class RandomVariable(Op):
---------- ----------
name: str name: str
The `Op`'s display name. The `Op`'s display name.
ndim_supp: int signature: str
Total number of dimensions for a single draw of the random variable Numpy-like vectorized signature of the random variable.
(e.g. a multivariate normal draw is 1D, so ``ndim_supp = 1``).
ndims_params: list of int
Number of dimensions for each distribution parameter when the
parameters only specify a single drawn of the random variable
(e.g. a multivariate normal's mean is 1D and covariance is 2D, so
``ndims_params = [1, 2]``).
dtype: str (optional) dtype: str (optional)
The dtype of the sampled output. If the value ``"floatX"`` is The dtype of the sampled output. If the value ``"floatX"`` is
given, then ``dtype`` is set to ``pytensor.config.floatX``. If given, then ``dtype`` is set to ``pytensor.config.floatX``. If
``None`` (the default), the `dtype` keyword must be set when ``None`` (the default), the `dtype` keyword must be set when
`RandomVariable.make_node` is called. `RandomVariable.make_node` is called.
inplace: boolean (optional) inplace: boolean (optional)
Determine whether or not the underlying rng state is updated Determine whether the underlying rng state is mutated or copied.
in-place or not (i.e. copied).
""" """
super().__init__() super().__init__()
self.name = name or getattr(self, "name") self.name = name or getattr(self, "name")
self.ndim_supp = (
ndim_supp if ndim_supp is not None else getattr(self, "ndim_supp") ndim_supp = (
ndim_supp if ndim_supp is not None else getattr(self, "ndim_supp", None)
) )
self.ndims_params = ( if ndim_supp is not None:
ndims_params if ndims_params is not None else getattr(self, "ndims_params") warnings.warn(
"ndim_supp is deprecated. Provide signature instead.", FutureWarning
)
self.ndim_supp = ndim_supp
ndims_params = (
ndims_params
if ndims_params is not None
else getattr(self, "ndims_params", None)
) )
if ndims_params is not None:
warnings.warn(
"ndims_params is deprecated. Provide signature instead.", FutureWarning
)
if not isinstance(ndims_params, Sequence):
raise TypeError("Parameter ndims_params must be sequence type.")
self.ndims_params = tuple(ndims_params)
self.signature = signature or getattr(self, "signature", None)
if self.signature is not None:
# Assume a single output. Several methods need to be updated to handle multiple outputs.
self.inputs_sig, [self.output_sig] = _parse_gufunc_signature(self.signature)
self.ndims_params = [len(input_sig) for input_sig in self.inputs_sig]
self.ndim_supp = len(self.output_sig)
else:
if (
getattr(self, "ndim_supp", None) is None
or getattr(self, "ndims_params", None) is None
):
raise ValueError("signature must be provided")
else:
self.signature = safe_signature(self.ndims_params, [self.ndim_supp])
self.dtype = dtype or getattr(self, "dtype", None) self.dtype = dtype or getattr(self, "dtype", None)
self.inplace = ( self.inplace = (
inplace if inplace is not None else getattr(self, "inplace", False) inplace if inplace is not None else getattr(self, "inplace", False)
) )
if not isinstance(self.ndims_params, Sequence):
raise TypeError("Parameter ndims_params must be sequence type.")
self.ndims_params = tuple(self.ndims_params)
if self.inplace: if self.inplace:
self.destroy_map = {0: [0]} self.destroy_map = {0: [0]}
...@@ -120,8 +142,31 @@ class RandomVariable(Op): ...@@ -120,8 +142,31 @@ class RandomVariable(Op):
values (not shapes) of some parameters. For instance, a `gaussian_random_walk(steps, size=(2,))`, values (not shapes) of some parameters. For instance, a `gaussian_random_walk(steps, size=(2,))`,
might have `support_shape=(steps,)`. might have `support_shape=(steps,)`.
""" """
if self.signature is not None:
# Signature could indicate fixed numerical shapes
# As per https://numpy.org/neps/nep-0020-gufunc-signature-enhancement.html
output_sig = self.output_sig
core_out_shape = {
dim: int(dim) if str.isnumeric(dim) else None for dim in self.output_sig
}
# Try to infer missing support dims from signature of params
for param, param_sig, ndim_params in zip(
dist_params, self.inputs_sig, self.ndims_params
):
if ndim_params == 0:
continue
for param_dim, dim in zip(param.shape[-ndim_params:], param_sig):
if dim in core_out_shape and core_out_shape[dim] is None:
core_out_shape[dim] = param_dim
if all(dim is not None for dim in core_out_shape.values()):
# We have all we need
return [core_out_shape[dim] for dim in output_sig]
raise NotImplementedError( raise NotImplementedError(
"`_supp_shape_from_params` must be implemented for multivariate RVs" "`_supp_shape_from_params` must be implemented for multivariate RVs "
"when signature is not sufficient to infer the support shape"
) )
def rng_fn(self, rng, *args, **kwargs) -> int | float | np.ndarray: def rng_fn(self, rng, *args, **kwargs) -> int | float | np.ndarray:
...@@ -129,7 +174,24 @@ class RandomVariable(Op): ...@@ -129,7 +174,24 @@ class RandomVariable(Op):
return getattr(rng, self.name)(*args, **kwargs) return getattr(rng, self.name)(*args, **kwargs)
def __str__(self): def __str__(self):
props_str = ", ".join(f"{getattr(self, prop)}" for prop in self.__props__[1:]) # Only show signature from core props
if signature := self.signature:
# inp, out = signature.split("->")
# extended_signature = f"[rng],[size],{inp}->[rng],{out}"
# core_props = [extended_signature]
core_props = [f'"{signature}"']
else:
# Far back compat
core_props = [str(self.ndim_supp), str(self.ndims_params)]
# Add any extra props that the subclass may have
extra_props = [
str(getattr(self, prop))
for prop in self.__props__
if prop not in RandomVariable.__props__
]
props_str = ", ".join(core_props + extra_props)
return f"{self.name}_rv{{{props_str}}}" return f"{self.name}_rv{{{props_str}}}"
def _infer_shape( def _infer_shape(
...@@ -298,11 +360,11 @@ class RandomVariable(Op): ...@@ -298,11 +360,11 @@ class RandomVariable(Op):
dtype_idx = constant(all_dtypes.index(dtype), dtype="int64") dtype_idx = constant(all_dtypes.index(dtype), dtype="int64")
else: else:
dtype_idx = constant(dtype, dtype="int64") dtype_idx = constant(dtype, dtype="int64")
dtype = all_dtypes[dtype_idx.data]
outtype = TensorType(dtype=dtype, shape=static_shape) dtype = all_dtypes[dtype_idx.data]
out_var = outtype()
inputs = (rng, size, dtype_idx, *dist_params) inputs = (rng, size, dtype_idx, *dist_params)
out_var = TensorType(dtype=dtype, shape=static_shape)()
outputs = (rng.type(), out_var) outputs = (rng.type(), out_var)
return Apply(self, inputs, outputs) return Apply(self, inputs, outputs)
...@@ -395,9 +457,8 @@ def vectorize_random_variable( ...@@ -395,9 +457,8 @@ def vectorize_random_variable(
# We extend it to accommodate the new input batch dimensions. # We extend it to accommodate the new input batch dimensions.
# Otherwise, we assume the new size already has the right values # Otherwise, we assume the new size already has the right values
# Need to make parameters implicit broadcasting explicit original_dist_params = op.dist_params(node)
original_dist_params = node.inputs[3:] old_size = op.size_param(node)
old_size = node.inputs[1]
len_old_size = get_vector_length(old_size) len_old_size = get_vector_length(old_size)
original_expanded_dist_params = explicit_expand_dims( original_expanded_dist_params = explicit_expand_dims(
......
import re
from pytensor.compile import optdb from pytensor.compile import optdb
from pytensor.graph.rewriting.basic import in2out, node_rewriter from pytensor.graph.rewriting.basic import in2out, node_rewriter
from pytensor.graph.rewriting.db import SequenceDB from pytensor.graph.rewriting.db import SequenceDB
...@@ -164,9 +166,9 @@ def materialize_implicit_arange_choice_without_replacement(fgraph, node): ...@@ -164,9 +166,9 @@ def materialize_implicit_arange_choice_without_replacement(fgraph, node):
a_vector_param = arange(a_scalar_param) a_vector_param = arange(a_scalar_param)
new_props_dict = op._props_dict().copy() new_props_dict = op._props_dict().copy()
new_ndims_params = list(op.ndims_params) # Signature changes from something like "(),(a),(2)->(s0, s1)" to "(a),(a),(2)->(s0, s1)"
new_ndims_params[0] += 1 # I.e., we substitute the first `()` by `(a)`
new_props_dict["ndims_params"] = new_ndims_params new_props_dict["signature"] = re.sub(r"\(\)", "(a)", op.signature, 1)
new_op = type(op)(**new_props_dict) new_op = type(op)(**new_props_dict)
return new_op.make_node(rng, size, dtype, a_vector_param, *other_params).outputs return new_op.make_node(rng, size, dtype, a_vector_param, *other_params).outputs
......
...@@ -123,7 +123,7 @@ def broadcast_params(params, ndims_params): ...@@ -123,7 +123,7 @@ def broadcast_params(params, ndims_params):
def explicit_expand_dims( def explicit_expand_dims(
params: Sequence[TensorVariable], params: Sequence[TensorVariable],
ndim_params: tuple[int], ndim_params: Sequence[int],
size_length: int = 0, size_length: int = 0,
) -> list[TensorVariable]: ) -> list[TensorVariable]:
"""Introduce explicit expand_dims in RV parameters that are implicitly broadcasted together and/or by size.""" """Introduce explicit expand_dims in RV parameters that are implicitly broadcasted together and/or by size."""
...@@ -137,7 +137,7 @@ def explicit_expand_dims( ...@@ -137,7 +137,7 @@ def explicit_expand_dims(
# See: https://github.com/pymc-devs/pytensor/issues/568 # See: https://github.com/pymc-devs/pytensor/issues/568
max_batch_dims = size_length max_batch_dims = size_length
else: else:
max_batch_dims = max(batch_dims) max_batch_dims = max(batch_dims, default=0)
new_params = [] new_params = []
for new_param, batch_dim in zip(params, batch_dims): for new_param, batch_dim in zip(params, batch_dims):
...@@ -354,6 +354,11 @@ def supp_shape_from_ref_param_shape( ...@@ -354,6 +354,11 @@ def supp_shape_from_ref_param_shape(
out: tuple out: tuple
Representing the support shape for a `RandomVariable` with the given `dist_params`. Representing the support shape for a `RandomVariable` with the given `dist_params`.
Notes
_____
This helper is no longer necessary when using signatures in `RandomVariable` subclasses.
""" """
if ndim_supp <= 0: if ndim_supp <= 0:
raise ValueError("ndim_supp must be greater than 0") raise ValueError("ndim_supp must be greater than 0")
......
...@@ -169,7 +169,8 @@ _DIMENSION_NAME = r"\w+" ...@@ -169,7 +169,8 @@ _DIMENSION_NAME = r"\w+"
_CORE_DIMENSION_LIST = f"(?:{_DIMENSION_NAME}(?:,{_DIMENSION_NAME})*)?" _CORE_DIMENSION_LIST = f"(?:{_DIMENSION_NAME}(?:,{_DIMENSION_NAME})*)?"
_ARGUMENT = rf"\({_CORE_DIMENSION_LIST}\)" _ARGUMENT = rf"\({_CORE_DIMENSION_LIST}\)"
_ARGUMENT_LIST = f"{_ARGUMENT}(?:,{_ARGUMENT})*" _ARGUMENT_LIST = f"{_ARGUMENT}(?:,{_ARGUMENT})*"
_SIGNATURE = f"^{_ARGUMENT_LIST}->{_ARGUMENT_LIST}$" # Allow no inputs
_SIGNATURE = f"^(?:{_ARGUMENT_LIST})?->{_ARGUMENT_LIST}$"
def _parse_gufunc_signature( def _parse_gufunc_signature(
...@@ -200,6 +201,8 @@ def _parse_gufunc_signature( ...@@ -200,6 +201,8 @@ def _parse_gufunc_signature(
tuple(re.findall(_DIMENSION_NAME, arg)) tuple(re.findall(_DIMENSION_NAME, arg))
for arg in re.findall(_ARGUMENT, arg_list) for arg in re.findall(_ARGUMENT, arg_list)
] ]
if arg_list # ignore no inputs
else []
for arg_list in signature.split("->") for arg_list in signature.split("->")
) )
......
...@@ -771,8 +771,7 @@ def test_random_unimplemented(): ...@@ -771,8 +771,7 @@ def test_random_unimplemented():
class NonExistentRV(RandomVariable): class NonExistentRV(RandomVariable):
name = "non-existent" name = "non-existent"
ndim_supp = 0 signature = "->()"
ndims_params = []
dtype = "floatX" dtype = "floatX"
def __call__(self, size=None, **kwargs): def __call__(self, size=None, **kwargs):
...@@ -798,8 +797,7 @@ def test_random_custom_implementation(): ...@@ -798,8 +797,7 @@ def test_random_custom_implementation():
class CustomRV(RandomVariable): class CustomRV(RandomVariable):
name = "non-existent" name = "non-existent"
ndim_supp = 0 signature = "->()"
ndims_params = []
dtype = "floatX" dtype = "floatX"
def __call__(self, size=None, **kwargs): def __call__(self, size=None, **kwargs):
......
...@@ -74,52 +74,28 @@ def apply_local_rewrite_to_rv( ...@@ -74,52 +74,28 @@ def apply_local_rewrite_to_rv(
return new_out, f_inputs, dist_st, f_rewritten return new_out, f_inputs, dist_st, f_rewritten
def test_inplace_rewrites(): class TestRVExpraProps(RandomVariable):
out = normal(0, 1) name = "test"
out.owner.inputs[0].default_update = out.owner.outputs[0] signature = "()->()"
__props__ = ("name", "signature", "dtype", "inplace", "extra")
dtype = "floatX"
_print_name = ("TestExtraProps", "\\operatorname{TestExtra_props}")
assert out.owner.op.inplace is False def __init__(self, extra, *args, **kwargs):
self.extra = extra
super().__init__(*args, **kwargs)
f = function( def rng_fn(self, rng, dtype, sigma, size):
[], return rng.normal(scale=sigma, size=size)
out,
mode="FAST_RUN",
)
(new_out, new_rng) = f.maker.fgraph.outputs
assert new_out.type == out.type
assert isinstance(new_out.owner.op, type(out.owner.op))
assert new_out.owner.op.inplace is True
assert all(
np.array_equal(a.data, b.data)
for a, b in zip(new_out.owner.inputs[2:], out.owner.inputs[2:])
)
assert np.array_equal(new_out.owner.inputs[1].data, [])
def test_inplace_rewrites_extra_props():
class Test(RandomVariable):
name = "test"
ndim_supp = 0
ndims_params = [0]
__props__ = ("name", "ndim_supp", "ndims_params", "dtype", "inplace", "extra")
dtype = "floatX"
_print_name = ("Test", "\\operatorname{Test}")
def __init__(self, extra, *args, **kwargs):
self.extra = extra
super().__init__(*args, **kwargs)
def make_node(self, rng, size, dtype, sigma):
return super().make_node(rng, size, dtype, sigma)
def rng_fn(self, rng, sigma, size):
return rng.normal(scale=sigma, size=size)
out = Test(extra="some value")(1)
out.owner.inputs[0].default_update = out.owner.outputs[0]
assert out.owner.op.inplace is False @pytest.mark.parametrize("rv_op", [normal, TestRVExpraProps(extra="some value")])
def test_inplace_rewrites(rv_op):
out = rv_op(np.e)
node = out.owner
op = node.op
node.inputs[0].default_update = node.outputs[0]
assert op.inplace is False
f = function( f = function(
[], [],
...@@ -129,9 +105,10 @@ def test_inplace_rewrites_extra_props(): ...@@ -129,9 +105,10 @@ def test_inplace_rewrites_extra_props():
(new_out, new_rng) = f.maker.fgraph.outputs (new_out, new_rng) = f.maker.fgraph.outputs
assert new_out.type == out.type assert new_out.type == out.type
assert isinstance(new_out.owner.op, type(out.owner.op)) new_node = new_out.owner
assert new_out.owner.op.inplace is True new_op = new_node.op
assert new_out.owner.op.extra == out.owner.op.extra assert isinstance(new_op, type(op))
assert new_op._props_dict() == (op._props_dict() | {"inplace": True})
assert all( assert all(
np.array_equal(a.data, b.data) np.array_equal(a.data, b.data)
for a, b in zip(new_out.owner.inputs[2:], out.owner.inputs[2:]) for a, b in zip(new_out.owner.inputs[2:], out.owner.inputs[2:])
......
...@@ -1463,11 +1463,8 @@ def batched_unweighted_choice_without_replacement_tester( ...@@ -1463,11 +1463,8 @@ def batched_unweighted_choice_without_replacement_tester(
rng = shared(rng_ctor()) rng = shared(rng_ctor())
# Batched a implicit size # Batched a implicit size
a_core_ndim = 2
core_shape_len = 1
rv_op = ChoiceWithoutReplacement( rv_op = ChoiceWithoutReplacement(
ndim_supp=max(a_core_ndim - 1, 0) + core_shape_len, signature="(a0,a1),(1)->(s0,a1)",
ndims_params=[a_core_ndim, core_shape_len],
dtype="int64", dtype="int64",
) )
...@@ -1483,11 +1480,8 @@ def batched_unweighted_choice_without_replacement_tester( ...@@ -1483,11 +1480,8 @@ def batched_unweighted_choice_without_replacement_tester(
assert np.all((draw >= i * 10) & (draw < (i + 1) * 10)) assert np.all((draw >= i * 10) & (draw < (i + 1) * 10))
# Explicit size broadcasts beyond a # Explicit size broadcasts beyond a
a_core_ndim = 2
core_shape_len = 2
rv_op = ChoiceWithoutReplacement( rv_op = ChoiceWithoutReplacement(
ndim_supp=max(a_core_ndim - 1, 0) + core_shape_len, signature="(a0,a1),(2)->(s0,s1,a1)",
ndims_params=[a_core_ndim, len(core_shape)],
dtype="int64", dtype="int64",
) )
...@@ -1515,12 +1509,8 @@ def batched_weighted_choice_without_replacement_tester( ...@@ -1515,12 +1509,8 @@ def batched_weighted_choice_without_replacement_tester(
""" """
rng = shared(rng_ctor()) rng = shared(rng_ctor())
# 3 ndims params indicates p is passed
a_core_ndim = 2
core_shape_len = 1
rv_op = ChoiceWithoutReplacement( rv_op = ChoiceWithoutReplacement(
ndim_supp=max(a_core_ndim - 1, 0) + core_shape_len, signature="(a0,a1),(a0),(1)->(s0,a1)",
ndims_params=[a_core_ndim, 1, 1],
dtype="int64", dtype="int64",
) )
...@@ -1540,11 +1530,8 @@ def batched_weighted_choice_without_replacement_tester( ...@@ -1540,11 +1530,8 @@ def batched_weighted_choice_without_replacement_tester(
# p and a are batched # p and a are batched
# Test implicit arange # Test implicit arange
a_core_ndim = 0
core_shape_len = 2
rv_op = ChoiceWithoutReplacement( rv_op = ChoiceWithoutReplacement(
ndim_supp=max(a_core_ndim - 1, 0) + core_shape_len, signature="(),(a),(2)->(s0,s1)",
ndims_params=[a_core_ndim, 1, 1],
dtype="int64", dtype="int64",
) )
a = 6 a = 6
...@@ -1566,11 +1553,8 @@ def batched_weighted_choice_without_replacement_tester( ...@@ -1566,11 +1553,8 @@ def batched_weighted_choice_without_replacement_tester(
assert set(draw) == set(range(i, 6, 2)) assert set(draw) == set(range(i, 6, 2))
# Size broadcasts beyond a # Size broadcasts beyond a
a_core_ndim = 2
core_shape_len = 1
rv_op = ChoiceWithoutReplacement( rv_op = ChoiceWithoutReplacement(
ndim_supp=max(a_core_ndim - 1, 0) + core_shape_len, signature="(a0,a1),(a0),(1)->(s0,a1)",
ndims_params=[a_core_ndim, 1, 1],
dtype="int64", dtype="int64",
) )
a = np.arange(4 * 5 * 2).reshape((4, 5, 2)) a = np.arange(4 * 5 * 2).reshape((4, 5, 2))
......
...@@ -23,14 +23,13 @@ def test_RandomVariable_basics(strict_test_value_flags): ...@@ -23,14 +23,13 @@ def test_RandomVariable_basics(strict_test_value_flags):
str_res = str( str_res = str(
RandomVariable( RandomVariable(
"normal", "normal",
0, signature="(),()->()",
[0, 0], dtype="float32",
"float32", inplace=False,
inplace=True,
) )
) )
assert str_res == "normal_rv{0, (0, 0), float32, True}" assert str_res == 'normal_rv{"(),()->()"}'
# `ndims_params` should be a `Sequence` type # `ndims_params` should be a `Sequence` type
with pytest.raises(TypeError, match="^Parameter ndims_params*"): with pytest.raises(TypeError, match="^Parameter ndims_params*"):
...@@ -64,9 +63,7 @@ def test_RandomVariable_basics(strict_test_value_flags): ...@@ -64,9 +63,7 @@ def test_RandomVariable_basics(strict_test_value_flags):
# Confirm that `inplace` works # Confirm that `inplace` works
rv = RandomVariable( rv = RandomVariable(
"normal", "normal",
0, signature="(),()->()",
[0, 0],
"normal",
inplace=True, inplace=True,
) )
...@@ -74,7 +71,7 @@ def test_RandomVariable_basics(strict_test_value_flags): ...@@ -74,7 +71,7 @@ def test_RandomVariable_basics(strict_test_value_flags):
assert rv.destroy_map == {0: [0]} assert rv.destroy_map == {0: [0]}
# A no-params `RandomVariable` # A no-params `RandomVariable`
rv = RandomVariable(name="test_rv", ndim_supp=0, ndims_params=()) rv = RandomVariable(name="test_rv", signature="->()")
with pytest.raises(TypeError): with pytest.raises(TypeError):
rv.make_node(rng=1) rv.make_node(rng=1)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论