提交 591c47e6 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Distinguish between size=None and size=() in RandomVariables

上级 98d73d78
...@@ -12,6 +12,7 @@ from pytensor.graph import Constant ...@@ -12,6 +12,7 @@ from pytensor.graph import Constant
from pytensor.link.jax.dispatch.basic import jax_funcify, jax_typify from pytensor.link.jax.dispatch.basic import jax_funcify, jax_typify
from pytensor.link.jax.dispatch.shape import JAXShapeTuple from pytensor.link.jax.dispatch.shape import JAXShapeTuple
from pytensor.tensor.shape import Shape, Shape_i from pytensor.tensor.shape import Shape, Shape_i
from pytensor.tensor.type_other import NoneTypeT
try: try:
...@@ -93,7 +94,6 @@ def jax_funcify_RandomVariable(op: ptr.RandomVariable, node, **kwargs): ...@@ -93,7 +94,6 @@ def jax_funcify_RandomVariable(op: ptr.RandomVariable, node, **kwargs):
rv = node.outputs[1] rv = node.outputs[1]
out_dtype = rv.type.dtype out_dtype = rv.type.dtype
static_shape = rv.type.shape static_shape = rv.type.shape
batch_ndim = op.batch_ndim(node) batch_ndim = op.batch_ndim(node)
# Try to pass static size directly to JAX # Try to pass static size directly to JAX
...@@ -102,11 +102,10 @@ def jax_funcify_RandomVariable(op: ptr.RandomVariable, node, **kwargs): ...@@ -102,11 +102,10 @@ def jax_funcify_RandomVariable(op: ptr.RandomVariable, node, **kwargs):
# Sometimes size can be constant folded during rewrites, # Sometimes size can be constant folded during rewrites,
# without the RandomVariable node being updated with new static types # without the RandomVariable node being updated with new static types
size_param = op.size_param(node) size_param = op.size_param(node)
if isinstance(size_param, Constant): if isinstance(size_param, Constant) and not isinstance(
size_tuple = tuple(size_param.data) size_param.type, NoneTypeT
# PyTensor uses empty size to represent size = None ):
if len(size_tuple): static_size = tuple(size_param.data)
static_size = tuple(size_param.data)
# If one dimension has unknown size, either the size is determined # If one dimension has unknown size, either the size is determined
# by a `Shape` operator in which case JAX will compile, or it is # by a `Shape` operator in which case JAX will compile, or it is
...@@ -115,9 +114,6 @@ def jax_funcify_RandomVariable(op: ptr.RandomVariable, node, **kwargs): ...@@ -115,9 +114,6 @@ def jax_funcify_RandomVariable(op: ptr.RandomVariable, node, **kwargs):
assert_size_argument_jax_compatible(node) assert_size_argument_jax_compatible(node)
def sample_fn(rng, size, *parameters): def sample_fn(rng, size, *parameters):
# PyTensor uses empty size to represent size = None
if jax.numpy.asarray(size).shape == (0,):
size = None
return jax_sample_fn(op, node=node)(rng, size, out_dtype, *parameters) return jax_sample_fn(op, node=node)(rng, size, out_dtype, *parameters)
else: else:
......
...@@ -21,6 +21,7 @@ from pytensor.link.utils import ( ...@@ -21,6 +21,7 @@ from pytensor.link.utils import (
) )
from pytensor.tensor.basic import get_vector_length from pytensor.tensor.basic import get_vector_length
from pytensor.tensor.random.type import RandomStateType from pytensor.tensor.random.type import RandomStateType
from pytensor.tensor.type_other import NoneTypeT
class RandomStateNumbaType(types.Type): class RandomStateNumbaType(types.Type):
...@@ -101,9 +102,13 @@ def make_numba_random_fn(node, np_random_func): ...@@ -101,9 +102,13 @@ def make_numba_random_fn(node, np_random_func):
if not isinstance(rng_param.type, RandomStateType): if not isinstance(rng_param.type, RandomStateType):
raise TypeError("Numba does not support NumPy `Generator`s") raise TypeError("Numba does not support NumPy `Generator`s")
tuple_size = int(get_vector_length(op.size_param(node))) size_param = op.size_param(node)
size_len = (
None
if isinstance(size_param.type, NoneTypeT)
else int(get_vector_length(size_param))
)
dist_params = op.dist_params(node) dist_params = op.dist_params(node)
size_dims = tuple_size - max(i.ndim for i in dist_params)
# Make a broadcast-capable version of the Numba supported scalar sampling # Make a broadcast-capable version of the Numba supported scalar sampling
# function # function
...@@ -119,7 +124,7 @@ def make_numba_random_fn(node, np_random_func): ...@@ -119,7 +124,7 @@ def make_numba_random_fn(node, np_random_func):
"np_random_func", "np_random_func",
"numba_vectorize", "numba_vectorize",
"to_fixed_tuple", "to_fixed_tuple",
"tuple_size", "size_len",
"size_dims", "size_dims",
"rng", "rng",
"size", "size",
...@@ -155,10 +160,12 @@ def {bcast_fn_name}({bcast_fn_input_names}): ...@@ -155,10 +160,12 @@ def {bcast_fn_name}({bcast_fn_input_names}):
"out_dtype": out_dtype, "out_dtype": out_dtype,
} }
if tuple_size > 0: if size_len is not None:
size_dims = size_len - max(i.ndim for i in dist_params)
random_fn_body = dedent( random_fn_body = dedent(
f""" f"""
size = to_fixed_tuple(size, tuple_size) size = to_fixed_tuple(size, size_len)
data = np.empty(size, dtype=out_dtype) data = np.empty(size, dtype=out_dtype)
for i in np.ndindex(size[:size_dims]): for i in np.ndindex(size[:size_dims]):
...@@ -170,7 +177,7 @@ def {bcast_fn_name}({bcast_fn_input_names}): ...@@ -170,7 +177,7 @@ def {bcast_fn_name}({bcast_fn_input_names}):
{ {
"np": np, "np": np,
"to_fixed_tuple": numba_ndarray.to_fixed_tuple, "to_fixed_tuple": numba_ndarray.to_fixed_tuple,
"tuple_size": tuple_size, "size_len": size_len,
"size_dims": size_dims, "size_dims": size_dims,
} }
) )
...@@ -305,19 +312,24 @@ def numba_funcify_BernoulliRV(op, node, **kwargs): ...@@ -305,19 +312,24 @@ def numba_funcify_BernoulliRV(op, node, **kwargs):
@numba_funcify.register(ptr.CategoricalRV) @numba_funcify.register(ptr.CategoricalRV)
def numba_funcify_CategoricalRV(op: ptr.CategoricalRV, node, **kwargs): def numba_funcify_CategoricalRV(op: ptr.CategoricalRV, node, **kwargs):
out_dtype = node.outputs[1].type.numpy_dtype out_dtype = node.outputs[1].type.numpy_dtype
size_len = int(get_vector_length(op.size_param(node))) size_param = op.size_param(node)
size_len = (
None
if isinstance(size_param.type, NoneTypeT)
else int(get_vector_length(size_param))
)
p_ndim = node.inputs[-1].ndim p_ndim = node.inputs[-1].ndim
@numba_basic.numba_njit @numba_basic.numba_njit
def categorical_rv(rng, size, p): def categorical_rv(rng, size, p):
if not size_len: if size_len is None:
size_tpl = p.shape[:-1] size_tpl = p.shape[:-1]
else: else:
size_tpl = numba_ndarray.to_fixed_tuple(size, size_len) size_tpl = numba_ndarray.to_fixed_tuple(size, size_len)
p = np.broadcast_to(p, size_tpl + p.shape[-1:]) p = np.broadcast_to(p, size_tpl + p.shape[-1:])
# Workaround https://github.com/numba/numba/issues/8975 # Workaround https://github.com/numba/numba/issues/8975
if not size_len and p_ndim == 1: if size_len is None and p_ndim == 1:
unif_samples = np.asarray(np.random.uniform(0, 1)) unif_samples = np.asarray(np.random.uniform(0, 1))
else: else:
unif_samples = np.random.uniform(0, 1, size_tpl) unif_samples = np.random.uniform(0, 1, size_tpl)
...@@ -336,13 +348,20 @@ def numba_funcify_DirichletRV(op, node, **kwargs): ...@@ -336,13 +348,20 @@ def numba_funcify_DirichletRV(op, node, **kwargs):
out_dtype = node.outputs[1].type.numpy_dtype out_dtype = node.outputs[1].type.numpy_dtype
alphas_ndim = op.dist_params(node)[0].type.ndim alphas_ndim = op.dist_params(node)[0].type.ndim
neg_ind_shape_len = -alphas_ndim + 1 neg_ind_shape_len = -alphas_ndim + 1
size_len = int(get_vector_length(op.size_param(node))) size_param = op.size_param(node)
size_len = (
None
if isinstance(size_param.type, NoneTypeT)
else int(get_vector_length(size_param))
)
if alphas_ndim > 1: if alphas_ndim > 1:
@numba_basic.numba_njit @numba_basic.numba_njit
def dirichlet_rv(rng, size, alphas): def dirichlet_rv(rng, size, alphas):
if size_len > 0: if size_len is None:
samples_shape = alphas.shape
else:
size_tpl = numba_ndarray.to_fixed_tuple(size, size_len) size_tpl = numba_ndarray.to_fixed_tuple(size, size_len)
if ( if (
0 < alphas.ndim - 1 <= len(size_tpl) 0 < alphas.ndim - 1 <= len(size_tpl)
...@@ -350,8 +369,6 @@ def numba_funcify_DirichletRV(op, node, **kwargs): ...@@ -350,8 +369,6 @@ def numba_funcify_DirichletRV(op, node, **kwargs):
): ):
raise ValueError("Parameters shape and size do not match.") raise ValueError("Parameters shape and size do not match.")
samples_shape = size_tpl + alphas.shape[-1:] samples_shape = size_tpl + alphas.shape[-1:]
else:
samples_shape = alphas.shape
res = np.empty(samples_shape, dtype=out_dtype) res = np.empty(samples_shape, dtype=out_dtype)
alphas_bcast = np.broadcast_to(alphas, samples_shape) alphas_bcast = np.broadcast_to(alphas, samples_shape)
...@@ -365,7 +382,8 @@ def numba_funcify_DirichletRV(op, node, **kwargs): ...@@ -365,7 +382,8 @@ def numba_funcify_DirichletRV(op, node, **kwargs):
@numba_basic.numba_njit @numba_basic.numba_njit
def dirichlet_rv(rng, size, alphas): def dirichlet_rv(rng, size, alphas):
size = numba_ndarray.to_fixed_tuple(size, size_len) if size_len is not None:
size = numba_ndarray.to_fixed_tuple(size, size_len)
return (rng, np.random.dirichlet(alphas, size)) return (rng, np.random.dirichlet(alphas, size))
return dirichlet_rv return dirichlet_rv
...@@ -404,8 +422,7 @@ def numba_funcify_choice_without_replacement(op, node, **kwargs): ...@@ -404,8 +422,7 @@ def numba_funcify_choice_without_replacement(op, node, **kwargs):
@numba_funcify.register(ptr.PermutationRV) @numba_funcify.register(ptr.PermutationRV)
def numba_funcify_permutation(op: ptr.PermutationRV, node, **kwargs): def numba_funcify_permutation(op: ptr.PermutationRV, node, **kwargs):
# PyTensor uses size=() to represent size=None size_is_none = isinstance(op.size_param(node).type, NoneTypeT)
size_is_none = op.size_param(node).type.shape == (0,)
batch_ndim = op.batch_ndim(node) batch_ndim = op.batch_ndim(node)
x_batch_ndim = node.inputs[-1].type.ndim - op.ndims_params[0] x_batch_ndim = node.inputs[-1].type.ndim - op.ndims_params[0]
......
...@@ -914,12 +914,11 @@ class MvNormalRV(RandomVariable): ...@@ -914,12 +914,11 @@ class MvNormalRV(RandomVariable):
# multivariate normals (or any other multivariate distributions), # multivariate normals (or any other multivariate distributions),
# so we need to implement that here # so we need to implement that here
size = tuple(size or ()) if size is None:
if size: mean, cov = broadcast_params([mean, cov], [1, 2])
else:
mean = np.broadcast_to(mean, size + mean.shape[-1:]) mean = np.broadcast_to(mean, size + mean.shape[-1:])
cov = np.broadcast_to(cov, size + cov.shape[-2:]) cov = np.broadcast_to(cov, size + cov.shape[-2:])
else:
mean, cov = broadcast_params([mean, cov], [1, 2])
res = np.empty(mean.shape) res = np.empty(mean.shape)
for idx in np.ndindex(mean.shape[:-1]): for idx in np.ndindex(mean.shape[:-1]):
...@@ -1800,13 +1799,11 @@ class MultinomialRV(RandomVariable): ...@@ -1800,13 +1799,11 @@ class MultinomialRV(RandomVariable):
@classmethod @classmethod
def rng_fn(cls, rng, n, p, size): def rng_fn(cls, rng, n, p, size):
if n.ndim > 0 or p.ndim > 1: if n.ndim > 0 or p.ndim > 1:
size = tuple(size or ()) if size is None:
n, p = broadcast_params([n, p], [0, 1])
if size: else:
n = np.broadcast_to(n, size) n = np.broadcast_to(n, size)
p = np.broadcast_to(p, size + p.shape[-1:]) p = np.broadcast_to(p, size + p.shape[-1:])
else:
n, p = broadcast_params([n, p], [0, 1])
res = np.empty(p.shape, dtype=cls.dtype) res = np.empty(p.shape, dtype=cls.dtype)
for idx in np.ndindex(p.shape[:-1]): for idx in np.ndindex(p.shape[:-1]):
...@@ -2155,7 +2152,7 @@ class PermutationRV(RandomVariable): ...@@ -2155,7 +2152,7 @@ class PermutationRV(RandomVariable):
def rng_fn(self, rng, x, size): def rng_fn(self, rng, x, size):
# We don't have access to the node in rng_fn :( # We don't have access to the node in rng_fn :(
x_batch_ndim = x.ndim - self.ndims_params[0] x_batch_ndim = x.ndim - self.ndims_params[0]
batch_ndim = max(x_batch_ndim, len(size or ())) batch_ndim = max(x_batch_ndim, 0 if size is None else len(size))
if batch_ndim: if batch_ndim:
# rng.permutation has no concept of batch dims # rng.permutation has no concept of batch dims
......
...@@ -16,7 +16,6 @@ from pytensor.tensor.basic import ( ...@@ -16,7 +16,6 @@ from pytensor.tensor.basic import (
as_tensor_variable, as_tensor_variable,
concatenate, concatenate,
constant, constant,
get_underlying_scalar_constant_value,
get_vector_length, get_vector_length,
infer_static_shape, infer_static_shape,
) )
...@@ -28,7 +27,7 @@ from pytensor.tensor.random.utils import ( ...@@ -28,7 +27,7 @@ from pytensor.tensor.random.utils import (
) )
from pytensor.tensor.shape import shape_tuple from pytensor.tensor.shape import shape_tuple
from pytensor.tensor.type import TensorType from pytensor.tensor.type import TensorType
from pytensor.tensor.type_other import NoneConst from pytensor.tensor.type_other import NoneConst, NoneTypeT
from pytensor.tensor.utils import _parse_gufunc_signature, safe_signature from pytensor.tensor.utils import _parse_gufunc_signature, safe_signature
from pytensor.tensor.variable import TensorVariable from pytensor.tensor.variable import TensorVariable
...@@ -196,10 +195,10 @@ class RandomVariable(Op): ...@@ -196,10 +195,10 @@ class RandomVariable(Op):
def _infer_shape( def _infer_shape(
self, self,
size: TensorVariable, size: TensorVariable | Variable,
dist_params: Sequence[TensorVariable], dist_params: Sequence[TensorVariable],
param_shapes: Sequence[tuple[Variable, ...]] | None = None, param_shapes: Sequence[tuple[Variable, ...]] | None = None,
) -> TensorVariable | tuple[ScalarVariable, ...]: ) -> tuple[ScalarVariable | TensorVariable, ...]:
"""Compute the output shape given the size and distribution parameters. """Compute the output shape given the size and distribution parameters.
Parameters Parameters
...@@ -225,9 +224,9 @@ class RandomVariable(Op): ...@@ -225,9 +224,9 @@ class RandomVariable(Op):
self._supp_shape_from_params(dist_params, param_shapes=param_shapes) self._supp_shape_from_params(dist_params, param_shapes=param_shapes)
) )
size_len = get_vector_length(size) if not isinstance(size.type, NoneTypeT):
size_len = get_vector_length(size)
if size_len > 0:
# Fail early when size is incompatible with parameters # Fail early when size is incompatible with parameters
for i, (param, param_ndim_supp) in enumerate( for i, (param, param_ndim_supp) in enumerate(
zip(dist_params, self.ndims_params) zip(dist_params, self.ndims_params)
...@@ -281,21 +280,11 @@ class RandomVariable(Op): ...@@ -281,21 +280,11 @@ class RandomVariable(Op):
shape = batch_shape + supp_shape shape = batch_shape + supp_shape
if not shape:
shape = constant([], dtype="int64")
return shape return shape
def infer_shape(self, fgraph, node, input_shapes): def infer_shape(self, fgraph, node, input_shapes):
_, size, *dist_params = node.inputs _, size, *dist_params = node.inputs
_, size_shape, *param_shapes = input_shapes _, _, *param_shapes = input_shapes
try:
size_len = get_vector_length(size)
except ValueError:
size_len = get_underlying_scalar_constant_value(size_shape[0])
size = tuple(size[n] for n in range(size_len))
shape = self._infer_shape(size, dist_params, param_shapes=param_shapes) shape = self._infer_shape(size, dist_params, param_shapes=param_shapes)
...@@ -367,8 +356,8 @@ class RandomVariable(Op): ...@@ -367,8 +356,8 @@ class RandomVariable(Op):
"The type of rng should be an instance of either RandomGeneratorType or RandomStateType" "The type of rng should be an instance of either RandomGeneratorType or RandomStateType"
) )
shape = self._infer_shape(size, dist_params) inferred_shape = self._infer_shape(size, dist_params)
_, static_shape = infer_static_shape(shape) _, static_shape = infer_static_shape(inferred_shape)
inputs = (rng, size, *dist_params) inputs = (rng, size, *dist_params)
out_type = TensorType(dtype=self.dtype, shape=static_shape) out_type = TensorType(dtype=self.dtype, shape=static_shape)
...@@ -396,21 +385,14 @@ class RandomVariable(Op): ...@@ -396,21 +385,14 @@ class RandomVariable(Op):
rng, size, *args = inputs rng, size, *args = inputs
# If `size == []`, that means no size is enforced, and NumPy is trusted # Draw from `rng` if `self.inplace` is `True`, and from a copy of `rng` otherwise.
# to draw the appropriate number of samples, NumPy uses `size=None` to
# represent that. Otherwise, NumPy expects a tuple.
if np.size(size) == 0:
size = None
else:
size = tuple(size)
# Draw from `rng` if `self.inplace` is `True`, and from a copy of `rng`
# otherwise.
if not self.inplace: if not self.inplace:
rng = copy(rng) rng = copy(rng)
rng_var_out[0] = rng rng_var_out[0] = rng
if size is not None:
size = tuple(size)
smpl_val = self.rng_fn(rng, *([*args, size])) smpl_val = self.rng_fn(rng, *([*args, size]))
if not isinstance(smpl_val, np.ndarray) or str(smpl_val.dtype) != self.dtype: if not isinstance(smpl_val, np.ndarray) or str(smpl_val.dtype) != self.dtype:
...@@ -473,7 +455,9 @@ def vectorize_random_variable( ...@@ -473,7 +455,9 @@ def vectorize_random_variable(
original_dist_params = op.dist_params(node) original_dist_params = op.dist_params(node)
old_size = op.size_param(node) old_size = op.size_param(node)
len_old_size = get_vector_length(old_size) len_old_size = (
None if isinstance(old_size.type, NoneTypeT) else get_vector_length(old_size)
)
original_expanded_dist_params = explicit_expand_dims( original_expanded_dist_params = explicit_expand_dims(
original_dist_params, op.ndims_params, len_old_size original_dist_params, op.ndims_params, len_old_size
......
...@@ -7,7 +7,7 @@ from pytensor.graph.op import compute_test_value ...@@ -7,7 +7,7 @@ from pytensor.graph.op import compute_test_value
from pytensor.graph.rewriting.basic import copy_stack_trace, in2out, node_rewriter from pytensor.graph.rewriting.basic import copy_stack_trace, in2out, node_rewriter
from pytensor.scalar import integer_types from pytensor.scalar import integer_types
from pytensor.tensor import NoneConst from pytensor.tensor import NoneConst
from pytensor.tensor.basic import constant, get_vector_length from pytensor.tensor.basic import constant
from pytensor.tensor.elemwise import DimShuffle from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.extra_ops import broadcast_to from pytensor.tensor.extra_ops import broadcast_to
from pytensor.tensor.random.op import RandomVariable from pytensor.tensor.random.op import RandomVariable
...@@ -20,7 +20,7 @@ from pytensor.tensor.subtensor import ( ...@@ -20,7 +20,7 @@ from pytensor.tensor.subtensor import (
as_index_variable, as_index_variable,
get_idx_list, get_idx_list,
) )
from pytensor.tensor.type_other import SliceType from pytensor.tensor.type_other import NoneTypeT, SliceType
def is_rv_used_in_graph(base_rv, node, fgraph): def is_rv_used_in_graph(base_rv, node, fgraph):
...@@ -83,27 +83,27 @@ def local_rv_size_lift(fgraph, node): ...@@ -83,27 +83,27 @@ def local_rv_size_lift(fgraph, node):
rng, size, *dist_params = node.inputs rng, size, *dist_params = node.inputs
if isinstance(size.type, NoneTypeT):
return
dist_params = broadcast_params(dist_params, node.op.ndims_params) dist_params = broadcast_params(dist_params, node.op.ndims_params)
if get_vector_length(size) > 0: dist_params = [
dist_params = [ broadcast_to(
broadcast_to( p,
p, (
( tuple(size)
tuple(size) + (
+ ( tuple(p.shape)[-node.op.ndims_params[i] :]
tuple(p.shape)[-node.op.ndims_params[i] :] if node.op.ndims_params[i] > 0
if node.op.ndims_params[i] > 0 else ()
else ()
)
) )
if node.op.ndim_supp > 0
else size,
) )
for i, p in enumerate(dist_params) if node.op.ndim_supp > 0
] else size,
else: )
return for i, p in enumerate(dist_params)
]
new_node = node.op.make_node(rng, None, *dist_params) new_node = node.op.make_node(rng, None, *dist_params)
...@@ -159,11 +159,10 @@ def local_dimshuffle_rv_lift(fgraph, node): ...@@ -159,11 +159,10 @@ def local_dimshuffle_rv_lift(fgraph, node):
batched_dims = rv.ndim - rv_op.ndim_supp batched_dims = rv.ndim - rv_op.ndim_supp
batched_dims_ds_order = tuple(o for o in ds_op.new_order if o not in supp_dims) batched_dims_ds_order = tuple(o for o in ds_op.new_order if o not in supp_dims)
# Make size explicit if isinstance(size.type, NoneTypeT):
missing_size_dims = batched_dims - get_vector_length(size) # Make size explicit
if missing_size_dims > 0: shape = tuple(broadcast_params(dist_params, rv_op.ndims_params)[0].shape)
full_size = tuple(broadcast_params(dist_params, rv_op.ndims_params)[0].shape) size = shape[:batched_dims]
size = full_size[:missing_size_dims] + tuple(size)
# Update the size to reflect the DimShuffled dimensions # Update the size to reflect the DimShuffled dimensions
new_size = [ new_size = [
......
...@@ -158,7 +158,7 @@ def materialize_implicit_arange_choice_without_replacement(fgraph, node): ...@@ -158,7 +158,7 @@ def materialize_implicit_arange_choice_without_replacement(fgraph, node):
# No need to materialize arange # No need to materialize arange
return None return None
rng, size, dtype, a_scalar_param, *other_params = node.inputs rng, size, a_scalar_param, *other_params = node.inputs
if a_scalar_param.type.ndim > 0: if a_scalar_param.type.ndim > 0:
# Automatic vectorization could have made this parameter batched, # Automatic vectorization could have made this parameter batched,
# there is no nice way to materialize a batched arange # there is no nice way to materialize a batched arange
...@@ -170,7 +170,7 @@ def materialize_implicit_arange_choice_without_replacement(fgraph, node): ...@@ -170,7 +170,7 @@ def materialize_implicit_arange_choice_without_replacement(fgraph, node):
# I.e., we substitute the first `()` by `(a)` # I.e., we substitute the first `()` by `(a)`
new_props_dict["signature"] = re.sub(r"\(\)", "(a)", op.signature, 1) new_props_dict["signature"] = re.sub(r"\(\)", "(a)", op.signature, 1)
new_op = type(op)(**new_props_dict) new_op = type(op)(**new_props_dict)
return new_op.make_node(rng, size, dtype, a_vector_param, *other_params).outputs return new_op.make_node(rng, size, a_vector_param, *other_params).outputs
random_vars_opt = SequenceDB() random_vars_opt = SequenceDB()
......
...@@ -9,8 +9,8 @@ import numpy as np ...@@ -9,8 +9,8 @@ import numpy as np
from pytensor.compile.sharedvalue import shared from pytensor.compile.sharedvalue import shared
from pytensor.graph.basic import Constant, Variable from pytensor.graph.basic import Constant, Variable
from pytensor.scalar import ScalarVariable from pytensor.scalar import ScalarVariable
from pytensor.tensor import get_vector_length from pytensor.tensor import NoneConst, get_vector_length
from pytensor.tensor.basic import as_tensor_variable, cast, constant from pytensor.tensor.basic import as_tensor_variable, cast
from pytensor.tensor.extra_ops import broadcast_arrays, broadcast_to from pytensor.tensor.extra_ops import broadcast_arrays, broadcast_to
from pytensor.tensor.math import maximum from pytensor.tensor.math import maximum
from pytensor.tensor.shape import shape_padleft, specify_shape from pytensor.tensor.shape import shape_padleft, specify_shape
...@@ -124,7 +124,7 @@ def broadcast_params(params, ndims_params): ...@@ -124,7 +124,7 @@ def broadcast_params(params, ndims_params):
def explicit_expand_dims( def explicit_expand_dims(
params: Sequence[TensorVariable], params: Sequence[TensorVariable],
ndim_params: Sequence[int], ndim_params: Sequence[int],
size_length: int = 0, size_length: int | None = None,
) -> list[TensorVariable]: ) -> list[TensorVariable]:
"""Introduce explicit expand_dims in RV parameters that are implicitly broadcasted together and/or by size.""" """Introduce explicit expand_dims in RV parameters that are implicitly broadcasted together and/or by size."""
...@@ -132,9 +132,7 @@ def explicit_expand_dims( ...@@ -132,9 +132,7 @@ def explicit_expand_dims(
param.type.ndim - ndim_param for param, ndim_param in zip(params, ndim_params) param.type.ndim - ndim_param for param, ndim_param in zip(params, ndim_params)
] ]
if size_length: if size_length is not None:
# NOTE: PyTensor is currently treating zero-length size as size=None, which is not what Numpy does
# See: https://github.com/pymc-devs/pytensor/issues/568
max_batch_dims = size_length max_batch_dims = size_length
else: else:
max_batch_dims = max(batch_dims, default=0) max_batch_dims = max(batch_dims, default=0)
...@@ -159,30 +157,30 @@ def compute_batch_shape(params, ndims_params: Sequence[int]) -> TensorVariable: ...@@ -159,30 +157,30 @@ def compute_batch_shape(params, ndims_params: Sequence[int]) -> TensorVariable:
def normalize_size_param( def normalize_size_param(
size: int | np.ndarray | Variable | Sequence | None, shape: int | np.ndarray | Variable | Sequence | None,
) -> Variable: ) -> Variable:
"""Create an PyTensor value for a ``RandomVariable`` ``size`` parameter.""" """Create an PyTensor value for a ``RandomVariable`` ``size`` parameter."""
if size is None: if shape is None or NoneConst.equals(shape):
size = constant([], dtype="int64") return NoneConst
elif isinstance(size, int): elif isinstance(shape, int):
size = as_tensor_variable([size], ndim=1) shape = as_tensor_variable([shape], ndim=1)
elif not isinstance(size, np.ndarray | Variable | Sequence): elif not isinstance(shape, np.ndarray | Variable | Sequence):
raise TypeError( raise TypeError(
"Parameter size must be None, an integer, or a sequence with integers." "Parameter size must be None, an integer, or a sequence with integers."
) )
else: else:
size = cast(as_tensor_variable(size, ndim=1, dtype="int64"), "int64") shape = cast(as_tensor_variable(shape, ndim=1, dtype="int64"), "int64")
if not isinstance(size, Constant): if not isinstance(shape, Constant):
# This should help ensure that the length of non-constant `size`s # This should help ensure that the length of non-constant `size`s
# will be available after certain types of cloning (e.g. the kind # will be available after certain types of cloning (e.g. the kind
# `Scan` performs) # `Scan` performs)
size = specify_shape(size, (get_vector_length(size),)) shape = specify_shape(shape, (get_vector_length(shape),))
assert not any(s is None for s in size.type.shape) assert not any(s is None for s in shape.type.shape)
assert size.dtype in int_dtypes assert shape.dtype in int_dtypes
return size return shape
class RandomStream: class RandomStream:
......
...@@ -30,6 +30,7 @@ from pytensor.tensor.random.rewriting import ( ...@@ -30,6 +30,7 @@ from pytensor.tensor.random.rewriting import (
from pytensor.tensor.rewriting.shape import ShapeFeature, ShapeOptimizer from pytensor.tensor.rewriting.shape import ShapeFeature, ShapeOptimizer
from pytensor.tensor.subtensor import AdvancedSubtensor, AdvancedSubtensor1, Subtensor from pytensor.tensor.subtensor import AdvancedSubtensor, AdvancedSubtensor1, Subtensor
from pytensor.tensor.type import iscalar, vector from pytensor.tensor.type import iscalar, vector
from pytensor.tensor.type_other import NoneConst
no_mode = Mode("py", RewriteDatabaseQuery(include=[], exclude=[])) no_mode = Mode("py", RewriteDatabaseQuery(include=[], exclude=[]))
...@@ -44,20 +45,25 @@ def apply_local_rewrite_to_rv( ...@@ -44,20 +45,25 @@ def apply_local_rewrite_to_rv(
p_pt.tag.test_value = p p_pt.tag.test_value = p
dist_params_pt.append(p_pt) dist_params_pt.append(p_pt)
size_pt = [] if size is None:
for s in size: size_pt = NoneConst
# To test DimShuffle with dropping dims we need that size dimension to be constant else:
if s == 1: size_pt = []
s_pt = constant(np.array(1, dtype="int32")) for s in size:
else: # To test DimShuffle with dropping dims we need that size dimension to be constant
s_pt = iscalar() if s == 1:
s_pt.tag.test_value = s s_pt = constant(np.array(1, dtype="int32"))
size_pt.append(s_pt) else:
s_pt = iscalar()
s_pt.tag.test_value = s
size_pt.append(s_pt)
dist_st = op_fn(dist_op(*dist_params_pt, size=size_pt, rng=rng, name=name)) dist_st = op_fn(dist_op(*dist_params_pt, size=size_pt, rng=rng, name=name))
f_inputs = [ f_inputs = [
p for p in dist_params_pt + size_pt if not isinstance(p, slice | Constant) p
for p in dist_params_pt + ([] if size is None else size_pt)
if not isinstance(p, slice | Constant)
] ]
mode = Mode( mode = Mode(
...@@ -135,7 +141,7 @@ def test_inplace_rewrites(rv_op): ...@@ -135,7 +141,7 @@ def test_inplace_rewrites(rv_op):
np.array([0.0, 1.0], dtype=config.floatX), np.array([0.0, 1.0], dtype=config.floatX),
np.array(5.0, dtype=config.floatX), np.array(5.0, dtype=config.floatX),
], ],
[], None,
), ),
( (
normal, normal,
...@@ -180,7 +186,7 @@ def test_local_rv_size_lift(dist_op, dist_params, size): ...@@ -180,7 +186,7 @@ def test_local_rv_size_lift(dist_op, dist_params, size):
rng, rng,
) )
assert pt.get_vector_length(new_out.owner.inputs[1]) == 0 assert new_out.owner.op.size_param(new_out.owner).data is None
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -194,7 +200,7 @@ def test_local_rv_size_lift(dist_op, dist_params, size): ...@@ -194,7 +200,7 @@ def test_local_rv_size_lift(dist_op, dist_params, size):
np.array([0.0, -100.0], dtype=np.float64), np.array([0.0, -100.0], dtype=np.float64),
np.array(1e-6, dtype=np.float64), np.array(1e-6, dtype=np.float64),
), ),
(), None,
1e-7, 1e-7,
), ),
( (
...@@ -205,7 +211,7 @@ def test_local_rv_size_lift(dist_op, dist_params, size): ...@@ -205,7 +211,7 @@ def test_local_rv_size_lift(dist_op, dist_params, size):
np.array(-10.0, dtype=np.float64), np.array(-10.0, dtype=np.float64),
np.array(1e-6, dtype=np.float64), np.array(1e-6, dtype=np.float64),
), ),
(), None,
1e-7, 1e-7,
), ),
( (
...@@ -216,7 +222,7 @@ def test_local_rv_size_lift(dist_op, dist_params, size): ...@@ -216,7 +222,7 @@ def test_local_rv_size_lift(dist_op, dist_params, size):
np.array(-10.0, dtype=np.float64), np.array(-10.0, dtype=np.float64),
np.array(1e-6, dtype=np.float64), np.array(1e-6, dtype=np.float64),
), ),
(), None,
1e-7, 1e-7,
), ),
( (
...@@ -227,7 +233,7 @@ def test_local_rv_size_lift(dist_op, dist_params, size): ...@@ -227,7 +233,7 @@ def test_local_rv_size_lift(dist_op, dist_params, size):
np.arange(2 * 2 * 2).reshape((2, 2, 2)).astype(config.floatX), np.arange(2 * 2 * 2).reshape((2, 2, 2)).astype(config.floatX),
np.array(1e-6).astype(config.floatX), np.array(1e-6).astype(config.floatX),
), ),
(), None,
1e-3, 1e-3,
), ),
( (
...@@ -440,7 +446,7 @@ def rand_bool_mask(shape, rng=None): ...@@ -440,7 +446,7 @@ def rand_bool_mask(shape, rng=None):
np.arange(30, dtype=config.floatX).reshape(3, 5, 2), np.arange(30, dtype=config.floatX).reshape(3, 5, 2),
np.full((1, 5, 1), 1e-6), np.full((1, 5, 1), 1e-6),
), ),
(), None,
), ),
( (
# `size`-only slice # `size`-only slice
...@@ -462,7 +468,7 @@ def rand_bool_mask(shape, rng=None): ...@@ -462,7 +468,7 @@ def rand_bool_mask(shape, rng=None):
np.arange(30, dtype=config.floatX).reshape(3, 5, 2), np.arange(30, dtype=config.floatX).reshape(3, 5, 2),
np.full((1, 5, 1), 1e-6), np.full((1, 5, 1), 1e-6),
), ),
(), None,
), ),
( (
# `size`-only slice # `size`-only slice
...@@ -484,7 +490,7 @@ def rand_bool_mask(shape, rng=None): ...@@ -484,7 +490,7 @@ def rand_bool_mask(shape, rng=None):
(0.1 - 1e-5) * np.arange(4).astype(dtype=config.floatX), (0.1 - 1e-5) * np.arange(4).astype(dtype=config.floatX),
0.1 * np.arange(4).astype(dtype=config.floatX), 0.1 * np.arange(4).astype(dtype=config.floatX),
), ),
(), None,
), ),
# 5 # 5
( (
...@@ -570,7 +576,7 @@ def rand_bool_mask(shape, rng=None): ...@@ -570,7 +576,7 @@ def rand_bool_mask(shape, rng=None):
dtype=config.floatX, dtype=config.floatX,
), ),
), ),
(), None,
), ),
( (
# Univariate distribution with core-vector parameters # Univariate distribution with core-vector parameters
...@@ -627,7 +633,7 @@ def rand_bool_mask(shape, rng=None): ...@@ -627,7 +633,7 @@ def rand_bool_mask(shape, rng=None):
np.arange(30).reshape(5, 3, 2), np.arange(30).reshape(5, 3, 2),
1e-6, 1e-6,
), ),
(), None,
), ),
( (
# Multidimensional boolean indexing # Multidimensional boolean indexing
...@@ -638,7 +644,7 @@ def rand_bool_mask(shape, rng=None): ...@@ -638,7 +644,7 @@ def rand_bool_mask(shape, rng=None):
np.arange(30).reshape(5, 3, 2), np.arange(30).reshape(5, 3, 2),
1e-6, 1e-6,
), ),
(), None,
), ),
( (
# Multidimensional boolean indexing # Multidimensional boolean indexing
...@@ -649,7 +655,7 @@ def rand_bool_mask(shape, rng=None): ...@@ -649,7 +655,7 @@ def rand_bool_mask(shape, rng=None):
np.arange(30).reshape(5, 3, 2), np.arange(30).reshape(5, 3, 2),
1e-6, 1e-6,
), ),
(), None,
), ),
# 20 # 20
( (
...@@ -661,7 +667,7 @@ def rand_bool_mask(shape, rng=None): ...@@ -661,7 +667,7 @@ def rand_bool_mask(shape, rng=None):
np.arange(30).reshape(5, 3, 2), np.arange(30).reshape(5, 3, 2),
1e-6, 1e-6,
), ),
(), None,
), ),
( (
# Multidimensional boolean indexing # Multidimensional boolean indexing
...@@ -687,7 +693,7 @@ def rand_bool_mask(shape, rng=None): ...@@ -687,7 +693,7 @@ def rand_bool_mask(shape, rng=None):
np.arange(30).reshape(5, 3, 2), np.arange(30).reshape(5, 3, 2),
1e-6, 1e-6,
), ),
(), None,
), ),
( (
# Multidimensional boolean indexing, # Multidimensional boolean indexing,
...@@ -703,7 +709,7 @@ def rand_bool_mask(shape, rng=None): ...@@ -703,7 +709,7 @@ def rand_bool_mask(shape, rng=None):
np.arange(30).reshape(5, 3, 2), np.arange(30).reshape(5, 3, 2),
1e-6, 1e-6,
), ),
(), None,
), ),
( (
# Multivariate distribution: indexing dips into core dimension # Multivariate distribution: indexing dips into core dimension
...@@ -714,7 +720,7 @@ def rand_bool_mask(shape, rng=None): ...@@ -714,7 +720,7 @@ def rand_bool_mask(shape, rng=None):
np.array([[-1, 20], [300, -4000]], dtype=config.floatX), np.array([[-1, 20], [300, -4000]], dtype=config.floatX),
np.eye(2).astype(config.floatX) * 1e-6, np.eye(2).astype(config.floatX) * 1e-6,
), ),
(), None,
), ),
# 25 # 25
( (
...@@ -726,7 +732,7 @@ def rand_bool_mask(shape, rng=None): ...@@ -726,7 +732,7 @@ def rand_bool_mask(shape, rng=None):
np.array([[-1, 20], [300, -4000]], dtype=config.floatX), np.array([[-1, 20], [300, -4000]], dtype=config.floatX),
np.eye(2).astype(config.floatX) * 1e-6, np.eye(2).astype(config.floatX) * 1e-6,
), ),
(), None,
), ),
( (
# Multivariate distribution: advanced integer indexing # Multivariate distribution: advanced integer indexing
...@@ -740,7 +746,7 @@ def rand_bool_mask(shape, rng=None): ...@@ -740,7 +746,7 @@ def rand_bool_mask(shape, rng=None):
), ),
np.eye(3, dtype=config.floatX) * 1e-6, np.eye(3, dtype=config.floatX) * 1e-6,
), ),
(), None,
), ),
( (
# Multivariate distribution: dummy slice "dips" into core dimension # Multivariate distribution: dummy slice "dips" into core dimension
......
...@@ -212,7 +212,7 @@ sd_pt.tag.test_value = np.array(1.0, dtype=config.floatX) ...@@ -212,7 +212,7 @@ sd_pt.tag.test_value = np.array(1.0, dtype=config.floatX)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"M, sd, size", "M, sd, size",
[ [
(pt.as_tensor_variable(np.array(1.0, dtype=config.floatX)), sd_pt, ()), (pt.as_tensor_variable(np.array(1.0, dtype=config.floatX)), sd_pt, None),
( (
pt.as_tensor_variable(np.array(1.0, dtype=config.floatX)), pt.as_tensor_variable(np.array(1.0, dtype=config.floatX)),
sd_pt, sd_pt,
...@@ -223,10 +223,10 @@ sd_pt.tag.test_value = np.array(1.0, dtype=config.floatX) ...@@ -223,10 +223,10 @@ sd_pt.tag.test_value = np.array(1.0, dtype=config.floatX)
sd_pt, sd_pt,
(2, M_pt), (2, M_pt),
), ),
(pt.zeros((M_pt,)), sd_pt, ()), (pt.zeros((M_pt,)), sd_pt, None),
(pt.zeros((M_pt,)), sd_pt, (M_pt,)), (pt.zeros((M_pt,)), sd_pt, (M_pt,)),
(pt.zeros((M_pt,)), sd_pt, (2, M_pt)), (pt.zeros((M_pt,)), sd_pt, (2, M_pt)),
(pt.zeros((M_pt,)), pt.ones((M_pt,)), ()), (pt.zeros((M_pt,)), pt.ones((M_pt,)), None),
(pt.zeros((M_pt,)), pt.ones((M_pt,)), (2, M_pt)), (pt.zeros((M_pt,)), pt.ones((M_pt,)), (2, M_pt)),
( (
create_pytensor_param( create_pytensor_param(
...@@ -244,9 +244,10 @@ sd_pt.tag.test_value = np.array(1.0, dtype=config.floatX) ...@@ -244,9 +244,10 @@ sd_pt.tag.test_value = np.array(1.0, dtype=config.floatX)
) )
def test_normal_infer_shape(M, sd, size): def test_normal_infer_shape(M, sd, size):
rv = normal(M, sd, size=size) rv = normal(M, sd, size=size)
rv_shape = list(normal._infer_shape(size or (), [M, sd], None)) size_pt = rv.owner.op.size_param(rv.owner)
rv_shape = list(normal._infer_shape(size_pt, [M, sd], None))
all_args = (M, sd, *size) all_args = (M, sd, *(() if size is None else size))
fn_inputs = [ fn_inputs = [
i i
for i in graph_inputs([a for a in all_args if isinstance(a, Variable)]) for i in graph_inputs([a for a in all_args if isinstance(a, Variable)])
...@@ -525,8 +526,8 @@ def mvnormal_test_fn(mean=None, cov=None, size=None, random_state=None): ...@@ -525,8 +526,8 @@ def mvnormal_test_fn(mean=None, cov=None, size=None, random_state=None):
mean = np.array([0.0], dtype=config.floatX) mean = np.array([0.0], dtype=config.floatX)
if cov is None: if cov is None:
cov = np.array([[1.0]], dtype=config.floatX) cov = np.array([[1.0]], dtype=config.floatX)
if size is None: if size is not None:
size = () size = tuple(size)
return multivariate_normal.rng_fn(random_state, mean, cov, size) return multivariate_normal.rng_fn(random_state, mean, cov, size)
...@@ -713,19 +714,20 @@ M_pt.tag.test_value = 3 ...@@ -713,19 +714,20 @@ M_pt.tag.test_value = 3
@pytest.mark.parametrize( @pytest.mark.parametrize(
"M, size", "M, size",
[ [
(pt.ones((M_pt,)), ()), (pt.ones((M_pt,)), None),
(pt.ones((M_pt,)), (M_pt + 1,)), (pt.ones((M_pt,)), (M_pt + 1,)),
(pt.ones((M_pt,)), (2, M_pt)), (pt.ones((M_pt,)), (2, M_pt)),
(pt.ones((M_pt, M_pt + 1)), ()), (pt.ones((M_pt, M_pt + 1)), None),
(pt.ones((M_pt, M_pt + 1)), (M_pt + 2, M_pt)), (pt.ones((M_pt, M_pt + 1)), (M_pt + 2, M_pt)),
(pt.ones((M_pt, M_pt + 1)), (2, M_pt + 2, M_pt + 3, M_pt)), (pt.ones((M_pt, M_pt + 1)), (2, M_pt + 2, M_pt + 3, M_pt)),
], ],
) )
def test_dirichlet_infer_shape(M, size): def test_dirichlet_infer_shape(M, size):
rv = dirichlet(M, size=size) rv = dirichlet(M, size=size)
rv_shape = list(dirichlet._infer_shape(size or (), [M], None)) size_pt = rv.owner.op.size_param(rv.owner)
rv_shape = list(dirichlet._infer_shape(size_pt, [M], None))
all_args = (M, *size) all_args = (M, *(() if size is None else size))
fn_inputs = [ fn_inputs = [
i i
for i in graph_inputs([a for a in all_args if isinstance(a, Variable)]) for i in graph_inputs([a for a in all_args if isinstance(a, Variable)])
...@@ -1620,8 +1622,7 @@ def test_unnatural_batched_dims(batch_dims_tester): ...@@ -1620,8 +1622,7 @@ def test_unnatural_batched_dims(batch_dims_tester):
@config.change_flags(compute_test_value="off") @config.change_flags(compute_test_value="off")
def test_pickle(): def test_pickle():
# This is an interesting `Op` case, because it has `None` types and a # This is an interesting `Op` case, because it has a conditional dtype
# conditional dtype
sample_a = choice(5, replace=False, size=(2, 3)) sample_a = choice(5, replace=False, size=(2, 3))
a_pkl = pickle.dumps(sample_a) a_pkl = pickle.dumps(sample_a)
......
...@@ -69,7 +69,7 @@ def test_RandomVariable_basics(strict_test_value_flags): ...@@ -69,7 +69,7 @@ def test_RandomVariable_basics(strict_test_value_flags):
# `RandomVariable._infer_shape` should handle no parameters # `RandomVariable._infer_shape` should handle no parameters
rv_shape = rv._infer_shape(pt.constant([]), (), []) rv_shape = rv._infer_shape(pt.constant([]), (), [])
assert rv_shape.equals(pt.constant([], dtype="int64")) assert rv_shape == ()
# `dtype` is respected # `dtype` is respected
rv = RandomVariable("normal", signature="(),()->()", dtype="int32") rv = RandomVariable("normal", signature="(),()->()", dtype="int32")
...@@ -299,3 +299,16 @@ def test_vectorize(): ...@@ -299,3 +299,16 @@ def test_vectorize():
vect_node = vectorize_graph(out, {mu: new_mu, sigma: new_sigma}).owner vect_node = vectorize_graph(out, {mu: new_mu, sigma: new_sigma}).owner
assert isinstance(vect_node.op, NormalRV) assert isinstance(vect_node.op, NormalRV)
assert vect_node.default_output().type.shape == (10, 2, 5) assert vect_node.default_output().type.shape == (10, 2, 5)
def test_size_none_vs_empty():
rv = RandomVariable(
"normal",
signature="(),()->()",
)
assert rv([0], [1], size=None).type.shape == (1,)
with pytest.raises(
ValueError, match="Size length is incompatible with batched dimensions"
):
rv([0], [1], size=())
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论