提交 f8e7fe07 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Use inferred shapes for RandomVariable size parameter

上级 a24cd432
from collections.abc import Sequence
from copy import copy
from typing import List, Optional, Tuple
import numpy as np
......@@ -10,12 +11,19 @@ from aesara.graph.fg import FunctionGraph
from aesara.graph.op import Op
from aesara.graph.opt_utils import optimize_graph
from aesara.misc.safe_asarray import _asarray
from aesara.tensor.basic import as_tensor_variable, constant, get_vector_length
from aesara.scalar import ScalarVariable
from aesara.tensor.basic import (
as_tensor_variable,
constant,
get_scalar_constant_value,
get_vector_length,
)
from aesara.tensor.basic_opt import ShapeFeature, topo_constant_folding
from aesara.tensor.random.type import RandomType
from aesara.tensor.random.utils import normalize_size_param, params_broadcast_shapes
from aesara.tensor.shape import shape_tuple
from aesara.tensor.type import TensorType, all_dtypes
from aesara.tensor.var import TensorVariable
def default_shape_from_params(
......@@ -159,25 +167,26 @@ class RandomVariable(Op):
props_str = ", ".join((f"{getattr(self, prop)}" for prop in self.__props__[1:]))
return f"{self.name}_rv{{{props_str}}}"
def _infer_shape(self, size, dist_params, param_shapes=None):
def _infer_shape(
self,
size: Tuple[TensorVariable],
dist_params: List[TensorVariable],
param_shapes: Optional[List[Tuple[TensorVariable]]] = None,
) -> Tuple[ScalarVariable]:
"""Compute the output shape given the size and distribution parameters.
Parameters
----------
size : TensorVariable
size
The size parameter specified for this `RandomVariable`.
dist_params : list of TensorVariable
dist_params
The symbolic parameter for this `RandomVariable`'s distribution.
param_shapes : list of tuples of TensorVariable (optional)
param_shapes
The shapes of the `dist_params` as given by `ShapeFeature`'s
via `Op.infer_shape`'s `input_shapes` argument. This parameter's
values are essentially more accurate versions of ``[d.shape for d
in dist_params]``.
Outputs
-------
shape : tuple of `ScalarVariable`
"""
size_len = get_vector_length(size)
......@@ -294,7 +303,14 @@ class RandomVariable(Op):
def infer_shape(self, fgraph, node, input_shapes):
_, size, _, *dist_params = node.inputs
_, _, _, *param_shapes = input_shapes
_, size_shape, _, *param_shapes = input_shapes
try:
size_len = get_vector_length(size)
except ValueError:
size_len = get_scalar_constant_value(size_shape[0])
size = tuple(size[n] for n in range(size_len))
shape = self._infer_shape(size, dist_params, param_shapes=param_shapes)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论