提交 0bfe6d22 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Fix PermutationRV ambiguous signature

The RV always expects a vector input and `ndims_paramas` is always `[1]`. Size is no longer ignored
上级 a8f952dc
...@@ -5,7 +5,7 @@ import numpy as np ...@@ -5,7 +5,7 @@ import numpy as np
import scipy.stats as stats import scipy.stats as stats
import pytensor import pytensor
from pytensor.tensor.basic import as_tensor_variable, arange from pytensor.tensor.basic import arange, as_tensor_variable
from pytensor.tensor.random.op import RandomVariable from pytensor.tensor.random.op import RandomVariable
from pytensor.tensor.random.type import RandomGeneratorType, RandomStateType from pytensor.tensor.random.type import RandomGeneratorType, RandomStateType
from pytensor.tensor.random.utils import ( from pytensor.tensor.random.utils import (
...@@ -2072,18 +2072,15 @@ class PermutationRV(RandomVariable): ...@@ -2072,18 +2072,15 @@ class PermutationRV(RandomVariable):
@classmethod @classmethod
def rng_fn(cls, rng, x, size): def rng_fn(cls, rng, x, size):
return rng.permutation(x if x.ndim > 0 else x.item()) return rng.permutation(x)
def _infer_shape(self, size, dist_params, param_shapes=None): def _supp_shape_from_params(self, dist_params, param_shapes=None):
param_shapes = param_shapes or [p.shape for p in dist_params] return supp_shape_from_ref_param_shape(
ndim_supp=self.ndim_supp,
(x,) = dist_params dist_params=dist_params,
(x_shape,) = param_shapes param_shapes=param_shapes,
ref_param_idx=0,
if x.ndim == 0: )
return (x,)
else:
return x_shape
def __call__(self, x, **kwargs): def __call__(self, x, **kwargs):
r"""Randomly permute a sequence or a range of values. r"""Randomly permute a sequence or a range of values.
...@@ -2096,15 +2093,35 @@ class PermutationRV(RandomVariable): ...@@ -2096,15 +2093,35 @@ class PermutationRV(RandomVariable):
Parameters Parameters
---------- ----------
x x
If `x` is an integer, randomly permute `np.arange(x)`. If `x` is a sequence, Elements to be shuffled.
shuffle its elements randomly.
""" """
x = as_tensor_variable(x) x = as_tensor_variable(x)
return super().__call__(x, dtype=x.dtype, **kwargs) return super().__call__(x, dtype=x.dtype, **kwargs)
permutation = PermutationRV() _permutation = PermutationRV()
def permutation(x, **kwargs):
r"""Randomly permute a sequence or a range of values.
Signature
---------
`(x) -> (x)`
Parameters
----------
x
If `x` is an integer, randomly permute `np.arange(x)`. If `x` is a sequence,
shuffle its elements randomly.
"""
x = as_tensor_variable(x)
if x.type.ndim == 0:
x = arange(x)
return _permutation(x, **kwargs)
__all__ = [ __all__ = [
......
...@@ -1413,6 +1413,14 @@ def test_permutation_samples(): ...@@ -1413,6 +1413,14 @@ def test_permutation_samples():
compare_sample_values(permutation, np.array([1.0, 2.0, 3.0], dtype=config.floatX)) compare_sample_values(permutation, np.array([1.0, 2.0, 3.0], dtype=config.floatX))
def test_permutation_shape():
assert tuple(permutation(5).shape.eval()) == (5,)
assert tuple(permutation(np.arange(5)).shape.eval()) == (5,)
assert tuple(permutation(np.arange(10).reshape(2, 5)).shape.eval()) == (2, 5)
assert tuple(permutation(5, size=(2, 3)).shape.eval()) == (2, 3, 5)
assert tuple(permutation(np.arange(5), size=(2, 3)).shape.eval()) == (2, 3, 5)
@config.change_flags(compute_test_value="off") @config.change_flags(compute_test_value="off")
def test_pickle(): def test_pickle():
# This is an interesting `Op` case, because it has `None` types and a # This is an interesting `Op` case, because it has `None` types and a
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论