提交 591c47e6 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Distinguish between size=None and size=() in RandomVariables

上级 98d73d78
......@@ -12,6 +12,7 @@ from pytensor.graph import Constant
from pytensor.link.jax.dispatch.basic import jax_funcify, jax_typify
from pytensor.link.jax.dispatch.shape import JAXShapeTuple
from pytensor.tensor.shape import Shape, Shape_i
from pytensor.tensor.type_other import NoneTypeT
try:
......@@ -93,7 +94,6 @@ def jax_funcify_RandomVariable(op: ptr.RandomVariable, node, **kwargs):
rv = node.outputs[1]
out_dtype = rv.type.dtype
static_shape = rv.type.shape
batch_ndim = op.batch_ndim(node)
# Try to pass static size directly to JAX
......@@ -102,10 +102,9 @@ def jax_funcify_RandomVariable(op: ptr.RandomVariable, node, **kwargs):
# Sometimes size can be constant folded during rewrites,
# without the RandomVariable node being updated with new static types
size_param = op.size_param(node)
if isinstance(size_param, Constant):
size_tuple = tuple(size_param.data)
# PyTensor uses empty size to represent size = None
if len(size_tuple):
if isinstance(size_param, Constant) and not isinstance(
size_param.type, NoneTypeT
):
static_size = tuple(size_param.data)
# If one dimension has unknown size, either the size is determined
......@@ -115,9 +114,6 @@ def jax_funcify_RandomVariable(op: ptr.RandomVariable, node, **kwargs):
assert_size_argument_jax_compatible(node)
def sample_fn(rng, size, *parameters):
# PyTensor uses empty size to represent size = None
if jax.numpy.asarray(size).shape == (0,):
size = None
return jax_sample_fn(op, node=node)(rng, size, out_dtype, *parameters)
else:
......
......@@ -21,6 +21,7 @@ from pytensor.link.utils import (
)
from pytensor.tensor.basic import get_vector_length
from pytensor.tensor.random.type import RandomStateType
from pytensor.tensor.type_other import NoneTypeT
class RandomStateNumbaType(types.Type):
......@@ -101,9 +102,13 @@ def make_numba_random_fn(node, np_random_func):
if not isinstance(rng_param.type, RandomStateType):
raise TypeError("Numba does not support NumPy `Generator`s")
tuple_size = int(get_vector_length(op.size_param(node)))
size_param = op.size_param(node)
size_len = (
None
if isinstance(size_param.type, NoneTypeT)
else int(get_vector_length(size_param))
)
dist_params = op.dist_params(node)
size_dims = tuple_size - max(i.ndim for i in dist_params)
# Make a broadcast-capable version of the Numba supported scalar sampling
# function
......@@ -119,7 +124,7 @@ def make_numba_random_fn(node, np_random_func):
"np_random_func",
"numba_vectorize",
"to_fixed_tuple",
"tuple_size",
"size_len",
"size_dims",
"rng",
"size",
......@@ -155,10 +160,12 @@ def {bcast_fn_name}({bcast_fn_input_names}):
"out_dtype": out_dtype,
}
if tuple_size > 0:
if size_len is not None:
size_dims = size_len - max(i.ndim for i in dist_params)
random_fn_body = dedent(
f"""
size = to_fixed_tuple(size, tuple_size)
size = to_fixed_tuple(size, size_len)
data = np.empty(size, dtype=out_dtype)
for i in np.ndindex(size[:size_dims]):
......@@ -170,7 +177,7 @@ def {bcast_fn_name}({bcast_fn_input_names}):
{
"np": np,
"to_fixed_tuple": numba_ndarray.to_fixed_tuple,
"tuple_size": tuple_size,
"size_len": size_len,
"size_dims": size_dims,
}
)
......@@ -305,19 +312,24 @@ def numba_funcify_BernoulliRV(op, node, **kwargs):
@numba_funcify.register(ptr.CategoricalRV)
def numba_funcify_CategoricalRV(op: ptr.CategoricalRV, node, **kwargs):
out_dtype = node.outputs[1].type.numpy_dtype
size_len = int(get_vector_length(op.size_param(node)))
size_param = op.size_param(node)
size_len = (
None
if isinstance(size_param.type, NoneTypeT)
else int(get_vector_length(size_param))
)
p_ndim = node.inputs[-1].ndim
@numba_basic.numba_njit
def categorical_rv(rng, size, p):
if not size_len:
if size_len is None:
size_tpl = p.shape[:-1]
else:
size_tpl = numba_ndarray.to_fixed_tuple(size, size_len)
p = np.broadcast_to(p, size_tpl + p.shape[-1:])
# Workaround https://github.com/numba/numba/issues/8975
if not size_len and p_ndim == 1:
if size_len is None and p_ndim == 1:
unif_samples = np.asarray(np.random.uniform(0, 1))
else:
unif_samples = np.random.uniform(0, 1, size_tpl)
......@@ -336,13 +348,20 @@ def numba_funcify_DirichletRV(op, node, **kwargs):
out_dtype = node.outputs[1].type.numpy_dtype
alphas_ndim = op.dist_params(node)[0].type.ndim
neg_ind_shape_len = -alphas_ndim + 1
size_len = int(get_vector_length(op.size_param(node)))
size_param = op.size_param(node)
size_len = (
None
if isinstance(size_param.type, NoneTypeT)
else int(get_vector_length(size_param))
)
if alphas_ndim > 1:
@numba_basic.numba_njit
def dirichlet_rv(rng, size, alphas):
if size_len > 0:
if size_len is None:
samples_shape = alphas.shape
else:
size_tpl = numba_ndarray.to_fixed_tuple(size, size_len)
if (
0 < alphas.ndim - 1 <= len(size_tpl)
......@@ -350,8 +369,6 @@ def numba_funcify_DirichletRV(op, node, **kwargs):
):
raise ValueError("Parameters shape and size do not match.")
samples_shape = size_tpl + alphas.shape[-1:]
else:
samples_shape = alphas.shape
res = np.empty(samples_shape, dtype=out_dtype)
alphas_bcast = np.broadcast_to(alphas, samples_shape)
......@@ -365,6 +382,7 @@ def numba_funcify_DirichletRV(op, node, **kwargs):
@numba_basic.numba_njit
def dirichlet_rv(rng, size, alphas):
if size_len is not None:
size = numba_ndarray.to_fixed_tuple(size, size_len)
return (rng, np.random.dirichlet(alphas, size))
......@@ -404,8 +422,7 @@ def numba_funcify_choice_without_replacement(op, node, **kwargs):
@numba_funcify.register(ptr.PermutationRV)
def numba_funcify_permutation(op: ptr.PermutationRV, node, **kwargs):
# PyTensor uses size=() to represent size=None
size_is_none = op.size_param(node).type.shape == (0,)
size_is_none = isinstance(op.size_param(node).type, NoneTypeT)
batch_ndim = op.batch_ndim(node)
x_batch_ndim = node.inputs[-1].type.ndim - op.ndims_params[0]
......
......@@ -914,12 +914,11 @@ class MvNormalRV(RandomVariable):
# multivariate normals (or any other multivariate distributions),
# so we need to implement that here
size = tuple(size or ())
if size:
if size is None:
mean, cov = broadcast_params([mean, cov], [1, 2])
else:
mean = np.broadcast_to(mean, size + mean.shape[-1:])
cov = np.broadcast_to(cov, size + cov.shape[-2:])
else:
mean, cov = broadcast_params([mean, cov], [1, 2])
res = np.empty(mean.shape)
for idx in np.ndindex(mean.shape[:-1]):
......@@ -1800,13 +1799,11 @@ class MultinomialRV(RandomVariable):
@classmethod
def rng_fn(cls, rng, n, p, size):
if n.ndim > 0 or p.ndim > 1:
size = tuple(size or ())
if size:
if size is None:
n, p = broadcast_params([n, p], [0, 1])
else:
n = np.broadcast_to(n, size)
p = np.broadcast_to(p, size + p.shape[-1:])
else:
n, p = broadcast_params([n, p], [0, 1])
res = np.empty(p.shape, dtype=cls.dtype)
for idx in np.ndindex(p.shape[:-1]):
......@@ -2155,7 +2152,7 @@ class PermutationRV(RandomVariable):
def rng_fn(self, rng, x, size):
# We don't have access to the node in rng_fn :(
x_batch_ndim = x.ndim - self.ndims_params[0]
batch_ndim = max(x_batch_ndim, len(size or ()))
batch_ndim = max(x_batch_ndim, 0 if size is None else len(size))
if batch_ndim:
# rng.permutation has no concept of batch dims
......
......@@ -16,7 +16,6 @@ from pytensor.tensor.basic import (
as_tensor_variable,
concatenate,
constant,
get_underlying_scalar_constant_value,
get_vector_length,
infer_static_shape,
)
......@@ -28,7 +27,7 @@ from pytensor.tensor.random.utils import (
)
from pytensor.tensor.shape import shape_tuple
from pytensor.tensor.type import TensorType
from pytensor.tensor.type_other import NoneConst
from pytensor.tensor.type_other import NoneConst, NoneTypeT
from pytensor.tensor.utils import _parse_gufunc_signature, safe_signature
from pytensor.tensor.variable import TensorVariable
......@@ -196,10 +195,10 @@ class RandomVariable(Op):
def _infer_shape(
self,
size: TensorVariable,
size: TensorVariable | Variable,
dist_params: Sequence[TensorVariable],
param_shapes: Sequence[tuple[Variable, ...]] | None = None,
) -> TensorVariable | tuple[ScalarVariable, ...]:
) -> tuple[ScalarVariable | TensorVariable, ...]:
"""Compute the output shape given the size and distribution parameters.
Parameters
......@@ -225,9 +224,9 @@ class RandomVariable(Op):
self._supp_shape_from_params(dist_params, param_shapes=param_shapes)
)
if not isinstance(size.type, NoneTypeT):
size_len = get_vector_length(size)
if size_len > 0:
# Fail early when size is incompatible with parameters
for i, (param, param_ndim_supp) in enumerate(
zip(dist_params, self.ndims_params)
......@@ -281,21 +280,11 @@ class RandomVariable(Op):
shape = batch_shape + supp_shape
if not shape:
shape = constant([], dtype="int64")
return shape
def infer_shape(self, fgraph, node, input_shapes):
_, size, *dist_params = node.inputs
_, size_shape, *param_shapes = input_shapes
try:
size_len = get_vector_length(size)
except ValueError:
size_len = get_underlying_scalar_constant_value(size_shape[0])
size = tuple(size[n] for n in range(size_len))
_, _, *param_shapes = input_shapes
shape = self._infer_shape(size, dist_params, param_shapes=param_shapes)
......@@ -367,8 +356,8 @@ class RandomVariable(Op):
"The type of rng should be an instance of either RandomGeneratorType or RandomStateType"
)
shape = self._infer_shape(size, dist_params)
_, static_shape = infer_static_shape(shape)
inferred_shape = self._infer_shape(size, dist_params)
_, static_shape = infer_static_shape(inferred_shape)
inputs = (rng, size, *dist_params)
out_type = TensorType(dtype=self.dtype, shape=static_shape)
......@@ -396,21 +385,14 @@ class RandomVariable(Op):
rng, size, *args = inputs
# If `size == []`, that means no size is enforced, and NumPy is trusted
# to draw the appropriate number of samples, NumPy uses `size=None` to
# represent that. Otherwise, NumPy expects a tuple.
if np.size(size) == 0:
size = None
else:
size = tuple(size)
# Draw from `rng` if `self.inplace` is `True`, and from a copy of `rng`
# otherwise.
# Draw from `rng` if `self.inplace` is `True`, and from a copy of `rng` otherwise.
if not self.inplace:
rng = copy(rng)
rng_var_out[0] = rng
if size is not None:
size = tuple(size)
smpl_val = self.rng_fn(rng, *([*args, size]))
if not isinstance(smpl_val, np.ndarray) or str(smpl_val.dtype) != self.dtype:
......@@ -473,7 +455,9 @@ def vectorize_random_variable(
original_dist_params = op.dist_params(node)
old_size = op.size_param(node)
len_old_size = get_vector_length(old_size)
len_old_size = (
None if isinstance(old_size.type, NoneTypeT) else get_vector_length(old_size)
)
original_expanded_dist_params = explicit_expand_dims(
original_dist_params, op.ndims_params, len_old_size
......
......@@ -7,7 +7,7 @@ from pytensor.graph.op import compute_test_value
from pytensor.graph.rewriting.basic import copy_stack_trace, in2out, node_rewriter
from pytensor.scalar import integer_types
from pytensor.tensor import NoneConst
from pytensor.tensor.basic import constant, get_vector_length
from pytensor.tensor.basic import constant
from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.extra_ops import broadcast_to
from pytensor.tensor.random.op import RandomVariable
......@@ -20,7 +20,7 @@ from pytensor.tensor.subtensor import (
as_index_variable,
get_idx_list,
)
from pytensor.tensor.type_other import SliceType
from pytensor.tensor.type_other import NoneTypeT, SliceType
def is_rv_used_in_graph(base_rv, node, fgraph):
......@@ -83,9 +83,11 @@ def local_rv_size_lift(fgraph, node):
rng, size, *dist_params = node.inputs
if isinstance(size.type, NoneTypeT):
return
dist_params = broadcast_params(dist_params, node.op.ndims_params)
if get_vector_length(size) > 0:
dist_params = [
broadcast_to(
p,
......@@ -102,8 +104,6 @@ def local_rv_size_lift(fgraph, node):
)
for i, p in enumerate(dist_params)
]
else:
return
new_node = node.op.make_node(rng, None, *dist_params)
......@@ -159,11 +159,10 @@ def local_dimshuffle_rv_lift(fgraph, node):
batched_dims = rv.ndim - rv_op.ndim_supp
batched_dims_ds_order = tuple(o for o in ds_op.new_order if o not in supp_dims)
if isinstance(size.type, NoneTypeT):
# Make size explicit
missing_size_dims = batched_dims - get_vector_length(size)
if missing_size_dims > 0:
full_size = tuple(broadcast_params(dist_params, rv_op.ndims_params)[0].shape)
size = full_size[:missing_size_dims] + tuple(size)
shape = tuple(broadcast_params(dist_params, rv_op.ndims_params)[0].shape)
size = shape[:batched_dims]
# Update the size to reflect the DimShuffled dimensions
new_size = [
......
......@@ -158,7 +158,7 @@ def materialize_implicit_arange_choice_without_replacement(fgraph, node):
# No need to materialize arange
return None
rng, size, dtype, a_scalar_param, *other_params = node.inputs
rng, size, a_scalar_param, *other_params = node.inputs
if a_scalar_param.type.ndim > 0:
# Automatic vectorization could have made this parameter batched,
# there is no nice way to materialize a batched arange
......@@ -170,7 +170,7 @@ def materialize_implicit_arange_choice_without_replacement(fgraph, node):
# I.e., we substitute the first `()` by `(a)`
new_props_dict["signature"] = re.sub(r"\(\)", "(a)", op.signature, 1)
new_op = type(op)(**new_props_dict)
return new_op.make_node(rng, size, dtype, a_vector_param, *other_params).outputs
return new_op.make_node(rng, size, a_vector_param, *other_params).outputs
random_vars_opt = SequenceDB()
......
......@@ -9,8 +9,8 @@ import numpy as np
from pytensor.compile.sharedvalue import shared
from pytensor.graph.basic import Constant, Variable
from pytensor.scalar import ScalarVariable
from pytensor.tensor import get_vector_length
from pytensor.tensor.basic import as_tensor_variable, cast, constant
from pytensor.tensor import NoneConst, get_vector_length
from pytensor.tensor.basic import as_tensor_variable, cast
from pytensor.tensor.extra_ops import broadcast_arrays, broadcast_to
from pytensor.tensor.math import maximum
from pytensor.tensor.shape import shape_padleft, specify_shape
......@@ -124,7 +124,7 @@ def broadcast_params(params, ndims_params):
def explicit_expand_dims(
params: Sequence[TensorVariable],
ndim_params: Sequence[int],
size_length: int = 0,
size_length: int | None = None,
) -> list[TensorVariable]:
"""Introduce explicit expand_dims in RV parameters that are implicitly broadcasted together and/or by size."""
......@@ -132,9 +132,7 @@ def explicit_expand_dims(
param.type.ndim - ndim_param for param, ndim_param in zip(params, ndim_params)
]
if size_length:
# NOTE: PyTensor is currently treating zero-length size as size=None, which is not what Numpy does
# See: https://github.com/pymc-devs/pytensor/issues/568
if size_length is not None:
max_batch_dims = size_length
else:
max_batch_dims = max(batch_dims, default=0)
......@@ -159,30 +157,30 @@ def compute_batch_shape(params, ndims_params: Sequence[int]) -> TensorVariable:
def normalize_size_param(
size: int | np.ndarray | Variable | Sequence | None,
shape: int | np.ndarray | Variable | Sequence | None,
) -> Variable:
"""Create an PyTensor value for a ``RandomVariable`` ``size`` parameter."""
if size is None:
size = constant([], dtype="int64")
elif isinstance(size, int):
size = as_tensor_variable([size], ndim=1)
elif not isinstance(size, np.ndarray | Variable | Sequence):
if shape is None or NoneConst.equals(shape):
return NoneConst
elif isinstance(shape, int):
shape = as_tensor_variable([shape], ndim=1)
elif not isinstance(shape, np.ndarray | Variable | Sequence):
raise TypeError(
"Parameter size must be None, an integer, or a sequence with integers."
)
else:
size = cast(as_tensor_variable(size, ndim=1, dtype="int64"), "int64")
shape = cast(as_tensor_variable(shape, ndim=1, dtype="int64"), "int64")
if not isinstance(size, Constant):
if not isinstance(shape, Constant):
# This should help ensure that the length of non-constant `size`s
# will be available after certain types of cloning (e.g. the kind
# `Scan` performs)
size = specify_shape(size, (get_vector_length(size),))
shape = specify_shape(shape, (get_vector_length(shape),))
assert not any(s is None for s in size.type.shape)
assert size.dtype in int_dtypes
assert not any(s is None for s in shape.type.shape)
assert shape.dtype in int_dtypes
return size
return shape
class RandomStream:
......
......@@ -30,6 +30,7 @@ from pytensor.tensor.random.rewriting import (
from pytensor.tensor.rewriting.shape import ShapeFeature, ShapeOptimizer
from pytensor.tensor.subtensor import AdvancedSubtensor, AdvancedSubtensor1, Subtensor
from pytensor.tensor.type import iscalar, vector
from pytensor.tensor.type_other import NoneConst
no_mode = Mode("py", RewriteDatabaseQuery(include=[], exclude=[]))
......@@ -44,6 +45,9 @@ def apply_local_rewrite_to_rv(
p_pt.tag.test_value = p
dist_params_pt.append(p_pt)
if size is None:
size_pt = NoneConst
else:
size_pt = []
for s in size:
# To test DimShuffle with dropping dims we need that size dimension to be constant
......@@ -57,7 +61,9 @@ def apply_local_rewrite_to_rv(
dist_st = op_fn(dist_op(*dist_params_pt, size=size_pt, rng=rng, name=name))
f_inputs = [
p for p in dist_params_pt + size_pt if not isinstance(p, slice | Constant)
p
for p in dist_params_pt + ([] if size is None else size_pt)
if not isinstance(p, slice | Constant)
]
mode = Mode(
......@@ -135,7 +141,7 @@ def test_inplace_rewrites(rv_op):
np.array([0.0, 1.0], dtype=config.floatX),
np.array(5.0, dtype=config.floatX),
],
[],
None,
),
(
normal,
......@@ -180,7 +186,7 @@ def test_local_rv_size_lift(dist_op, dist_params, size):
rng,
)
assert pt.get_vector_length(new_out.owner.inputs[1]) == 0
assert new_out.owner.op.size_param(new_out.owner).data is None
@pytest.mark.parametrize(
......@@ -194,7 +200,7 @@ def test_local_rv_size_lift(dist_op, dist_params, size):
np.array([0.0, -100.0], dtype=np.float64),
np.array(1e-6, dtype=np.float64),
),
(),
None,
1e-7,
),
(
......@@ -205,7 +211,7 @@ def test_local_rv_size_lift(dist_op, dist_params, size):
np.array(-10.0, dtype=np.float64),
np.array(1e-6, dtype=np.float64),
),
(),
None,
1e-7,
),
(
......@@ -216,7 +222,7 @@ def test_local_rv_size_lift(dist_op, dist_params, size):
np.array(-10.0, dtype=np.float64),
np.array(1e-6, dtype=np.float64),
),
(),
None,
1e-7,
),
(
......@@ -227,7 +233,7 @@ def test_local_rv_size_lift(dist_op, dist_params, size):
np.arange(2 * 2 * 2).reshape((2, 2, 2)).astype(config.floatX),
np.array(1e-6).astype(config.floatX),
),
(),
None,
1e-3,
),
(
......@@ -440,7 +446,7 @@ def rand_bool_mask(shape, rng=None):
np.arange(30, dtype=config.floatX).reshape(3, 5, 2),
np.full((1, 5, 1), 1e-6),
),
(),
None,
),
(
# `size`-only slice
......@@ -462,7 +468,7 @@ def rand_bool_mask(shape, rng=None):
np.arange(30, dtype=config.floatX).reshape(3, 5, 2),
np.full((1, 5, 1), 1e-6),
),
(),
None,
),
(
# `size`-only slice
......@@ -484,7 +490,7 @@ def rand_bool_mask(shape, rng=None):
(0.1 - 1e-5) * np.arange(4).astype(dtype=config.floatX),
0.1 * np.arange(4).astype(dtype=config.floatX),
),
(),
None,
),
# 5
(
......@@ -570,7 +576,7 @@ def rand_bool_mask(shape, rng=None):
dtype=config.floatX,
),
),
(),
None,
),
(
# Univariate distribution with core-vector parameters
......@@ -627,7 +633,7 @@ def rand_bool_mask(shape, rng=None):
np.arange(30).reshape(5, 3, 2),
1e-6,
),
(),
None,
),
(
# Multidimensional boolean indexing
......@@ -638,7 +644,7 @@ def rand_bool_mask(shape, rng=None):
np.arange(30).reshape(5, 3, 2),
1e-6,
),
(),
None,
),
(
# Multidimensional boolean indexing
......@@ -649,7 +655,7 @@ def rand_bool_mask(shape, rng=None):
np.arange(30).reshape(5, 3, 2),
1e-6,
),
(),
None,
),
# 20
(
......@@ -661,7 +667,7 @@ def rand_bool_mask(shape, rng=None):
np.arange(30).reshape(5, 3, 2),
1e-6,
),
(),
None,
),
(
# Multidimensional boolean indexing
......@@ -687,7 +693,7 @@ def rand_bool_mask(shape, rng=None):
np.arange(30).reshape(5, 3, 2),
1e-6,
),
(),
None,
),
(
# Multidimensional boolean indexing,
......@@ -703,7 +709,7 @@ def rand_bool_mask(shape, rng=None):
np.arange(30).reshape(5, 3, 2),
1e-6,
),
(),
None,
),
(
# Multivariate distribution: indexing dips into core dimension
......@@ -714,7 +720,7 @@ def rand_bool_mask(shape, rng=None):
np.array([[-1, 20], [300, -4000]], dtype=config.floatX),
np.eye(2).astype(config.floatX) * 1e-6,
),
(),
None,
),
# 25
(
......@@ -726,7 +732,7 @@ def rand_bool_mask(shape, rng=None):
np.array([[-1, 20], [300, -4000]], dtype=config.floatX),
np.eye(2).astype(config.floatX) * 1e-6,
),
(),
None,
),
(
# Multivariate distribution: advanced integer indexing
......@@ -740,7 +746,7 @@ def rand_bool_mask(shape, rng=None):
),
np.eye(3, dtype=config.floatX) * 1e-6,
),
(),
None,
),
(
# Multivariate distribution: dummy slice "dips" into core dimension
......
......@@ -212,7 +212,7 @@ sd_pt.tag.test_value = np.array(1.0, dtype=config.floatX)
@pytest.mark.parametrize(
"M, sd, size",
[
(pt.as_tensor_variable(np.array(1.0, dtype=config.floatX)), sd_pt, ()),
(pt.as_tensor_variable(np.array(1.0, dtype=config.floatX)), sd_pt, None),
(
pt.as_tensor_variable(np.array(1.0, dtype=config.floatX)),
sd_pt,
......@@ -223,10 +223,10 @@ sd_pt.tag.test_value = np.array(1.0, dtype=config.floatX)
sd_pt,
(2, M_pt),
),
(pt.zeros((M_pt,)), sd_pt, ()),
(pt.zeros((M_pt,)), sd_pt, None),
(pt.zeros((M_pt,)), sd_pt, (M_pt,)),
(pt.zeros((M_pt,)), sd_pt, (2, M_pt)),
(pt.zeros((M_pt,)), pt.ones((M_pt,)), ()),
(pt.zeros((M_pt,)), pt.ones((M_pt,)), None),
(pt.zeros((M_pt,)), pt.ones((M_pt,)), (2, M_pt)),
(
create_pytensor_param(
......@@ -244,9 +244,10 @@ sd_pt.tag.test_value = np.array(1.0, dtype=config.floatX)
)
def test_normal_infer_shape(M, sd, size):
rv = normal(M, sd, size=size)
rv_shape = list(normal._infer_shape(size or (), [M, sd], None))
size_pt = rv.owner.op.size_param(rv.owner)
rv_shape = list(normal._infer_shape(size_pt, [M, sd], None))
all_args = (M, sd, *size)
all_args = (M, sd, *(() if size is None else size))
fn_inputs = [
i
for i in graph_inputs([a for a in all_args if isinstance(a, Variable)])
......@@ -525,8 +526,8 @@ def mvnormal_test_fn(mean=None, cov=None, size=None, random_state=None):
mean = np.array([0.0], dtype=config.floatX)
if cov is None:
cov = np.array([[1.0]], dtype=config.floatX)
if size is None:
size = ()
if size is not None:
size = tuple(size)
return multivariate_normal.rng_fn(random_state, mean, cov, size)
......@@ -713,19 +714,20 @@ M_pt.tag.test_value = 3
@pytest.mark.parametrize(
"M, size",
[
(pt.ones((M_pt,)), ()),
(pt.ones((M_pt,)), None),
(pt.ones((M_pt,)), (M_pt + 1,)),
(pt.ones((M_pt,)), (2, M_pt)),
(pt.ones((M_pt, M_pt + 1)), ()),
(pt.ones((M_pt, M_pt + 1)), None),
(pt.ones((M_pt, M_pt + 1)), (M_pt + 2, M_pt)),
(pt.ones((M_pt, M_pt + 1)), (2, M_pt + 2, M_pt + 3, M_pt)),
],
)
def test_dirichlet_infer_shape(M, size):
rv = dirichlet(M, size=size)
rv_shape = list(dirichlet._infer_shape(size or (), [M], None))
size_pt = rv.owner.op.size_param(rv.owner)
rv_shape = list(dirichlet._infer_shape(size_pt, [M], None))
all_args = (M, *size)
all_args = (M, *(() if size is None else size))
fn_inputs = [
i
for i in graph_inputs([a for a in all_args if isinstance(a, Variable)])
......@@ -1620,8 +1622,7 @@ def test_unnatural_batched_dims(batch_dims_tester):
@config.change_flags(compute_test_value="off")
def test_pickle():
# This is an interesting `Op` case, because it has `None` types and a
# conditional dtype
# This is an interesting `Op` case, because it has a conditional dtype
sample_a = choice(5, replace=False, size=(2, 3))
a_pkl = pickle.dumps(sample_a)
......
......@@ -69,7 +69,7 @@ def test_RandomVariable_basics(strict_test_value_flags):
# `RandomVariable._infer_shape` should handle no parameters
rv_shape = rv._infer_shape(pt.constant([]), (), [])
assert rv_shape.equals(pt.constant([], dtype="int64"))
assert rv_shape == ()
# `dtype` is respected
rv = RandomVariable("normal", signature="(),()->()", dtype="int32")
......@@ -299,3 +299,16 @@ def test_vectorize():
vect_node = vectorize_graph(out, {mu: new_mu, sigma: new_sigma}).owner
assert isinstance(vect_node.op, NormalRV)
assert vect_node.default_output().type.shape == (10, 2, 5)
def test_size_none_vs_empty():
rv = RandomVariable(
"normal",
signature="(),()->()",
)
assert rv([0], [1], size=None).type.shape == (1,)
with pytest.raises(
ValueError, match="Size length is incompatible with batched dimensions"
):
rv([0], [1], size=())
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论