提交 b8e939e9 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Fix `local_subtensor_rv_lift` rewrite bug with vector parameters

Also allow rewrite to work with multivariate variables, when indexing does not act on support dims.
上级 bfeabc82
from itertools import zip_longest
from pytensor.compile import optdb from pytensor.compile import optdb
from pytensor.configdefaults import config from pytensor.configdefaults import config
from pytensor.graph.op import compute_test_value from pytensor.graph.op import compute_test_value
from pytensor.graph.rewriting.basic import in2out, node_rewriter from pytensor.graph.rewriting.basic import in2out, node_rewriter
from pytensor.tensor import NoneConst
from pytensor.tensor.basic import constant, get_vector_length from pytensor.tensor.basic import constant, get_vector_length
from pytensor.tensor.elemwise import DimShuffle from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.extra_ops import broadcast_to from pytensor.tensor.extra_ops import broadcast_to
...@@ -17,6 +20,7 @@ from pytensor.tensor.subtensor import ( ...@@ -17,6 +20,7 @@ from pytensor.tensor.subtensor import (
get_idx_list, get_idx_list,
indexed_result_shape, indexed_result_shape,
) )
from pytensor.tensor.type_other import SliceType
def is_rv_used_in_graph(base_rv, node, fgraph): def is_rv_used_in_graph(base_rv, node, fgraph):
...@@ -196,37 +200,11 @@ def local_dimshuffle_rv_lift(fgraph, node): ...@@ -196,37 +200,11 @@ def local_dimshuffle_rv_lift(fgraph, node):
def local_subtensor_rv_lift(fgraph, node): def local_subtensor_rv_lift(fgraph, node):
"""Lift a ``*Subtensor`` through ``RandomVariable`` inputs. """Lift a ``*Subtensor`` through ``RandomVariable`` inputs.
In a fashion similar to ``local_dimshuffle_rv_lift``, the indexed dimensions For example, ``normal(mu, std)[0] == normal(mu[0], std[0])``.
need to be separated into distinct replication-space and (independent)
parameter-space ``*Subtensor``s.
The replication-space ``*Subtensor`` can be used to determine a
sub/super-set of the replication-space and, thus, a "smaller"/"larger"
``size`` tuple. The parameter-space ``*Subtensor`` is simply lifted and
applied to the distribution parameters.
Consider the following example graph:
``normal(mu, std, size=(d1, d2, d3))[idx1, idx2, idx3]``. The
``*Subtensor`` ``Op`` requests indices ``idx1``, ``idx2``, and ``idx3``,
which correspond to all three ``size`` dimensions. Now, depending on the
broadcasted dimensions of ``mu`` and ``std``, this ``*Subtensor`` ``Op``
could be reducing the ``size`` parameter and/or sub-setting the independent
``mu`` and ``std`` parameters. Only once the dimensions are properly
separated into the two replication/parameter subspaces can we determine how
the ``*Subtensor`` indices are distributed.
For instance, ``normal(mu, std, size=(d1, d2, d3))[idx1, idx2, idx3]``
could become
``normal(mu[idx1], std[idx2], size=np.shape(idx1) + np.shape(idx2) + np.shape(idx3))``
if ``mu.shape == std.shape == ()``
``normal`` is a rather simple case, because it's univariate. Multivariate
cases require a mapping between the parameter space and the image of the
random variable. This may not always be possible, but for many common
distributions it is. For example, the dimensions of the multivariate
normal's image can be mapped directly to each dimension of its parameters.
We use these mappings to change a graph like ``multivariate_normal(mu, Sigma)[idx1]``
into ``multivariate_normal(mu[idx1], Sigma[idx1, idx1])``.
This rewrite also applies to multivariate distributions as long
as indexing does not happen within core dimensions, such as in
``mvnormal(mu, cov, size=(2,))[0, 0]``.
""" """
st_op = node.op st_op = node.op
...@@ -234,103 +212,92 @@ def local_subtensor_rv_lift(fgraph, node): ...@@ -234,103 +212,92 @@ def local_subtensor_rv_lift(fgraph, node):
if not isinstance(st_op, (AdvancedSubtensor, AdvancedSubtensor1, Subtensor)): if not isinstance(st_op, (AdvancedSubtensor, AdvancedSubtensor1, Subtensor)):
return False return False
base_rv = node.inputs[0] rv = node.inputs[0]
rv_node = rv.owner
rv_node = base_rv.owner
if not (rv_node and isinstance(rv_node.op, RandomVariable)): if not (rv_node and isinstance(rv_node.op, RandomVariable)):
return False return False
# If no one else is using the underlying `RandomVariable`, then we can
# do this; otherwise, the graph would be internally inconsistent.
if is_rv_used_in_graph(base_rv, node, fgraph):
return False
rv_op = rv_node.op rv_op = rv_node.op
rng, size, dtype, *dist_params = rv_node.inputs rng, size, dtype, *dist_params = rv_node.inputs
# TODO: Remove this once the multi-dimensional changes described below are # Parse indices
# in place.
if rv_op.ndim_supp > 0:
return False
rv_op = base_rv.owner.op
rng, size, dtype, *dist_params = base_rv.owner.inputs
idx_list = getattr(st_op, "idx_list", None) idx_list = getattr(st_op, "idx_list", None)
if idx_list: if idx_list:
cdata = get_idx_list(node.inputs, idx_list) cdata = get_idx_list(node.inputs, idx_list)
else: else:
cdata = node.inputs[1:] cdata = node.inputs[1:]
st_indices, st_is_bool = zip( st_indices, st_is_bool = zip(
*tuple( *tuple(
(as_index_variable(i), getattr(i, "dtype", None) == "bool") for i in cdata (as_index_variable(i), getattr(i, "dtype", None) == "bool") for i in cdata
) )
) )
# We need to separate dimensions into replications and independents # Check that indexing does not act on support dims
num_ind_dims = None batched_ndims = rv.ndim - rv_op.ndim_supp
if len(dist_params) == 1: if len(st_indices) > batched_ndims:
num_ind_dims = dist_params[0].ndim # If the last indexes are just dummy `slice(None)` we discard them
else: st_is_bool = st_is_bool[:batched_ndims]
# When there is more than one distribution parameter, assume that all st_indices, supp_indices = (
# of them will broadcast to the maximum number of dimensions st_indices[:batched_ndims],
num_ind_dims = max(d.ndim for d in dist_params) st_indices[batched_ndims:],
)
reps_ind_split_idx = base_rv.ndim - (num_ind_dims + rv_op.ndim_supp) for index in supp_indices:
if not (
if len(st_indices) > reps_ind_split_idx: isinstance(index.type, SliceType)
# These are the indices that need to be applied to the parameters and all(NoneConst.equals(i) for i in index.owner.inputs)
ind_indices = tuple(st_indices[reps_ind_split_idx:]) ):
return False
# We need to broadcast the parameters before applying the `*Subtensor*`
# with these indices, because the indices could be referencing broadcast
# dimensions that don't exist (yet)
bcast_dist_params = broadcast_params(dist_params, rv_op.ndims_params)
# TODO: For multidimensional distributions, we need a map that tells us
# which dimensions of the parameters need to be indexed.
#
# For example, `multivariate_normal` would have the following:
# `RandomVariable.param_to_image_dims = ((0,), (0, 1))`
#
# I.e. the first parameter's (i.e. mean's) first dimension maps directly to
# the dimension of the RV's image, and its second parameter's
# (i.e. covariance's) first and second dimensions map directly to the
# dimension of the RV's image.
args_lifted = tuple(p[ind_indices] for p in bcast_dist_params)
else:
# In this case, no indexing is applied to the parameters; only the
# `size` parameter is affected.
args_lifted = dist_params
# If no one else is using the underlying `RandomVariable`, then we can
# do this; otherwise, the graph would be internally inconsistent.
if is_rv_used_in_graph(rv, node, fgraph):
return False
# Update the size to reflect the indexed dimensions
# TODO: Could use `ShapeFeature` info. We would need to be sure that # TODO: Could use `ShapeFeature` info. We would need to be sure that
# `node` isn't in the results, though. # `node` isn't in the results, though.
# if hasattr(fgraph, "shape_feature"): # if hasattr(fgraph, "shape_feature"):
# output_shape = fgraph.shape_feature.shape_of(node.outputs[0]) # output_shape = fgraph.shape_feature.shape_of(node.outputs[0])
# else: # else:
output_shape = indexed_result_shape(base_rv.shape, st_indices) output_shape_ignoring_bool = indexed_result_shape(rv.shape, st_indices)
new_size_ignoring_boolean = (
size_lifted = ( output_shape_ignoring_bool
output_shape if rv_op.ndim_supp == 0 else output_shape[: -rv_op.ndim_supp] if rv_op.ndim_supp == 0
else output_shape_ignoring_bool[: -rv_op.ndim_supp]
) )
# Boolean indices can actually change the `size` value (compared to just # Boolean indices can actually change the `size` value (compared to just *which* dimensions of `size` are used).
# *which* dimensions of `size` are used). # The `indexed_result_shape` helper does not consider this
if any(st_is_bool): if any(st_is_bool):
size_lifted = tuple( new_size = tuple(
at_sum(idx) if is_bool else s at_sum(idx) if is_bool else s
for s, is_bool, idx in zip( for s, is_bool, idx in zip_longest(
size_lifted, st_is_bool, st_indices[: (reps_ind_split_idx + 1)] new_size_ignoring_boolean, st_is_bool, st_indices, fillvalue=False
) )
) )
else:
new_size = new_size_ignoring_boolean
# Update the parameters to reflect the indexed dimensions
new_dist_params = []
for param, param_ndim_supp in zip(dist_params, rv_op.ndims_params):
# Apply indexing on the batched dimensions of the parameter
batched_param_dims_missing = batched_ndims - (param.ndim - param_ndim_supp)
batched_param = shape_padleft(param, batched_param_dims_missing)
batched_st_indices = []
for st_index, batched_param_shape in zip(st_indices, batched_param.type.shape):
# If we have a degenerate dimension indexing it should always do the job
if batched_param_shape == 1:
batched_st_indices.append(0)
else:
batched_st_indices.append(st_index)
new_dist_params.append(batched_param[tuple(batched_st_indices)])
new_node = rv_op.make_node(rng, size_lifted, dtype, *args_lifted) # Create new RV
_, new_rv = new_node.outputs new_node = rv_op.make_node(rng, new_size, dtype, *new_dist_params)
new_rv = new_node.default_output()
# Calling `Op.make_node` directly circumvents test value computations, so
# we need to compute the test values manually
if config.compute_test_value != "off": if config.compute_test_value != "off":
compute_test_value(new_node) compute_test_value(new_node)
......
...@@ -12,6 +12,7 @@ from pytensor.graph.rewriting.db import RewriteDatabaseQuery ...@@ -12,6 +12,7 @@ from pytensor.graph.rewriting.db import RewriteDatabaseQuery
from pytensor.tensor import constant from pytensor.tensor import constant
from pytensor.tensor.elemwise import DimShuffle from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.random.basic import ( from pytensor.tensor.random.basic import (
categorical,
dirichlet, dirichlet,
multinomial, multinomial,
multivariate_normal, multivariate_normal,
...@@ -36,8 +37,8 @@ def apply_local_rewrite_to_rv( ...@@ -36,8 +37,8 @@ def apply_local_rewrite_to_rv(
rewrite, op_fn, dist_op, dist_params, size, rng, name=None rewrite, op_fn, dist_op, dist_params, size, rng, name=None
): ):
dist_params_at = [] dist_params_at = []
for p in dist_params: for i, p in enumerate(dist_params):
p_at = at.as_tensor(p).type() p_at = at.as_tensor(p).type(f"p_{i}")
p_at.tag.test_value = p p_at.tag.test_value = p
dist_params_at.append(p_at) dist_params_at.append(p_at)
...@@ -495,8 +496,79 @@ def test_DimShuffle_lift(ds_order, lifted, dist_op, dist_params, size, rtol): ...@@ -495,8 +496,79 @@ def test_DimShuffle_lift(ds_order, lifted, dist_op, dist_params, size, rtol):
), ),
(3, 2, 2), (3, 2, 2),
), ),
# A multi-dimensional case # Only one distribution parameter
(
(0,),
True,
poisson,
(np.array([[1, 2], [3, 4]], dtype=config.floatX),),
(3, 2, 2),
),
# Univariate distribution with vector parameters
(
(np.array([0, 2]),),
True,
categorical,
(np.array([0.0, 0.0, 1.0], dtype=config.floatX),),
(4,),
),
(
(np.array([True, False, True, True]),),
True,
categorical,
(np.array([0.0, 0.0, 1.0], dtype=config.floatX),),
(4,),
),
(
(np.array([True, False, True]),),
True,
categorical,
(
np.array(
[[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]],
dtype=config.floatX,
),
),
(),
),
(
(
slice(None),
np.array([True, False, True]),
),
True,
categorical,
(
np.array(
[[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]],
dtype=config.floatX,
),
),
(4, 3),
),
# Boolean indexing where output is empty
(
(np.array([False, False]),),
True,
normal,
(np.array([[1.0, 0.0, 0.0]], dtype=config.floatX),),
(2, 3),
),
( (
(np.array([False, False]),),
True,
categorical,
(
np.array(
[[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]],
dtype=config.floatX,
),
),
(2, 3),
),
# Multivariate cases, indexing only supported if it does not affect core dimensions
(
# Indexing dips into core dimension
(np.array([1]), 0), (np.array([1]), 0),
False, False,
multivariate_normal, multivariate_normal,
...@@ -506,13 +578,30 @@ def test_DimShuffle_lift(ds_order, lifted, dist_op, dist_params, size, rtol): ...@@ -506,13 +578,30 @@ def test_DimShuffle_lift(ds_order, lifted, dist_op, dist_params, size, rtol):
), ),
(), (),
), ),
# Only one distribution parameter
( (
(0,), (np.array([0, 2]),),
True, True,
poisson, multivariate_normal,
(np.array([[1, 2], [3, 4]], dtype=config.floatX),), (
(3, 2, 2), np.array(
[[-100, -125, -150], [0, 0, 0], [200, 225, 250]],
dtype=config.floatX,
),
np.eye(3, dtype=config.floatX) * 1e-6,
),
(),
),
(
(np.array([True, False, True]), slice(None)),
True,
multivariate_normal,
(
np.array([200, 250], dtype=config.floatX),
# Second covariance is invalid, to test it is not chosen
np.dstack([np.eye(2), np.eye(2) * 0, np.eye(2)]).T.astype(config.floatX)
* 1e-6,
),
(3,),
), ),
], ],
) )
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论