提交 80bfde1e authored 作者: Ricardo's avatar Ricardo 提交者: Ricardo Vieira

Simplify `RandomVariable._infer_shape`

上级 265c0d94
......@@ -232,9 +232,7 @@ class RandomVariable(Op):
)
if len(params_ind_slice) == 1:
ind_param, ind_shape = params_ind_slice[0]
ndim_ind = len(ind_shape)
shape_ind = ind_shape
_, shape_ind = params_ind_slice[0]
elif len(params_ind_slice) > 1:
# If there are multiple parameters, the dimensions of their
# independent variates should broadcast together.
......@@ -244,36 +242,21 @@ class RandomVariable(Op):
p_shapes, arrays_are_shapes=True
)
ndim_ind = len(shape_ind)
else:
ndim_ind = 0
# Distribution has no parameters
shape_ind = ()
if self.ndim_supp == 0:
shape_supp = tuple()
shape_reps = tuple(size)
if ndim_ind > 0:
shape_reps = shape_reps[:-ndim_ind]
ndim_reps = len(shape_reps)
shape_supp = ()
else:
shape_supp = self._supp_shape_from_params(
dist_params,
param_shapes=param_shapes,
)
ndim_reps = size_len
shape_reps = size
ndim_shape = self.ndim_supp + ndim_ind + ndim_reps
if ndim_shape == 0:
shape = tuple(shape_ind) + tuple(shape_supp)
if not shape:
shape = constant([], dtype="int64")
else:
shape = tuple(shape_reps) + tuple(shape_ind) + tuple(shape_supp)
# if shape is None:
# raise ShapeError()
return shape
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论