提交 c38eea06 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Fix shape inference for multivariate random Ops

When size is not provided, the batch shapes of the parameters were being broadcasted twice, and the second time, wrongly, due to mixing static shape of the original parameters and the potentially larger shape of the just broadcasted parameters.
上级 205da7f9
...@@ -20,11 +20,7 @@ from pytensor.tensor.basic import ( ...@@ -20,11 +20,7 @@ from pytensor.tensor.basic import (
infer_static_shape, infer_static_shape,
) )
from pytensor.tensor.random.type import RandomGeneratorType, RandomStateType, RandomType from pytensor.tensor.random.type import RandomGeneratorType, RandomStateType, RandomType
from pytensor.tensor.random.utils import ( from pytensor.tensor.random.utils import broadcast_params, normalize_size_param
broadcast_params,
normalize_size_param,
params_broadcast_shapes,
)
from pytensor.tensor.shape import shape_tuple from pytensor.tensor.shape import shape_tuple
from pytensor.tensor.type import TensorType, all_dtypes from pytensor.tensor.type import TensorType, all_dtypes
from pytensor.tensor.type_other import NoneConst from pytensor.tensor.type_other import NoneConst
...@@ -156,6 +152,13 @@ class RandomVariable(Op): ...@@ -156,6 +152,13 @@ class RandomVariable(Op):
from pytensor.tensor.extra_ops import broadcast_shape_iter from pytensor.tensor.extra_ops import broadcast_shape_iter
if self.ndim_supp == 0:
supp_shape = ()
else:
supp_shape = tuple(
self._supp_shape_from_params(dist_params, param_shapes=param_shapes)
)
size_len = get_vector_length(size) size_len = get_vector_length(size)
if size_len > 0: if size_len > 0:
...@@ -171,19 +174,11 @@ class RandomVariable(Op): ...@@ -171,19 +174,11 @@ class RandomVariable(Op):
f"Size length must be 0 or >= {param_batched_dims}" f"Size length must be 0 or >= {param_batched_dims}"
) )
if self.ndim_supp == 0: return tuple(size) + supp_shape
return size
else:
supp_shape = self._supp_shape_from_params(
dist_params, param_shapes=param_shapes
)
return tuple(size) + tuple(supp_shape)
# Broadcast the parameters # Size was not provided, we must infer it from the shape of the parameters
param_shapes = params_broadcast_shapes( if param_shapes is None:
param_shapes or [shape_tuple(p) for p in dist_params], param_shapes = [shape_tuple(p) for p in dist_params]
self.ndims_params,
)
def extract_batch_shape(p, ps, n): def extract_batch_shape(p, ps, n):
shape = tuple(ps) shape = tuple(ps)
...@@ -191,10 +186,10 @@ class RandomVariable(Op): ...@@ -191,10 +186,10 @@ class RandomVariable(Op):
if n == 0: if n == 0:
return shape return shape
batch_shape = [ batch_shape = tuple(
s if not b else constant(1, "int64") s if not b else constant(1, "int64")
for s, b in zip(shape[:-n], p.type.broadcastable[:-n]) for s, b in zip(shape[:-n], p.type.broadcastable[:-n])
] )
return batch_shape return batch_shape
# These are versions of our actual parameters with the anticipated # These are versions of our actual parameters with the anticipated
...@@ -218,15 +213,8 @@ class RandomVariable(Op): ...@@ -218,15 +213,8 @@ class RandomVariable(Op):
# Distribution has no parameters # Distribution has no parameters
batch_shape = () batch_shape = ()
if self.ndim_supp == 0: shape = batch_shape + supp_shape
supp_shape = ()
else:
supp_shape = self._supp_shape_from_params(
dist_params,
param_shapes=param_shapes,
)
shape = tuple(batch_shape) + tuple(supp_shape)
if not shape: if not shape:
shape = constant([], dtype="int64") shape = constant([], dtype="int64")
......
...@@ -206,6 +206,42 @@ def test_RandomVariable_incompatible_size(): ...@@ -206,6 +206,42 @@ def test_RandomVariable_incompatible_size():
rv_op(np.zeros((2, 4, 3)), 1, size=(4,)) rv_op(np.zeros((2, 4, 3)), 1, size=(4,))
class MultivariateRandomVariable(RandomVariable):
name = "MultivariateRandomVariable"
ndim_supp = 1
ndims_params = (1, 2)
dtype = "floatX"
def _supp_shape_from_params(self, dist_params, param_shapes=None):
return [dist_params[0].shape[-1]]
@config.change_flags(compute_test_value="off")
def test_multivariate_rv_infer_static_shape():
"""Test that infer shape for multivariate random variable works when a parameter must be broadcasted."""
mv_op = MultivariateRandomVariable()
param1 = tensor(shape=(10, 2, 3))
param2 = tensor(shape=(10, 2, 3, 3))
assert mv_op(param1, param2).type.shape == (10, 2, 3)
param1 = tensor(shape=(2, 3))
param2 = tensor(shape=(10, 2, 3, 3))
assert mv_op(param1, param2).type.shape == (10, 2, 3)
param1 = tensor(shape=(10, 2, 3))
param2 = tensor(shape=(2, 3, 3))
assert mv_op(param1, param2).type.shape == (10, 2, 3)
param1 = tensor(shape=(10, 1, 3))
param2 = tensor(shape=(2, 3, 3))
assert mv_op(param1, param2).type.shape == (10, 2, 3)
param1 = tensor(shape=(2, 3))
param2 = tensor(shape=(2, 3, 3))
assert mv_op(param1, param2, size=(10, 2)).type.shape == (10, 2, 3)
def test_vectorize_node(): def test_vectorize_node():
vec = tensor(shape=(None,)) vec = tensor(shape=(None,))
vec.tag.test_value = [0, 0, 0] vec.tag.test_value = [0, 0, 0]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论