提交 17ba075e authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Extract RandomVariable size parameter handling to a separate utility function

上级 ccda3ccb
...@@ -11,15 +11,14 @@ from aesara.graph.op import Op ...@@ -11,15 +11,14 @@ from aesara.graph.op import Op
from aesara.misc.safe_asarray import _asarray from aesara.misc.safe_asarray import _asarray
from aesara.tensor.basic import ( from aesara.tensor.basic import (
as_tensor_variable, as_tensor_variable,
cast,
constant, constant,
get_scalar_constant_value, get_scalar_constant_value,
get_vector_length, get_vector_length,
) )
from aesara.tensor.exceptions import NotScalarConstantError from aesara.tensor.exceptions import NotScalarConstantError
from aesara.tensor.random.type import RandomStateType from aesara.tensor.random.type import RandomStateType
from aesara.tensor.random.utils import params_broadcast_shapes from aesara.tensor.random.utils import normalize_size_param, params_broadcast_shapes
from aesara.tensor.type import TensorType, all_dtypes, int_dtypes from aesara.tensor.type import TensorType, all_dtypes
from aesara.tensor.type_other import NoneConst from aesara.tensor.type_other import NoneConst
...@@ -348,18 +347,7 @@ class RandomVariable(Op): ...@@ -348,18 +347,7 @@ class RandomVariable(Op):
`(rng_var, out_var)`. `(rng_var, out_var)`.
""" """
if size is None: size = normalize_size_param(size)
size = constant([], dtype="int64")
elif isinstance(size, int):
size = as_tensor_variable([size], ndim=1)
elif not isinstance(size, (np.ndarray, Variable, Sequence)):
raise TypeError(
"Parameter size must be None, an integer, or a sequence with integers."
)
else:
size = cast(as_tensor_variable(size, ndim=1), "int64")
assert size.dtype in int_dtypes
dist_params = tuple( dist_params = tuple(
as_tensor_variable(p) if not isinstance(p, Variable) else p as_tensor_variable(p) if not isinstance(p, Variable) else p
......
from collections.abc import Sequence
from functools import wraps from functools import wraps
from itertools import zip_longest from itertools import zip_longest
...@@ -5,8 +6,10 @@ import numpy as np ...@@ -5,8 +6,10 @@ import numpy as np
from aesara.compile.sharedvalue import shared from aesara.compile.sharedvalue import shared
from aesara.graph.basic import Variable from aesara.graph.basic import Variable
from aesara.tensor.basic import as_tensor_variable, cast, constant
from aesara.tensor.extra_ops import broadcast_to from aesara.tensor.extra_ops import broadcast_to
from aesara.tensor.math import maximum from aesara.tensor.math import maximum
from aesara.tensor.type import int_dtypes
def params_broadcast_shapes(param_shapes, ndims_params, use_aesara=True): def params_broadcast_shapes(param_shapes, ndims_params, use_aesara=True):
...@@ -95,6 +98,24 @@ def broadcast_params(params, ndims_params): ...@@ -95,6 +98,24 @@ def broadcast_params(params, ndims_params):
return bcast_params return bcast_params
def normalize_size_param(size):
"""Create an Aesara value for a ``RandomVariable`` ``size`` parameter."""
if size is None:
size = constant([], dtype="int64")
elif isinstance(size, int):
size = as_tensor_variable([size], ndim=1)
elif not isinstance(size, (np.ndarray, Variable, Sequence)):
raise TypeError(
"Parameter size must be None, an integer, or a sequence with integers."
)
else:
size = cast(as_tensor_variable(size, ndim=1), "int64")
assert size.dtype in int_dtypes
return size
class RandomStream: class RandomStream:
"""Module component with similar interface to `numpy.random.RandomState`. """Module component with similar interface to `numpy.random.RandomState`.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论