提交 7e351b93 authored 作者: Ricardo's avatar Ricardo 提交者: Ricardo Vieira

Rename `*shape_from_params` to `*supp_shape_from_params` for clarity

上级 e30d7b90
...@@ -6,7 +6,7 @@ import scipy.stats as stats ...@@ -6,7 +6,7 @@ import scipy.stats as stats
import aesara import aesara
from aesara.tensor.basic import as_tensor_variable from aesara.tensor.basic import as_tensor_variable
from aesara.tensor.random.op import RandomVariable, default_shape_from_params from aesara.tensor.random.op import RandomVariable, default_supp_shape_from_params
from aesara.tensor.random.type import RandomGeneratorType, RandomStateType from aesara.tensor.random.type import RandomGeneratorType, RandomStateType
from aesara.tensor.random.utils import broadcast_params from aesara.tensor.random.utils import broadcast_params
from aesara.tensor.random.var import ( from aesara.tensor.random.var import (
...@@ -591,8 +591,8 @@ class MultinomialRV(RandomVariable): ...@@ -591,8 +591,8 @@ class MultinomialRV(RandomVariable):
dtype = "int64" dtype = "int64"
_print_name = ("MN", "\\operatorname{MN}") _print_name = ("MN", "\\operatorname{MN}")
def _shape_from_params(self, dist_params, rep_param_idx=1, param_shapes=None): def _supp_shape_from_params(self, dist_params, rep_param_idx=1, param_shapes=None):
return default_shape_from_params( return default_supp_shape_from_params(
self.ndim_supp, dist_params, rep_param_idx, param_shapes self.ndim_supp, dist_params, rep_param_idx, param_shapes
) )
...@@ -713,7 +713,7 @@ class ChoiceRV(RandomVariable): ...@@ -713,7 +713,7 @@ class ChoiceRV(RandomVariable):
def rng_fn(cls, rng, a, p, replace, size): def rng_fn(cls, rng, a, p, replace, size):
return rng.choice(a, size, replace, p) return rng.choice(a, size, replace, p)
def _shape_from_params(self, *args, **kwargs): def _supp_shape_from_params(self, *args, **kwargs):
raise NotImplementedError() raise NotImplementedError()
def _infer_shape(self, size, dist_params, param_shapes=None): def _infer_shape(self, size, dist_params, param_shapes=None):
......
...@@ -24,7 +24,7 @@ from aesara.tensor.type_other import NoneConst ...@@ -24,7 +24,7 @@ from aesara.tensor.type_other import NoneConst
from aesara.tensor.var import TensorVariable from aesara.tensor.var import TensorVariable
def default_shape_from_params( def default_supp_shape_from_params(
ndim_supp: int, ndim_supp: int,
dist_params: Sequence[Variable], dist_params: Sequence[Variable],
rep_param_idx: Optional[int] = 0, rep_param_idx: Optional[int] = 0,
...@@ -151,14 +151,15 @@ class RandomVariable(Op): ...@@ -151,14 +151,15 @@ class RandomVariable(Op):
if self.inplace: if self.inplace:
self.destroy_map = {0: [0]} self.destroy_map = {0: [0]}
def _shape_from_params(self, dist_params, **kwargs): def _supp_shape_from_params(self, dist_params, **kwargs):
"""Determine the shape of a `RandomVariable`'s output given its parameters. """Determine the support shape of a `RandomVariable`'s output given its parameters.
This does *not* consider the extra dimensions added by the `size` parameter. This does *not* consider the extra dimensions added by the `size` parameter
or independent (batched) parameters.
Defaults to `param_supp_shape_fn`. Defaults to `param_supp_shape_fn`.
""" """
return default_shape_from_params(self.ndim_supp, dist_params, **kwargs) return default_supp_shape_from_params(self.ndim_supp, dist_params, **kwargs)
def rng_fn(self, rng, *args, **kwargs): def rng_fn(self, rng, *args, **kwargs):
"""Sample a numeric random variate.""" """Sample a numeric random variate."""
...@@ -196,7 +197,7 @@ class RandomVariable(Op): ...@@ -196,7 +197,7 @@ class RandomVariable(Op):
if self.ndim_supp == 0: if self.ndim_supp == 0:
return size return size
else: else:
supp_shape = self._shape_from_params( supp_shape = self._supp_shape_from_params(
dist_params, param_shapes=param_shapes dist_params, param_shapes=param_shapes
) )
return tuple(size) + tuple(supp_shape) return tuple(size) + tuple(supp_shape)
...@@ -256,7 +257,7 @@ class RandomVariable(Op): ...@@ -256,7 +257,7 @@ class RandomVariable(Op):
ndim_reps = len(shape_reps) ndim_reps = len(shape_reps)
else: else:
shape_supp = self._shape_from_params( shape_supp = self._supp_shape_from_params(
dist_params, dist_params,
param_shapes=param_shapes, param_shapes=param_shapes,
) )
......
...@@ -1227,7 +1227,7 @@ def test_integers_samples(): ...@@ -1227,7 +1227,7 @@ def test_integers_samples():
def test_choice_samples(): def test_choice_samples():
with pytest.raises(NotImplementedError): with pytest.raises(NotImplementedError):
choice._shape_from_params(np.asarray(5)) choice._supp_shape_from_params(np.asarray(5))
rv_numpy_tester(choice, np.asarray([5])) rv_numpy_tester(choice, np.asarray([5]))
rv_numpy_tester(choice, np.array([1.0, 5.0], dtype=config.floatX)) rv_numpy_tester(choice, np.array([1.0, 5.0], dtype=config.floatX))
......
...@@ -10,7 +10,7 @@ from aesara.tensor.random.op import ( ...@@ -10,7 +10,7 @@ from aesara.tensor.random.op import (
RandomState, RandomState,
RandomVariable, RandomVariable,
default_rng, default_rng,
default_shape_from_params, default_supp_shape_from_params,
) )
from aesara.tensor.shape import specify_shape from aesara.tensor.shape import specify_shape
from aesara.tensor.type import all_dtypes, iscalar, tensor from aesara.tensor.type import all_dtypes, iscalar, tensor
...@@ -22,20 +22,24 @@ def set_aesara_flags(): ...@@ -22,20 +22,24 @@ def set_aesara_flags():
yield yield
def test_default_shape_from_params(): def test_default_supp_shape_from_params():
with pytest.raises(ValueError, match="^ndim_supp*"): with pytest.raises(ValueError, match="^ndim_supp*"):
default_shape_from_params(0, (np.array([1, 2]), 0)) default_supp_shape_from_params(0, (np.array([1, 2]), 0))
res = default_shape_from_params(1, (np.array([1, 2]), np.eye(2)), rep_param_idx=0) res = default_supp_shape_from_params(
1, (np.array([1, 2]), np.eye(2)), rep_param_idx=0
)
assert res == (2,) assert res == (2,)
res = default_shape_from_params(1, (np.array([1, 2]), 0), param_shapes=((2,), ())) res = default_supp_shape_from_params(
1, (np.array([1, 2]), 0), param_shapes=((2,), ())
)
assert res == (2,) assert res == (2,)
with pytest.raises(ValueError, match="^Reference parameter*"): with pytest.raises(ValueError, match="^Reference parameter*"):
default_shape_from_params(1, (np.array(1),), rep_param_idx=0) default_supp_shape_from_params(1, (np.array(1),), rep_param_idx=0)
res = default_shape_from_params( res = default_supp_shape_from_params(
2, (np.array([1, 2]), np.ones((2, 3, 4))), rep_param_idx=1 2, (np.array([1, 2]), np.ones((2, 3, 4))), rep_param_idx=1
) )
assert res == (3, 4) assert res == (3, 4)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论