提交 b8e939e9 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Fix `local_subtensor_rv_lift` rewrite bug with vector parameters

Also allow rewrite to work with multivariate variables, when indexing does not act on support dims.
上级 bfeabc82
from itertools import zip_longest
from pytensor.compile import optdb
from pytensor.configdefaults import config
from pytensor.graph.op import compute_test_value
from pytensor.graph.rewriting.basic import in2out, node_rewriter
from pytensor.tensor import NoneConst
from pytensor.tensor.basic import constant, get_vector_length
from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.extra_ops import broadcast_to
......@@ -17,6 +20,7 @@ from pytensor.tensor.subtensor import (
get_idx_list,
indexed_result_shape,
)
from pytensor.tensor.type_other import SliceType
def is_rv_used_in_graph(base_rv, node, fgraph):
......@@ -196,37 +200,11 @@ def local_dimshuffle_rv_lift(fgraph, node):
def local_subtensor_rv_lift(fgraph, node):
"""Lift a ``*Subtensor`` through ``RandomVariable`` inputs.
In a fashion similar to ``local_dimshuffle_rv_lift``, the indexed dimensions
need to be separated into distinct replication-space and (independent)
parameter-space ``*Subtensor``s.
The replication-space ``*Subtensor`` can be used to determine a
sub/super-set of the replication-space and, thus, a "smaller"/"larger"
``size`` tuple. The parameter-space ``*Subtensor`` is simply lifted and
applied to the distribution parameters.
Consider the following example graph:
``normal(mu, std, size=(d1, d2, d3))[idx1, idx2, idx3]``. The
``*Subtensor`` ``Op`` requests indices ``idx1``, ``idx2``, and ``idx3``,
which correspond to all three ``size`` dimensions. Now, depending on the
broadcasted dimensions of ``mu`` and ``std``, this ``*Subtensor`` ``Op``
could be reducing the ``size`` parameter and/or sub-setting the independent
``mu`` and ``std`` parameters. Only once the dimensions are properly
separated into the two replication/parameter subspaces can we determine how
the ``*Subtensor`` indices are distributed.
For instance, ``normal(mu, std, size=(d1, d2, d3))[idx1, idx2, idx3]``
could become
``normal(mu[idx1], std[idx2], size=np.shape(idx1) + np.shape(idx2) + np.shape(idx3))``
if ``mu.shape == std.shape == ()``
``normal`` is a rather simple case, because it's univariate. Multivariate
cases require a mapping between the parameter space and the image of the
random variable. This may not always be possible, but for many common
distributions it is. For example, the dimensions of the multivariate
normal's image can be mapped directly to each dimension of its parameters.
We use these mappings to change a graph like ``multivariate_normal(mu, Sigma)[idx1]``
into ``multivariate_normal(mu[idx1], Sigma[idx1, idx1])``.
For example, ``normal(mu, std)[0] == normal(mu[0], std[0])``.
This rewrite also applies to multivariate distributions as long
as indexing does not happen within core dimensions, such as in
``mvnormal(mu, cov, size=(2,))[0, 0]``.
"""
st_op = node.op
......@@ -234,103 +212,92 @@ def local_subtensor_rv_lift(fgraph, node):
if not isinstance(st_op, (AdvancedSubtensor, AdvancedSubtensor1, Subtensor)):
return False
base_rv = node.inputs[0]
rv = node.inputs[0]
rv_node = rv.owner
rv_node = base_rv.owner
if not (rv_node and isinstance(rv_node.op, RandomVariable)):
return False
# If no one else is using the underlying `RandomVariable`, then we can
# do this; otherwise, the graph would be internally inconsistent.
if is_rv_used_in_graph(base_rv, node, fgraph):
return False
rv_op = rv_node.op
rng, size, dtype, *dist_params = rv_node.inputs
# TODO: Remove this once the multi-dimensional changes described below are
# in place.
if rv_op.ndim_supp > 0:
return False
rv_op = base_rv.owner.op
rng, size, dtype, *dist_params = base_rv.owner.inputs
# Parse indices
idx_list = getattr(st_op, "idx_list", None)
if idx_list:
cdata = get_idx_list(node.inputs, idx_list)
else:
cdata = node.inputs[1:]
st_indices, st_is_bool = zip(
*tuple(
(as_index_variable(i), getattr(i, "dtype", None) == "bool") for i in cdata
)
)
# We need to separate dimensions into replications and independents
num_ind_dims = None
if len(dist_params) == 1:
num_ind_dims = dist_params[0].ndim
else:
# When there is more than one distribution parameter, assume that all
# of them will broadcast to the maximum number of dimensions
num_ind_dims = max(d.ndim for d in dist_params)
reps_ind_split_idx = base_rv.ndim - (num_ind_dims + rv_op.ndim_supp)
if len(st_indices) > reps_ind_split_idx:
# These are the indices that need to be applied to the parameters
ind_indices = tuple(st_indices[reps_ind_split_idx:])
# We need to broadcast the parameters before applying the `*Subtensor*`
# with these indices, because the indices could be referencing broadcast
# dimensions that don't exist (yet)
bcast_dist_params = broadcast_params(dist_params, rv_op.ndims_params)
# TODO: For multidimensional distributions, we need a map that tells us
# which dimensions of the parameters need to be indexed.
#
# For example, `multivariate_normal` would have the following:
# `RandomVariable.param_to_image_dims = ((0,), (0, 1))`
#
# I.e. the first parameter's (i.e. mean's) first dimension maps directly to
# the dimension of the RV's image, and its second parameter's
# (i.e. covariance's) first and second dimensions map directly to the
# dimension of the RV's image.
args_lifted = tuple(p[ind_indices] for p in bcast_dist_params)
else:
# In this case, no indexing is applied to the parameters; only the
# `size` parameter is affected.
args_lifted = dist_params
# Check that indexing does not act on support dims
batched_ndims = rv.ndim - rv_op.ndim_supp
if len(st_indices) > batched_ndims:
# If the last indexes are just dummy `slice(None)` we discard them
st_is_bool = st_is_bool[:batched_ndims]
st_indices, supp_indices = (
st_indices[:batched_ndims],
st_indices[batched_ndims:],
)
for index in supp_indices:
if not (
isinstance(index.type, SliceType)
and all(NoneConst.equals(i) for i in index.owner.inputs)
):
return False
# If no one else is using the underlying `RandomVariable`, then we can
# do this; otherwise, the graph would be internally inconsistent.
if is_rv_used_in_graph(rv, node, fgraph):
return False
# Update the size to reflect the indexed dimensions
# TODO: Could use `ShapeFeature` info. We would need to be sure that
# `node` isn't in the results, though.
# if hasattr(fgraph, "shape_feature"):
# output_shape = fgraph.shape_feature.shape_of(node.outputs[0])
# else:
output_shape = indexed_result_shape(base_rv.shape, st_indices)
size_lifted = (
output_shape if rv_op.ndim_supp == 0 else output_shape[: -rv_op.ndim_supp]
output_shape_ignoring_bool = indexed_result_shape(rv.shape, st_indices)
new_size_ignoring_boolean = (
output_shape_ignoring_bool
if rv_op.ndim_supp == 0
else output_shape_ignoring_bool[: -rv_op.ndim_supp]
)
# Boolean indices can actually change the `size` value (compared to just
# *which* dimensions of `size` are used).
# Boolean indices can actually change the `size` value (compared to just *which* dimensions of `size` are used).
# The `indexed_result_shape` helper does not consider this
if any(st_is_bool):
size_lifted = tuple(
new_size = tuple(
at_sum(idx) if is_bool else s
for s, is_bool, idx in zip(
size_lifted, st_is_bool, st_indices[: (reps_ind_split_idx + 1)]
for s, is_bool, idx in zip_longest(
new_size_ignoring_boolean, st_is_bool, st_indices, fillvalue=False
)
)
else:
new_size = new_size_ignoring_boolean
# Update the parameters to reflect the indexed dimensions
new_dist_params = []
for param, param_ndim_supp in zip(dist_params, rv_op.ndims_params):
# Apply indexing on the batched dimensions of the parameter
batched_param_dims_missing = batched_ndims - (param.ndim - param_ndim_supp)
batched_param = shape_padleft(param, batched_param_dims_missing)
batched_st_indices = []
for st_index, batched_param_shape in zip(st_indices, batched_param.type.shape):
# If we have a degenerate dimension indexing it should always do the job
if batched_param_shape == 1:
batched_st_indices.append(0)
else:
batched_st_indices.append(st_index)
new_dist_params.append(batched_param[tuple(batched_st_indices)])
new_node = rv_op.make_node(rng, size_lifted, dtype, *args_lifted)
_, new_rv = new_node.outputs
# Create new RV
new_node = rv_op.make_node(rng, new_size, dtype, *new_dist_params)
new_rv = new_node.default_output()
# Calling `Op.make_node` directly circumvents test value computations, so
# we need to compute the test values manually
if config.compute_test_value != "off":
compute_test_value(new_node)
......
......@@ -12,6 +12,7 @@ from pytensor.graph.rewriting.db import RewriteDatabaseQuery
from pytensor.tensor import constant
from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.random.basic import (
categorical,
dirichlet,
multinomial,
multivariate_normal,
......@@ -36,8 +37,8 @@ def apply_local_rewrite_to_rv(
rewrite, op_fn, dist_op, dist_params, size, rng, name=None
):
dist_params_at = []
for p in dist_params:
p_at = at.as_tensor(p).type()
for i, p in enumerate(dist_params):
p_at = at.as_tensor(p).type(f"p_{i}")
p_at.tag.test_value = p
dist_params_at.append(p_at)
......@@ -495,8 +496,79 @@ def test_DimShuffle_lift(ds_order, lifted, dist_op, dist_params, size, rtol):
),
(3, 2, 2),
),
# A multi-dimensional case
# Only one distribution parameter
(
(0,),
True,
poisson,
(np.array([[1, 2], [3, 4]], dtype=config.floatX),),
(3, 2, 2),
),
# Univariate distribution with vector parameters
(
(np.array([0, 2]),),
True,
categorical,
(np.array([0.0, 0.0, 1.0], dtype=config.floatX),),
(4,),
),
(
(np.array([True, False, True, True]),),
True,
categorical,
(np.array([0.0, 0.0, 1.0], dtype=config.floatX),),
(4,),
),
(
(np.array([True, False, True]),),
True,
categorical,
(
np.array(
[[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]],
dtype=config.floatX,
),
),
(),
),
(
(
slice(None),
np.array([True, False, True]),
),
True,
categorical,
(
np.array(
[[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]],
dtype=config.floatX,
),
),
(4, 3),
),
# Boolean indexing where output is empty
(
(np.array([False, False]),),
True,
normal,
(np.array([[1.0, 0.0, 0.0]], dtype=config.floatX),),
(2, 3),
),
(
(np.array([False, False]),),
True,
categorical,
(
np.array(
[[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]],
dtype=config.floatX,
),
),
(2, 3),
),
# Multivariate cases, indexing only supported if it does not affect core dimensions
(
# Indexing dips into core dimension
(np.array([1]), 0),
False,
multivariate_normal,
......@@ -506,13 +578,30 @@ def test_DimShuffle_lift(ds_order, lifted, dist_op, dist_params, size, rtol):
),
(),
),
# Only one distribution parameter
(
(0,),
(np.array([0, 2]),),
True,
poisson,
(np.array([[1, 2], [3, 4]], dtype=config.floatX),),
(3, 2, 2),
multivariate_normal,
(
np.array(
[[-100, -125, -150], [0, 0, 0], [200, 225, 250]],
dtype=config.floatX,
),
np.eye(3, dtype=config.floatX) * 1e-6,
),
(),
),
(
(np.array([True, False, True]), slice(None)),
True,
multivariate_normal,
(
np.array([200, 250], dtype=config.floatX),
# Second covariance is invalid, to test it is not chosen
np.dstack([np.eye(2), np.eye(2) * 0, np.eye(2)]).T.astype(config.floatX)
* 1e-6,
),
(3,),
),
],
)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论