提交 c38eea06 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Fix shape inference for multivariate random Ops

When size is not provided, the batch shapes of the parameters were being broadcasted twice, and the second time, wrongly, due to mixing static shape of the original parameters and the potentially larger shape of the just broadcasted parameters.
上级 205da7f9
......@@ -20,11 +20,7 @@ from pytensor.tensor.basic import (
infer_static_shape,
)
from pytensor.tensor.random.type import RandomGeneratorType, RandomStateType, RandomType
from pytensor.tensor.random.utils import (
broadcast_params,
normalize_size_param,
params_broadcast_shapes,
)
from pytensor.tensor.random.utils import broadcast_params, normalize_size_param
from pytensor.tensor.shape import shape_tuple
from pytensor.tensor.type import TensorType, all_dtypes
from pytensor.tensor.type_other import NoneConst
......@@ -156,6 +152,13 @@ class RandomVariable(Op):
from pytensor.tensor.extra_ops import broadcast_shape_iter
if self.ndim_supp == 0:
supp_shape = ()
else:
supp_shape = tuple(
self._supp_shape_from_params(dist_params, param_shapes=param_shapes)
)
size_len = get_vector_length(size)
if size_len > 0:
......@@ -171,19 +174,11 @@ class RandomVariable(Op):
f"Size length must be 0 or >= {param_batched_dims}"
)
if self.ndim_supp == 0:
return size
else:
supp_shape = self._supp_shape_from_params(
dist_params, param_shapes=param_shapes
)
return tuple(size) + tuple(supp_shape)
return tuple(size) + supp_shape
# Broadcast the parameters
param_shapes = params_broadcast_shapes(
param_shapes or [shape_tuple(p) for p in dist_params],
self.ndims_params,
)
# Size was not provided, we must infer it from the shape of the parameters
if param_shapes is None:
param_shapes = [shape_tuple(p) for p in dist_params]
def extract_batch_shape(p, ps, n):
shape = tuple(ps)
......@@ -191,10 +186,10 @@ class RandomVariable(Op):
if n == 0:
return shape
batch_shape = [
batch_shape = tuple(
s if not b else constant(1, "int64")
for s, b in zip(shape[:-n], p.type.broadcastable[:-n])
]
)
return batch_shape
# These are versions of our actual parameters with the anticipated
......@@ -218,15 +213,8 @@ class RandomVariable(Op):
# Distribution has no parameters
batch_shape = ()
if self.ndim_supp == 0:
supp_shape = ()
else:
supp_shape = self._supp_shape_from_params(
dist_params,
param_shapes=param_shapes,
)
shape = batch_shape + supp_shape
shape = tuple(batch_shape) + tuple(supp_shape)
if not shape:
shape = constant([], dtype="int64")
......
......@@ -206,6 +206,42 @@ def test_RandomVariable_incompatible_size():
rv_op(np.zeros((2, 4, 3)), 1, size=(4,))
class MultivariateRandomVariable(RandomVariable):
name = "MultivariateRandomVariable"
ndim_supp = 1
ndims_params = (1, 2)
dtype = "floatX"
def _supp_shape_from_params(self, dist_params, param_shapes=None):
return [dist_params[0].shape[-1]]
@config.change_flags(compute_test_value="off")
def test_multivariate_rv_infer_static_shape():
"""Test that infer shape for multivariate random variable works when a parameter must be broadcasted."""
mv_op = MultivariateRandomVariable()
param1 = tensor(shape=(10, 2, 3))
param2 = tensor(shape=(10, 2, 3, 3))
assert mv_op(param1, param2).type.shape == (10, 2, 3)
param1 = tensor(shape=(2, 3))
param2 = tensor(shape=(10, 2, 3, 3))
assert mv_op(param1, param2).type.shape == (10, 2, 3)
param1 = tensor(shape=(10, 2, 3))
param2 = tensor(shape=(2, 3, 3))
assert mv_op(param1, param2).type.shape == (10, 2, 3)
param1 = tensor(shape=(10, 1, 3))
param2 = tensor(shape=(2, 3, 3))
assert mv_op(param1, param2).type.shape == (10, 2, 3)
param1 = tensor(shape=(2, 3))
param2 = tensor(shape=(2, 3, 3))
assert mv_op(param1, param2, size=(10, 2)).type.shape == (10, 2, 3)
def test_vectorize_node():
vec = tensor(shape=(None,))
vec.tag.test_value = [0, 0, 0]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论