提交 61c15af3 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Handle implicit broadcasting correctly in RandomVariable vectorization

上级 e8273115
......@@ -8,7 +8,7 @@ import pytensor
from pytensor.configdefaults import config
from pytensor.graph.basic import Apply, Variable, equal_computations
from pytensor.graph.op import Op
from pytensor.graph.replace import _vectorize_node
from pytensor.graph.replace import _vectorize_node, vectorize_graph
from pytensor.misc.safe_asarray import _asarray
from pytensor.scalar import ScalarVariable
from pytensor.tensor.basic import (
......@@ -20,7 +20,10 @@ from pytensor.tensor.basic import (
infer_static_shape,
)
from pytensor.tensor.random.type import RandomGeneratorType, RandomStateType, RandomType
from pytensor.tensor.random.utils import broadcast_params, normalize_size_param
from pytensor.tensor.random.utils import (
explicit_expand_dims,
normalize_size_param,
)
from pytensor.tensor.shape import shape_tuple
from pytensor.tensor.type import TensorType, all_dtypes
from pytensor.tensor.type_other import NoneConst
......@@ -387,10 +390,26 @@ def vectorize_random_variable(
# If size was provided originally and a new size hasn't been provided,
# We extend it to accommodate the new input batch dimensions.
# Otherwise, we assume the new size already has the right values
# Need to make parameters implicit broadcasting explicit
original_dist_params = node.inputs[3:]
old_size = node.inputs[1]
len_old_size = get_vector_length(old_size)
original_expanded_dist_params = explicit_expand_dims(
original_dist_params, op.ndims_params, len_old_size
)
# We call vectorize_graph to automatically handle any new explicit expand_dims
dist_params = vectorize_graph(
original_expanded_dist_params, dict(zip(original_dist_params, dist_params))
)
if len_old_size and equal_computations([old_size], [size]):
bcasted_param = broadcast_params(dist_params, op.ndims_params)[0]
# If the original RV had a size variable and a new one has not been provided,
# we need to define a new size as the concatenation of the original size dimensions
# and the novel ones implied by new broadcasted batched parameters dimensions.
# We use the first broadcasted batch dimension for reference.
bcasted_param = explicit_expand_dims(dist_params, op.ndims_params)[0]
new_param_ndim = (bcasted_param.type.ndim - op.ndims_params[0]) - len_old_size
if new_param_ndim >= 0:
new_size_dims = bcasted_param.shape[:new_param_ndim]
......
......@@ -13,7 +13,7 @@ from pytensor.tensor import get_vector_length
from pytensor.tensor.basic import as_tensor_variable, cast, constant
from pytensor.tensor.extra_ops import broadcast_to
from pytensor.tensor.math import maximum
from pytensor.tensor.shape import specify_shape
from pytensor.tensor.shape import shape_padleft, specify_shape
from pytensor.tensor.type import int_dtypes
from pytensor.tensor.variable import TensorVariable
......@@ -121,6 +121,34 @@ def broadcast_params(params, ndims_params):
return bcast_params
def explicit_expand_dims(
params: Sequence[TensorVariable],
ndim_params: tuple[int],
size_length: int = 0,
) -> list[TensorVariable]:
"""Introduce explicit expand_dims in RV parameters that are implicitly broadcasted together and/or by size."""
batch_dims = [
param.type.ndim - ndim_param for param, ndim_param in zip(params, ndim_params)
]
if size_length:
# NOTE: PyTensor is currently treating zero-length size as size=None, which is not what Numpy does
# See: https://github.com/pymc-devs/pytensor/issues/568
max_batch_dims = size_length
else:
max_batch_dims = max(batch_dims)
new_params = []
for new_param, batch_dim in zip(params, batch_dims):
missing_dims = max_batch_dims - batch_dim
if missing_dims:
new_param = shape_padleft(new_param, missing_dims)
new_params.append(new_param)
return new_params
def normalize_size_param(
size: Optional[Union[int, np.ndarray, Variable, Sequence]],
) -> Variable:
......
......@@ -248,7 +248,7 @@ def test_vectorize_node():
# Test without size
node = normal(vec).owner
new_inputs = node.inputs.copy()
new_inputs[3] = mat
new_inputs[3] = mat # mu
vect_node = vectorize_node(node, *new_inputs)
assert vect_node.op is normal
assert vect_node.inputs[3] is mat
......@@ -256,8 +256,8 @@ def test_vectorize_node():
# Test with size, new size provided
node = normal(vec, size=(3,)).owner
new_inputs = node.inputs.copy()
new_inputs[1] = (2, 3)
new_inputs[3] = mat
new_inputs[1] = (2, 3) # size
new_inputs[3] = mat # mu
vect_node = vectorize_node(node, *new_inputs)
assert vect_node.op is normal
assert tuple(vect_node.inputs[1].eval()) == (2, 3)
......@@ -266,10 +266,37 @@ def test_vectorize_node():
# Test with size, new size not provided
node = normal(vec, size=(3,)).owner
new_inputs = node.inputs.copy()
new_inputs[3] = mat
new_inputs[3] = mat # mu
vect_node = vectorize_node(node, *new_inputs)
assert vect_node.op is normal
assert vect_node.inputs[3] is mat
assert tuple(
vect_node.inputs[1].eval({mat: np.zeros((2, 3), dtype=config.floatX)})
) == (2, 3)
# Test parameter broadcasting
node = normal(vec).owner
new_inputs = node.inputs.copy()
new_inputs[3] = tensor("mu", shape=(10, 5)) # mu
new_inputs[4] = tensor("sigma", shape=(10,)) # sigma
vect_node = vectorize_node(node, *new_inputs)
assert vect_node.op is normal
assert vect_node.default_output().type.shape == (10, 5)
# Test parameter broadcasting with non-expanding size
node = normal(vec, size=(5,)).owner
new_inputs = node.inputs.copy()
new_inputs[3] = tensor("mu", shape=(10, 5)) # mu
new_inputs[4] = tensor("sigma", shape=(10,)) # sigma
vect_node = vectorize_node(node, *new_inputs)
assert vect_node.op is normal
assert vect_node.default_output().type.shape == (10, 5)
# Test parameter broadcasting with expanding size
node = normal(vec, size=(2, 5)).owner
new_inputs = node.inputs.copy()
new_inputs[3] = tensor("mu", shape=(10, 5)) # mu
new_inputs[4] = tensor("sigma", shape=(10,)) # sigma
vect_node = vectorize_node(node, *new_inputs)
assert vect_node.op is normal
assert vect_node.default_output().type.shape == (10, 2, 5)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论