提交 7e351b93 authored 作者: Ricardo's avatar Ricardo 提交者: Ricardo Vieira

Rename `*shape_from_params` to `*supp_shape_from_params` for clarity

上级 e30d7b90
......@@ -6,7 +6,7 @@ import scipy.stats as stats
import aesara
from aesara.tensor.basic import as_tensor_variable
from aesara.tensor.random.op import RandomVariable, default_shape_from_params
from aesara.tensor.random.op import RandomVariable, default_supp_shape_from_params
from aesara.tensor.random.type import RandomGeneratorType, RandomStateType
from aesara.tensor.random.utils import broadcast_params
from aesara.tensor.random.var import (
......@@ -591,8 +591,8 @@ class MultinomialRV(RandomVariable):
dtype = "int64"
_print_name = ("MN", "\\operatorname{MN}")
def _shape_from_params(self, dist_params, rep_param_idx=1, param_shapes=None):
return default_shape_from_params(
def _supp_shape_from_params(self, dist_params, rep_param_idx=1, param_shapes=None):
return default_supp_shape_from_params(
self.ndim_supp, dist_params, rep_param_idx, param_shapes
)
......@@ -713,7 +713,7 @@ class ChoiceRV(RandomVariable):
def rng_fn(cls, rng, a, p, replace, size):
return rng.choice(a, size, replace, p)
def _shape_from_params(self, *args, **kwargs):
def _supp_shape_from_params(self, *args, **kwargs):
raise NotImplementedError()
def _infer_shape(self, size, dist_params, param_shapes=None):
......
......@@ -24,7 +24,7 @@ from aesara.tensor.type_other import NoneConst
from aesara.tensor.var import TensorVariable
def default_shape_from_params(
def default_supp_shape_from_params(
ndim_supp: int,
dist_params: Sequence[Variable],
rep_param_idx: Optional[int] = 0,
......@@ -151,14 +151,15 @@ class RandomVariable(Op):
if self.inplace:
self.destroy_map = {0: [0]}
def _shape_from_params(self, dist_params, **kwargs):
"""Determine the shape of a `RandomVariable`'s output given its parameters.
def _supp_shape_from_params(self, dist_params, **kwargs):
"""Determine the support shape of a `RandomVariable`'s output given its parameters.
This does *not* consider the extra dimensions added by the `size` parameter.
This does *not* consider the extra dimensions added by the `size` parameter
or independent (batched) parameters.
Defaults to `param_supp_shape_fn`.
"""
return default_shape_from_params(self.ndim_supp, dist_params, **kwargs)
return default_supp_shape_from_params(self.ndim_supp, dist_params, **kwargs)
def rng_fn(self, rng, *args, **kwargs):
"""Sample a numeric random variate."""
......@@ -196,7 +197,7 @@ class RandomVariable(Op):
if self.ndim_supp == 0:
return size
else:
supp_shape = self._shape_from_params(
supp_shape = self._supp_shape_from_params(
dist_params, param_shapes=param_shapes
)
return tuple(size) + tuple(supp_shape)
......@@ -256,7 +257,7 @@ class RandomVariable(Op):
ndim_reps = len(shape_reps)
else:
shape_supp = self._shape_from_params(
shape_supp = self._supp_shape_from_params(
dist_params,
param_shapes=param_shapes,
)
......
......@@ -1227,7 +1227,7 @@ def test_integers_samples():
def test_choice_samples():
with pytest.raises(NotImplementedError):
choice._shape_from_params(np.asarray(5))
choice._supp_shape_from_params(np.asarray(5))
rv_numpy_tester(choice, np.asarray([5]))
rv_numpy_tester(choice, np.array([1.0, 5.0], dtype=config.floatX))
......
......@@ -10,7 +10,7 @@ from aesara.tensor.random.op import (
RandomState,
RandomVariable,
default_rng,
default_shape_from_params,
default_supp_shape_from_params,
)
from aesara.tensor.shape import specify_shape
from aesara.tensor.type import all_dtypes, iscalar, tensor
......@@ -22,20 +22,24 @@ def set_aesara_flags():
yield
def test_default_shape_from_params():
def test_default_supp_shape_from_params():
with pytest.raises(ValueError, match="^ndim_supp*"):
default_shape_from_params(0, (np.array([1, 2]), 0))
default_supp_shape_from_params(0, (np.array([1, 2]), 0))
res = default_shape_from_params(1, (np.array([1, 2]), np.eye(2)), rep_param_idx=0)
res = default_supp_shape_from_params(
1, (np.array([1, 2]), np.eye(2)), rep_param_idx=0
)
assert res == (2,)
res = default_shape_from_params(1, (np.array([1, 2]), 0), param_shapes=((2,), ()))
res = default_supp_shape_from_params(
1, (np.array([1, 2]), 0), param_shapes=((2,), ())
)
assert res == (2,)
with pytest.raises(ValueError, match="^Reference parameter*"):
default_shape_from_params(1, (np.array(1),), rep_param_idx=0)
default_supp_shape_from_params(1, (np.array(1),), rep_param_idx=0)
res = default_shape_from_params(
res = default_supp_shape_from_params(
2, (np.array([1, 2]), np.ones((2, 3, 4))), rep_param_idx=1
)
assert res == (3, 4)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论