提交 ecd9c3b8 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Simplify RandomVariable._infer_shape

上级 b2b7e287
...@@ -191,6 +191,8 @@ class RandomVariable(Op): ...@@ -191,6 +191,8 @@ class RandomVariable(Op):
""" """
from pytensor.tensor.extra_ops import broadcast_shape_iter
size_len = get_vector_length(size) size_len = get_vector_length(size)
if size_len > 0: if size_len > 0:
...@@ -216,57 +218,52 @@ class RandomVariable(Op): ...@@ -216,57 +218,52 @@ class RandomVariable(Op):
# Broadcast the parameters # Broadcast the parameters
param_shapes = params_broadcast_shapes( param_shapes = params_broadcast_shapes(
param_shapes or [shape_tuple(p) for p in dist_params], self.ndims_params param_shapes or [shape_tuple(p) for p in dist_params],
self.ndims_params,
) )
def slice_ind_dims(p, ps, n): def extract_batch_shape(p, ps, n):
shape = tuple(ps) shape = tuple(ps)
if n == 0: if n == 0:
return (p, shape) return shape
ind_slice = (slice(None),) * (p.ndim - n) + (0,) * n batch_shape = [
ind_shape = [
s if b is False else constant(1, "int64") s if b is False else constant(1, "int64")
for s, b in zip(shape[:-n], p.broadcastable[:-n]) for s, b in zip(shape[:-n], p.type.broadcastable[:-n])
] ]
return ( return batch_shape
p[ind_slice],
ind_shape,
)
# These are versions of our actual parameters with the anticipated # These are versions of our actual parameters with the anticipated
# dimensions (i.e. support dimensions) removed so that only the # dimensions (i.e. support dimensions) removed so that only the
# independent variate dimensions are left. # independent variate dimensions are left.
params_ind_slice = tuple( params_batch_shape = tuple(
slice_ind_dims(p, ps, n) extract_batch_shape(p, ps, n)
for p, ps, n in zip(dist_params, param_shapes, self.ndims_params) for p, ps, n in zip(dist_params, param_shapes, self.ndims_params)
) )
if len(params_ind_slice) == 1: if len(params_batch_shape) == 1:
_, shape_ind = params_ind_slice[0] [batch_shape] = params_batch_shape
elif len(params_ind_slice) > 1: elif len(params_batch_shape) > 1:
# If there are multiple parameters, the dimensions of their # If there are multiple parameters, the dimensions of their
# independent variates should broadcast together. # independent variates should broadcast together.
p_slices, p_shapes = zip(*params_ind_slice) batch_shape = broadcast_shape_iter(
params_batch_shape,
shape_ind = pytensor.tensor.extra_ops.broadcast_shape_iter( arrays_are_shapes=True,
p_shapes, arrays_are_shapes=True
) )
else: else:
# Distribution has no parameters # Distribution has no parameters
shape_ind = () batch_shape = ()
if self.ndim_supp == 0: if self.ndim_supp == 0:
shape_supp = () supp_shape = ()
else: else:
shape_supp = self._supp_shape_from_params( supp_shape = self._supp_shape_from_params(
dist_params, dist_params,
param_shapes=param_shapes, param_shapes=param_shapes,
) )
shape = tuple(shape_ind) + tuple(shape_supp) shape = tuple(batch_shape) + tuple(supp_shape)
if not shape: if not shape:
shape = constant([], dtype="int64") shape = constant([], dtype="int64")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论