提交 ee53f7c9 authored 作者: danhphan's avatar danhphan 提交者: Brandon T. Willard

Add type hints for default_shape_from_params

上级 4e2fc615
......@@ -26,8 +26,11 @@ from aesara.tensor.var import TensorVariable
def default_shape_from_params(
ndim_supp, dist_params, rep_param_idx=0, param_shapes=None
):
ndim_supp: int,
dist_params: Sequence[Variable],
rep_param_idx: Optional[int] = 0,
param_shapes: Optional[Sequence[Tuple[ScalarVariable]]] = None,
) -> Tuple[int, ...]:
"""Infer the dimensions for the output of a `RandomVariable`.
This is a function that derives a random variable's support
......@@ -50,14 +53,14 @@ def default_shape_from_params(
(e.g. a multivariate normal draw is 1D, so `ndim_supp = 1`).
dist_params: list of `aesara.graph.basic.Variable`
The distribution parameters.
param_shapes: list of tuple of `ScalarVariable` (optional)
Symbolic shapes for each distribution parameter. These will
be used in place of distribution parameter-generated shapes.
rep_param_idx: int (optional)
The index of the distribution parameter to use as a reference
In other words, a parameter in `dist_param` with a shape corresponding
to the support's shape.
The default is the first parameter (i.e. the value 0).
param_shapes: list of tuple of `ScalarVariable` (optional)
Symbolic shapes for each distribution parameter. These will
be used in place of distribution parameter-generated shapes.
Results
-------
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论