提交 caa580bb authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Fix broadcasting bug in vectorize of RandomVariables

上级 c9f5f656
...@@ -20,6 +20,7 @@ from pytensor.tensor.basic import ( ...@@ -20,6 +20,7 @@ from pytensor.tensor.basic import (
) )
from pytensor.tensor.random.type import RandomGeneratorType, RandomStateType, RandomType from pytensor.tensor.random.type import RandomGeneratorType, RandomStateType, RandomType
from pytensor.tensor.random.utils import ( from pytensor.tensor.random.utils import (
compute_batch_shape,
explicit_expand_dims, explicit_expand_dims,
normalize_size_param, normalize_size_param,
) )
...@@ -403,15 +404,14 @@ def vectorize_random_variable( ...@@ -403,15 +404,14 @@ def vectorize_random_variable(
original_expanded_dist_params, dict(zip(original_dist_params, dist_params)) original_expanded_dist_params, dict(zip(original_dist_params, dist_params))
) )
if len_old_size and equal_computations([old_size], [size]): new_ndim = dist_params[0].type.ndim - original_expanded_dist_params[0].type.ndim
if new_ndim and len_old_size and equal_computations([old_size], [size]):
# If the original RV had a size variable and a new one has not been provided, # If the original RV had a size variable and a new one has not been provided,
# we need to define a new size as the concatenation of the original size dimensions # we need to define a new size as the concatenation of the original size dimensions
# and the novel ones implied by new broadcasted batched parameters dimensions. # and the novel ones implied by new broadcasted batched parameters dimensions.
# We use the first broadcasted batch dimension for reference. broadcasted_batch_shape = compute_batch_shape(dist_params, op.ndims_params)
bcasted_param = explicit_expand_dims(dist_params, op.ndims_params)[0] new_size_dims = broadcasted_batch_shape[:new_ndim]
new_param_ndim = (bcasted_param.type.ndim - op.ndims_params[0]) - len_old_size size = concatenate([new_size_dims, size])
if new_param_ndim >= 0:
new_size_dims = bcasted_param.shape[:new_param_ndim]
size = concatenate([new_size_dims, size])
return op.make_node(rng, size, dtype, *dist_params) return op.make_node(rng, size, dtype, *dist_params)
...@@ -11,7 +11,7 @@ from pytensor.graph.basic import Constant, Variable ...@@ -11,7 +11,7 @@ from pytensor.graph.basic import Constant, Variable
from pytensor.scalar import ScalarVariable from pytensor.scalar import ScalarVariable
from pytensor.tensor import get_vector_length from pytensor.tensor import get_vector_length
from pytensor.tensor.basic import as_tensor_variable, cast, constant from pytensor.tensor.basic import as_tensor_variable, cast, constant
from pytensor.tensor.extra_ops import broadcast_to from pytensor.tensor.extra_ops import broadcast_arrays, broadcast_to
from pytensor.tensor.math import maximum from pytensor.tensor.math import maximum
from pytensor.tensor.shape import shape_padleft, specify_shape from pytensor.tensor.shape import shape_padleft, specify_shape
from pytensor.tensor.type import int_dtypes from pytensor.tensor.type import int_dtypes
...@@ -149,6 +149,15 @@ def explicit_expand_dims( ...@@ -149,6 +149,15 @@ def explicit_expand_dims(
return new_params return new_params
def compute_batch_shape(params, ndims_params: Sequence[int]) -> TensorVariable:
params = explicit_expand_dims(params, ndims_params)
batch_params = [
param[(..., *(0,) * core_ndim)]
for param, core_ndim in zip(params, ndims_params)
]
return broadcast_arrays(*batch_params)[0].shape
def normalize_size_param( def normalize_size_param(
size: int | np.ndarray | Variable | Sequence | None, size: int | np.ndarray | Variable | Sequence | None,
) -> Variable: ) -> Variable:
......
...@@ -292,6 +292,14 @@ def test_vectorize_node(): ...@@ -292,6 +292,14 @@ def test_vectorize_node():
assert vect_node.op is normal assert vect_node.op is normal
assert vect_node.default_output().type.shape == (10, 5) assert vect_node.default_output().type.shape == (10, 5)
node = normal(vec, size=(5,)).owner
new_inputs = node.inputs.copy()
new_inputs[3] = tensor("mu", shape=(1, 5)) # mu
new_inputs[4] = tensor("sigma", shape=(10,)) # sigma
vect_node = vectorize_node(node, *new_inputs)
assert vect_node.op is normal
assert vect_node.default_output().type.shape == (10, 5)
# Test parameter broadcasting with expanding size # Test parameter broadcasting with expanding size
node = normal(vec, size=(2, 5)).owner node = normal(vec, size=(2, 5)).owner
new_inputs = node.inputs.copy() new_inputs = node.inputs.copy()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论