提交 3e9c6a4f authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Introduce signature instead of ndim_supp and ndims_params

上级 a576fa2c
......@@ -13,7 +13,6 @@ from pytensor.tensor.random.type import RandomGeneratorType, RandomStateType
from pytensor.tensor.random.utils import (
broadcast_params,
normalize_size_param,
supp_shape_from_ref_param_shape,
)
from pytensor.tensor.random.var import (
RandomGeneratorSharedVariable,
......@@ -91,8 +90,7 @@ class UniformRV(RandomVariable):
"""
name = "uniform"
ndim_supp = 0
ndims_params = [0, 0]
signature = "(),()->()"
dtype = "floatX"
_print_name = ("Uniform", "\\operatorname{Uniform}")
......@@ -146,8 +144,7 @@ class TriangularRV(RandomVariable):
"""
name = "triangular"
ndim_supp = 0
ndims_params = [0, 0, 0]
signature = "(),(),()->()"
dtype = "floatX"
_print_name = ("Triangular", "\\operatorname{Triangular}")
......@@ -202,8 +199,7 @@ class BetaRV(RandomVariable):
"""
name = "beta"
ndim_supp = 0
ndims_params = [0, 0]
signature = "(),()->()"
dtype = "floatX"
_print_name = ("Beta", "\\operatorname{Beta}")
......@@ -249,8 +245,7 @@ class NormalRV(RandomVariable):
"""
name = "normal"
ndim_supp = 0
ndims_params = [0, 0]
signature = "(),()->()"
dtype = "floatX"
_print_name = ("Normal", "\\operatorname{Normal}")
......@@ -316,8 +311,7 @@ class HalfNormalRV(ScipyRandomVariable):
"""
name = "halfnormal"
ndim_supp = 0
ndims_params = [0, 0]
signature = "(),()->()"
dtype = "floatX"
_print_name = ("HalfNormal", "\\operatorname{HalfNormal}")
......@@ -382,8 +376,7 @@ class LogNormalRV(RandomVariable):
"""
name = "lognormal"
ndim_supp = 0
ndims_params = [0, 0]
signature = "(),()->()"
dtype = "floatX"
_print_name = ("LogNormal", "\\operatorname{LogNormal}")
......@@ -434,8 +427,7 @@ class GammaRV(RandomVariable):
"""
name = "gamma"
ndim_supp = 0
ndims_params = [0, 0]
signature = "(),()->()"
dtype = "floatX"
_print_name = ("Gamma", "\\operatorname{Gamma}")
......@@ -567,8 +559,7 @@ class ParetoRV(ScipyRandomVariable):
"""
name = "pareto"
ndim_supp = 0
ndims_params = [0, 0]
signature = "(),()->()"
dtype = "floatX"
_print_name = ("Pareto", "\\operatorname{Pareto}")
......@@ -618,8 +609,7 @@ class GumbelRV(ScipyRandomVariable):
"""
name = "gumbel"
ndim_supp = 0
ndims_params = [0, 0]
signature = "(),()->()"
dtype = "floatX"
_print_name = ("Gumbel", "\\operatorname{Gumbel}")
......@@ -680,8 +670,7 @@ class ExponentialRV(RandomVariable):
"""
name = "exponential"
ndim_supp = 0
ndims_params = [0]
signature = "()->()"
dtype = "floatX"
_print_name = ("Exponential", "\\operatorname{Exponential}")
......@@ -724,8 +713,7 @@ class WeibullRV(RandomVariable):
"""
name = "weibull"
ndim_supp = 0
ndims_params = [0]
signature = "()->()"
dtype = "floatX"
_print_name = ("Weibull", "\\operatorname{Weibull}")
......@@ -769,8 +757,7 @@ class LogisticRV(RandomVariable):
"""
name = "logistic"
ndim_supp = 0
ndims_params = [0, 0]
signature = "(),()->()"
dtype = "floatX"
_print_name = ("Logistic", "\\operatorname{Logistic}")
......@@ -818,8 +805,7 @@ class VonMisesRV(RandomVariable):
"""
name = "vonmises"
ndim_supp = 0
ndims_params = [0, 0]
signature = "(),()->()"
dtype = "floatX"
_print_name = ("VonMises", "\\operatorname{VonMises}")
......@@ -886,19 +872,10 @@ class MvNormalRV(RandomVariable):
"""
name = "multivariate_normal"
ndim_supp = 1
ndims_params = [1, 2]
signature = "(n),(n,n)->(n)"
dtype = "floatX"
_print_name = ("MultivariateNormal", "\\operatorname{MultivariateNormal}")
def _supp_shape_from_params(self, dist_params, param_shapes=None):
return supp_shape_from_ref_param_shape(
ndim_supp=self.ndim_supp,
dist_params=dist_params,
param_shapes=param_shapes,
ref_param_idx=0,
)
def __call__(self, mean=None, cov=None, size=None, **kwargs):
r""" "Draw samples from a multivariate normal distribution.
......@@ -942,7 +919,7 @@ class MvNormalRV(RandomVariable):
mean = np.broadcast_to(mean, size + mean.shape[-1:])
cov = np.broadcast_to(cov, size + cov.shape[-2:])
else:
mean, cov = broadcast_params([mean, cov], cls.ndims_params)
mean, cov = broadcast_params([mean, cov], [1, 2])
res = np.empty(mean.shape)
for idx in np.ndindex(mean.shape[:-1]):
......@@ -973,19 +950,10 @@ class DirichletRV(RandomVariable):
"""
name = "dirichlet"
ndim_supp = 1
ndims_params = [1]
signature = "(a)->(a)"
dtype = "floatX"
_print_name = ("Dirichlet", "\\operatorname{Dirichlet}")
def _supp_shape_from_params(self, dist_params, param_shapes=None):
return supp_shape_from_ref_param_shape(
ndim_supp=self.ndim_supp,
dist_params=dist_params,
param_shapes=param_shapes,
ref_param_idx=0,
)
def __call__(self, alphas, size=None, **kwargs):
r"""Draw samples from a dirichlet distribution.
......@@ -1047,8 +1015,7 @@ class PoissonRV(RandomVariable):
"""
name = "poisson"
ndim_supp = 0
ndims_params = [0]
signature = "()->()"
dtype = "int64"
_print_name = ("Poisson", "\\operatorname{Poisson}")
......@@ -1093,8 +1060,7 @@ class GeometricRV(RandomVariable):
"""
name = "geometric"
ndim_supp = 0
ndims_params = [0]
signature = "()->()"
dtype = "int64"
_print_name = ("Geometric", "\\operatorname{Geometric}")
......@@ -1136,8 +1102,7 @@ class HyperGeometricRV(RandomVariable):
"""
name = "hypergeometric"
ndim_supp = 0
ndims_params = [0, 0, 0]
signature = "(),(),()->()"
dtype = "int64"
_print_name = ("HyperGeometric", "\\operatorname{HyperGeometric}")
......@@ -1185,8 +1150,7 @@ class CauchyRV(ScipyRandomVariable):
"""
name = "cauchy"
ndim_supp = 0
ndims_params = [0, 0]
signature = "(),()->()"
dtype = "floatX"
_print_name = ("Cauchy", "\\operatorname{Cauchy}")
......@@ -1236,8 +1200,7 @@ class HalfCauchyRV(ScipyRandomVariable):
"""
name = "halfcauchy"
ndim_supp = 0
ndims_params = [0, 0]
signature = "(),()->()"
dtype = "floatX"
_print_name = ("HalfCauchy", "\\operatorname{HalfCauchy}")
......@@ -1291,8 +1254,7 @@ class InvGammaRV(ScipyRandomVariable):
"""
name = "invgamma"
ndim_supp = 0
ndims_params = [0, 0]
signature = "(),()->()"
dtype = "floatX"
_print_name = ("InverseGamma", "\\operatorname{InverseGamma}")
......@@ -1342,8 +1304,7 @@ class WaldRV(RandomVariable):
"""
name = "wald"
ndim_supp = 0
ndims_params = [0, 0]
signature = "(),()->()"
dtype = "floatX"
_print_name_ = ("Wald", "\\operatorname{Wald}")
......@@ -1390,8 +1351,7 @@ class TruncExponentialRV(ScipyRandomVariable):
"""
name = "truncexpon"
ndim_supp = 0
ndims_params = [0, 0, 0]
signature = "(),(),()->()"
dtype = "floatX"
_print_name = ("TruncatedExponential", "\\operatorname{TruncatedExponential}")
......@@ -1446,8 +1406,7 @@ class StudentTRV(ScipyRandomVariable):
"""
name = "t"
ndim_supp = 0
ndims_params = [0, 0, 0]
signature = "(),(),()->()"
dtype = "floatX"
_print_name = ("StudentT", "\\operatorname{StudentT}")
......@@ -1506,8 +1465,7 @@ class BernoulliRV(ScipyRandomVariable):
"""
name = "bernoulli"
ndim_supp = 0
ndims_params = [0]
signature = "()->()"
dtype = "int64"
_print_name = ("Bernoulli", "\\operatorname{Bernoulli}")
......@@ -1554,8 +1512,7 @@ class LaplaceRV(RandomVariable):
"""
name = "laplace"
ndim_supp = 0
ndims_params = [0, 0]
signature = "(),()->()"
dtype = "floatX"
_print_name = ("Laplace", "\\operatorname{Laplace}")
......@@ -1601,8 +1558,7 @@ class BinomialRV(RandomVariable):
"""
name = "binomial"
ndim_supp = 0
ndims_params = [0, 0]
signature = "(),()->()"
dtype = "int64"
_print_name = ("Binomial", "\\operatorname{Binomial}")
......@@ -1645,9 +1601,8 @@ class NegBinomialRV(ScipyRandomVariable):
"""
name = "nbinom"
ndim_supp = 0
ndims_params = [0, 0]
name = "negative_binomial"
signature = "(),()->()"
dtype = "int64"
_print_name = ("NegativeBinomial", "\\operatorname{NegativeBinomial}")
......@@ -1702,8 +1657,7 @@ class BetaBinomialRV(ScipyRandomVariable):
"""
name = "beta_binomial"
ndim_supp = 0
ndims_params = [0, 0, 0]
signature = "(),(),()->()"
dtype = "int64"
_print_name = ("BetaBinomial", "\\operatorname{BetaBinomial}")
......@@ -1754,8 +1708,7 @@ class GenGammaRV(ScipyRandomVariable):
"""
name = "gengamma"
ndim_supp = 0
ndims_params = [0, 0, 0]
signature = "(),(),()->()"
dtype = "floatX"
_print_name = ("GeneralizedGamma", "\\operatorname{GeneralizedGamma}")
......@@ -1817,8 +1770,7 @@ class MultinomialRV(RandomVariable):
"""
name = "multinomial"
ndim_supp = 1
ndims_params = [0, 1]
signature = "(),(p)->(p)"
dtype = "int64"
_print_name = ("Multinomial", "\\operatorname{Multinomial}")
......@@ -1845,14 +1797,6 @@ class MultinomialRV(RandomVariable):
"""
return super().__call__(n, p, size=size, **kwargs)
def _supp_shape_from_params(self, dist_params, param_shapes=None):
return supp_shape_from_ref_param_shape(
ndim_supp=self.ndim_supp,
dist_params=dist_params,
param_shapes=param_shapes,
ref_param_idx=1,
)
@classmethod
def rng_fn(cls, rng, n, p, size):
if n.ndim > 0 or p.ndim > 1:
......@@ -1862,7 +1806,7 @@ class MultinomialRV(RandomVariable):
n = np.broadcast_to(n, size)
p = np.broadcast_to(p, size + p.shape[-1:])
else:
n, p = broadcast_params([n, p], cls.ndims_params)
n, p = broadcast_params([n, p], [0, 1])
res = np.empty(p.shape, dtype=cls.dtype)
for idx in np.ndindex(p.shape[:-1]):
......@@ -1892,8 +1836,7 @@ class CategoricalRV(RandomVariable):
"""
name = "categorical"
ndim_supp = 0
ndims_params = [1]
signature = "(p)->()"
dtype = "int64"
_print_name = ("Categorical", "\\operatorname{Categorical}")
......@@ -1948,8 +1891,7 @@ class RandIntRV(RandomVariable):
"""
name = "randint"
ndim_supp = 0
ndims_params = [0, 0]
signature = "(),()->()"
dtype = "int64"
_print_name = ("randint", "\\operatorname{randint}")
......@@ -2001,8 +1943,7 @@ class IntegersRV(RandomVariable):
"""
name = "integers"
ndim_supp = 0
ndims_params = [0, 0]
signature = "(),()->()"
dtype = "int64"
_print_name = ("integers", "\\operatorname{integers}")
......@@ -2174,17 +2115,23 @@ def choice(a, size=None, replace=True, p=None, rng=None):
a_ndim = a.type.ndim
dtype = a.type.dtype
a_dims = [f"a{i}" for i in range(a_ndim)]
a_sig = ",".join(a_dims)
idx_dims = [f"s{i}" for i in range(core_shape_length)]
if a_ndim == 0:
p_sig = "a"
out_dims = idx_dims
else:
p_sig = a_dims[0]
out_dims = idx_dims + a_dims[1:]
out_sig = ",".join(out_dims)
if p is None:
ndims_params = [a_ndim, 1]
signature = f"({a_sig}),({core_shape_length})->({out_sig})"
else:
ndims_params = [a_ndim, 1, 1]
ndim_supp = max(a_ndim - 1, 0) + core_shape_length
signature = f"({a_sig}),({p_sig}),({core_shape_length})->({out_sig})"
op = ChoiceWithoutReplacement(
ndim_supp=ndim_supp,
ndims_params=ndims_params,
dtype=dtype,
)
op = ChoiceWithoutReplacement(signature=signature, dtype=dtype)
params = (a, core_shape) if p is None else (a, p, core_shape)
return op(*params, size=None, rng=rng)
......@@ -2247,10 +2194,12 @@ def permutation(x, **kwargs):
x_dtype = x.type.dtype
# PermutationRV has a signature () -> (x) if x is a scalar
# and (*x) -> (*x) otherwise, with has many entries as the dimensionsality of x
ndim_supp = max(x_ndim, 1)
return PermutationRV(ndim_supp=ndim_supp, ndims_params=[x_ndim], dtype=x_dtype)(
x, **kwargs
)
if x_ndim == 0:
signature = "()->(x)"
else:
arg_sig = ",".join(f"x{i}" for i in range(x_ndim))
signature = f"({arg_sig})->({arg_sig})"
return PermutationRV(signature=signature, dtype=x_dtype)(x, **kwargs)
__all__ = [
......
import warnings
from collections.abc import Sequence
from copy import copy
from typing import cast
......@@ -28,6 +29,7 @@ from pytensor.tensor.random.utils import (
from pytensor.tensor.shape import shape_tuple
from pytensor.tensor.type import TensorType, all_dtypes
from pytensor.tensor.type_other import NoneConst
from pytensor.tensor.utils import _parse_gufunc_signature, safe_signature
from pytensor.tensor.variable import TensorVariable
......@@ -42,7 +44,7 @@ class RandomVariable(Op):
_output_type_depends_on_input_value = True
__props__ = ("name", "ndim_supp", "ndims_params", "dtype", "inplace")
__props__ = ("name", "signature", "dtype", "inplace")
default_output = 1
def __init__(
......@@ -50,8 +52,9 @@ class RandomVariable(Op):
name=None,
ndim_supp=None,
ndims_params=None,
dtype=None,
dtype: str | None = None,
inplace=None,
signature: str | None = None,
):
"""Create a random variable `Op`.
......@@ -59,44 +62,63 @@ class RandomVariable(Op):
----------
name: str
The `Op`'s display name.
ndim_supp: int
Total number of dimensions for a single draw of the random variable
(e.g. a multivariate normal draw is 1D, so ``ndim_supp = 1``).
ndims_params: list of int
Number of dimensions for each distribution parameter when the
parameters only specify a single drawn of the random variable
(e.g. a multivariate normal's mean is 1D and covariance is 2D, so
``ndims_params = [1, 2]``).
signature: str
Numpy-like vectorized signature of the random variable.
dtype: str (optional)
The dtype of the sampled output. If the value ``"floatX"`` is
given, then ``dtype`` is set to ``pytensor.config.floatX``. If
``None`` (the default), the `dtype` keyword must be set when
`RandomVariable.make_node` is called.
inplace: boolean (optional)
Determine whether or not the underlying rng state is updated
in-place or not (i.e. copied).
Determine whether the underlying rng state is mutated or copied.
"""
super().__init__()
self.name = name or getattr(self, "name")
self.ndim_supp = (
ndim_supp if ndim_supp is not None else getattr(self, "ndim_supp")
ndim_supp = (
ndim_supp if ndim_supp is not None else getattr(self, "ndim_supp", None)
)
self.ndims_params = (
ndims_params if ndims_params is not None else getattr(self, "ndims_params")
if ndim_supp is not None:
warnings.warn(
"ndim_supp is deprecated. Provide signature instead.", FutureWarning
)
self.ndim_supp = ndim_supp
ndims_params = (
ndims_params
if ndims_params is not None
else getattr(self, "ndims_params", None)
)
if ndims_params is not None:
warnings.warn(
"ndims_params is deprecated. Provide signature instead.", FutureWarning
)
if not isinstance(ndims_params, Sequence):
raise TypeError("Parameter ndims_params must be sequence type.")
self.ndims_params = tuple(ndims_params)
self.signature = signature or getattr(self, "signature", None)
if self.signature is not None:
# Assume a single output. Several methods need to be updated to handle multiple outputs.
self.inputs_sig, [self.output_sig] = _parse_gufunc_signature(self.signature)
self.ndims_params = [len(input_sig) for input_sig in self.inputs_sig]
self.ndim_supp = len(self.output_sig)
else:
if (
getattr(self, "ndim_supp", None) is None
or getattr(self, "ndims_params", None) is None
):
raise ValueError("signature must be provided")
else:
self.signature = safe_signature(self.ndims_params, [self.ndim_supp])
self.dtype = dtype or getattr(self, "dtype", None)
self.inplace = (
inplace if inplace is not None else getattr(self, "inplace", False)
)
if not isinstance(self.ndims_params, Sequence):
raise TypeError("Parameter ndims_params must be sequence type.")
self.ndims_params = tuple(self.ndims_params)
if self.inplace:
self.destroy_map = {0: [0]}
......@@ -120,8 +142,31 @@ class RandomVariable(Op):
values (not shapes) of some parameters. For instance, a `gaussian_random_walk(steps, size=(2,))`,
might have `support_shape=(steps,)`.
"""
if self.signature is not None:
# Signature could indicate fixed numerical shapes
# As per https://numpy.org/neps/nep-0020-gufunc-signature-enhancement.html
output_sig = self.output_sig
core_out_shape = {
dim: int(dim) if str.isnumeric(dim) else None for dim in self.output_sig
}
# Try to infer missing support dims from signature of params
for param, param_sig, ndim_params in zip(
dist_params, self.inputs_sig, self.ndims_params
):
if ndim_params == 0:
continue
for param_dim, dim in zip(param.shape[-ndim_params:], param_sig):
if dim in core_out_shape and core_out_shape[dim] is None:
core_out_shape[dim] = param_dim
if all(dim is not None for dim in core_out_shape.values()):
# We have all we need
return [core_out_shape[dim] for dim in output_sig]
raise NotImplementedError(
"`_supp_shape_from_params` must be implemented for multivariate RVs"
"`_supp_shape_from_params` must be implemented for multivariate RVs "
"when signature is not sufficient to infer the support shape"
)
def rng_fn(self, rng, *args, **kwargs) -> int | float | np.ndarray:
......@@ -129,7 +174,24 @@ class RandomVariable(Op):
return getattr(rng, self.name)(*args, **kwargs)
def __str__(self):
props_str = ", ".join(f"{getattr(self, prop)}" for prop in self.__props__[1:])
# Only show signature from core props
if signature := self.signature:
# inp, out = signature.split("->")
# extended_signature = f"[rng],[size],{inp}->[rng],{out}"
# core_props = [extended_signature]
core_props = [f'"{signature}"']
else:
# Far back compat
core_props = [str(self.ndim_supp), str(self.ndims_params)]
# Add any extra props that the subclass may have
extra_props = [
str(getattr(self, prop))
for prop in self.__props__
if prop not in RandomVariable.__props__
]
props_str = ", ".join(core_props + extra_props)
return f"{self.name}_rv{{{props_str}}}"
def _infer_shape(
......@@ -298,11 +360,11 @@ class RandomVariable(Op):
dtype_idx = constant(all_dtypes.index(dtype), dtype="int64")
else:
dtype_idx = constant(dtype, dtype="int64")
dtype = all_dtypes[dtype_idx.data]
outtype = TensorType(dtype=dtype, shape=static_shape)
out_var = outtype()
dtype = all_dtypes[dtype_idx.data]
inputs = (rng, size, dtype_idx, *dist_params)
out_var = TensorType(dtype=dtype, shape=static_shape)()
outputs = (rng.type(), out_var)
return Apply(self, inputs, outputs)
......@@ -395,9 +457,8 @@ def vectorize_random_variable(
# We extend it to accommodate the new input batch dimensions.
# Otherwise, we assume the new size already has the right values
# Need to make parameters implicit broadcasting explicit
original_dist_params = node.inputs[3:]
old_size = node.inputs[1]
original_dist_params = op.dist_params(node)
old_size = op.size_param(node)
len_old_size = get_vector_length(old_size)
original_expanded_dist_params = explicit_expand_dims(
......
import re
from pytensor.compile import optdb
from pytensor.graph.rewriting.basic import in2out, node_rewriter
from pytensor.graph.rewriting.db import SequenceDB
......@@ -164,9 +166,9 @@ def materialize_implicit_arange_choice_without_replacement(fgraph, node):
a_vector_param = arange(a_scalar_param)
new_props_dict = op._props_dict().copy()
new_ndims_params = list(op.ndims_params)
new_ndims_params[0] += 1
new_props_dict["ndims_params"] = new_ndims_params
# Signature changes from something like "(),(a),(2)->(s0, s1)" to "(a),(a),(2)->(s0, s1)"
# I.e., we substitute the first `()` by `(a)`
new_props_dict["signature"] = re.sub(r"\(\)", "(a)", op.signature, 1)
new_op = type(op)(**new_props_dict)
return new_op.make_node(rng, size, dtype, a_vector_param, *other_params).outputs
......
......@@ -123,7 +123,7 @@ def broadcast_params(params, ndims_params):
def explicit_expand_dims(
params: Sequence[TensorVariable],
ndim_params: tuple[int],
ndim_params: Sequence[int],
size_length: int = 0,
) -> list[TensorVariable]:
"""Introduce explicit expand_dims in RV parameters that are implicitly broadcasted together and/or by size."""
......@@ -137,7 +137,7 @@ def explicit_expand_dims(
# See: https://github.com/pymc-devs/pytensor/issues/568
max_batch_dims = size_length
else:
max_batch_dims = max(batch_dims)
max_batch_dims = max(batch_dims, default=0)
new_params = []
for new_param, batch_dim in zip(params, batch_dims):
......@@ -354,6 +354,11 @@ def supp_shape_from_ref_param_shape(
out: tuple
Representing the support shape for a `RandomVariable` with the given `dist_params`.
Notes
_____
This helper is no longer necessary when using signatures in `RandomVariable` subclasses.
"""
if ndim_supp <= 0:
raise ValueError("ndim_supp must be greater than 0")
......
......@@ -169,7 +169,8 @@ _DIMENSION_NAME = r"\w+"
_CORE_DIMENSION_LIST = f"(?:{_DIMENSION_NAME}(?:,{_DIMENSION_NAME})*)?"
_ARGUMENT = rf"\({_CORE_DIMENSION_LIST}\)"
_ARGUMENT_LIST = f"{_ARGUMENT}(?:,{_ARGUMENT})*"
_SIGNATURE = f"^{_ARGUMENT_LIST}->{_ARGUMENT_LIST}$"
# Allow no inputs
_SIGNATURE = f"^(?:{_ARGUMENT_LIST})?->{_ARGUMENT_LIST}$"
def _parse_gufunc_signature(
......@@ -200,6 +201,8 @@ def _parse_gufunc_signature(
tuple(re.findall(_DIMENSION_NAME, arg))
for arg in re.findall(_ARGUMENT, arg_list)
]
if arg_list # ignore no inputs
else []
for arg_list in signature.split("->")
)
......
......@@ -771,8 +771,7 @@ def test_random_unimplemented():
class NonExistentRV(RandomVariable):
name = "non-existent"
ndim_supp = 0
ndims_params = []
signature = "->()"
dtype = "floatX"
def __call__(self, size=None, **kwargs):
......@@ -798,8 +797,7 @@ def test_random_custom_implementation():
class CustomRV(RandomVariable):
name = "non-existent"
ndim_supp = 0
ndims_params = []
signature = "->()"
dtype = "floatX"
def __call__(self, size=None, **kwargs):
......
......@@ -74,52 +74,28 @@ def apply_local_rewrite_to_rv(
return new_out, f_inputs, dist_st, f_rewritten
def test_inplace_rewrites():
out = normal(0, 1)
out.owner.inputs[0].default_update = out.owner.outputs[0]
class TestRVExpraProps(RandomVariable):
name = "test"
signature = "()->()"
__props__ = ("name", "signature", "dtype", "inplace", "extra")
dtype = "floatX"
_print_name = ("TestExtraProps", "\\operatorname{TestExtra_props}")
assert out.owner.op.inplace is False
def __init__(self, extra, *args, **kwargs):
self.extra = extra
super().__init__(*args, **kwargs)
f = function(
[],
out,
mode="FAST_RUN",
)
(new_out, new_rng) = f.maker.fgraph.outputs
assert new_out.type == out.type
assert isinstance(new_out.owner.op, type(out.owner.op))
assert new_out.owner.op.inplace is True
assert all(
np.array_equal(a.data, b.data)
for a, b in zip(new_out.owner.inputs[2:], out.owner.inputs[2:])
)
assert np.array_equal(new_out.owner.inputs[1].data, [])
def test_inplace_rewrites_extra_props():
class Test(RandomVariable):
name = "test"
ndim_supp = 0
ndims_params = [0]
__props__ = ("name", "ndim_supp", "ndims_params", "dtype", "inplace", "extra")
dtype = "floatX"
_print_name = ("Test", "\\operatorname{Test}")
def __init__(self, extra, *args, **kwargs):
self.extra = extra
super().__init__(*args, **kwargs)
def make_node(self, rng, size, dtype, sigma):
return super().make_node(rng, size, dtype, sigma)
def rng_fn(self, rng, sigma, size):
return rng.normal(scale=sigma, size=size)
def rng_fn(self, rng, dtype, sigma, size):
return rng.normal(scale=sigma, size=size)
out = Test(extra="some value")(1)
out.owner.inputs[0].default_update = out.owner.outputs[0]
assert out.owner.op.inplace is False
@pytest.mark.parametrize("rv_op", [normal, TestRVExpraProps(extra="some value")])
def test_inplace_rewrites(rv_op):
out = rv_op(np.e)
node = out.owner
op = node.op
node.inputs[0].default_update = node.outputs[0]
assert op.inplace is False
f = function(
[],
......@@ -129,9 +105,10 @@ def test_inplace_rewrites_extra_props():
(new_out, new_rng) = f.maker.fgraph.outputs
assert new_out.type == out.type
assert isinstance(new_out.owner.op, type(out.owner.op))
assert new_out.owner.op.inplace is True
assert new_out.owner.op.extra == out.owner.op.extra
new_node = new_out.owner
new_op = new_node.op
assert isinstance(new_op, type(op))
assert new_op._props_dict() == (op._props_dict() | {"inplace": True})
assert all(
np.array_equal(a.data, b.data)
for a, b in zip(new_out.owner.inputs[2:], out.owner.inputs[2:])
......
......@@ -1463,11 +1463,8 @@ def batched_unweighted_choice_without_replacement_tester(
rng = shared(rng_ctor())
# Batched a implicit size
a_core_ndim = 2
core_shape_len = 1
rv_op = ChoiceWithoutReplacement(
ndim_supp=max(a_core_ndim - 1, 0) + core_shape_len,
ndims_params=[a_core_ndim, core_shape_len],
signature="(a0,a1),(1)->(s0,a1)",
dtype="int64",
)
......@@ -1483,11 +1480,8 @@ def batched_unweighted_choice_without_replacement_tester(
assert np.all((draw >= i * 10) & (draw < (i + 1) * 10))
# Explicit size broadcasts beyond a
a_core_ndim = 2
core_shape_len = 2
rv_op = ChoiceWithoutReplacement(
ndim_supp=max(a_core_ndim - 1, 0) + core_shape_len,
ndims_params=[a_core_ndim, len(core_shape)],
signature="(a0,a1),(2)->(s0,s1,a1)",
dtype="int64",
)
......@@ -1515,12 +1509,8 @@ def batched_weighted_choice_without_replacement_tester(
"""
rng = shared(rng_ctor())
# 3 ndims params indicates p is passed
a_core_ndim = 2
core_shape_len = 1
rv_op = ChoiceWithoutReplacement(
ndim_supp=max(a_core_ndim - 1, 0) + core_shape_len,
ndims_params=[a_core_ndim, 1, 1],
signature="(a0,a1),(a0),(1)->(s0,a1)",
dtype="int64",
)
......@@ -1540,11 +1530,8 @@ def batched_weighted_choice_without_replacement_tester(
# p and a are batched
# Test implicit arange
a_core_ndim = 0
core_shape_len = 2
rv_op = ChoiceWithoutReplacement(
ndim_supp=max(a_core_ndim - 1, 0) + core_shape_len,
ndims_params=[a_core_ndim, 1, 1],
signature="(),(a),(2)->(s0,s1)",
dtype="int64",
)
a = 6
......@@ -1566,11 +1553,8 @@ def batched_weighted_choice_without_replacement_tester(
assert set(draw) == set(range(i, 6, 2))
# Size broadcasts beyond a
a_core_ndim = 2
core_shape_len = 1
rv_op = ChoiceWithoutReplacement(
ndim_supp=max(a_core_ndim - 1, 0) + core_shape_len,
ndims_params=[a_core_ndim, 1, 1],
signature="(a0,a1),(a0),(1)->(s0,a1)",
dtype="int64",
)
a = np.arange(4 * 5 * 2).reshape((4, 5, 2))
......
......@@ -23,14 +23,13 @@ def test_RandomVariable_basics(strict_test_value_flags):
str_res = str(
RandomVariable(
"normal",
0,
[0, 0],
"float32",
inplace=True,
signature="(),()->()",
dtype="float32",
inplace=False,
)
)
assert str_res == "normal_rv{0, (0, 0), float32, True}"
assert str_res == 'normal_rv{"(),()->()"}'
# `ndims_params` should be a `Sequence` type
with pytest.raises(TypeError, match="^Parameter ndims_params*"):
......@@ -64,9 +63,7 @@ def test_RandomVariable_basics(strict_test_value_flags):
# Confirm that `inplace` works
rv = RandomVariable(
"normal",
0,
[0, 0],
"normal",
signature="(),()->()",
inplace=True,
)
......@@ -74,7 +71,7 @@ def test_RandomVariable_basics(strict_test_value_flags):
assert rv.destroy_map == {0: [0]}
# A no-params `RandomVariable`
rv = RandomVariable(name="test_rv", ndim_supp=0, ndims_params=())
rv = RandomVariable(name="test_rv", signature="->()")
with pytest.raises(TypeError):
rv.make_node(rng=1)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论