提交 2ebfbf1c authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Expand and simplify `local_dimshuffle_rv_lift`

* The rewrite no longer bails out when dimshuffle affects both unique param dimensions and repeated param dimensions from the size argument. This requires: 1) Adding broadcastable dimensions to the parameters, which should be "cost-free" and would need to be done in the `perform` method anyway. 2) Extend size to incorporate implicit batch dimensions coming from the parameters. This requires computing the shape resulting from broadcasting the parameters. It's unclear whether this is less performant, because the `perform` method can now simply broadcast each parameter to the size, instead of having to broadcast the parameters together. * The rewrite now works with Multivariate RVs * The rewrite bails out when dimensions are dropped by the Dimshuffle. This case was not correctly handled by the previous rewrite
上级 8e61224a
......@@ -8,7 +8,7 @@ from pytensor.tensor.extra_ops import broadcast_to
from pytensor.tensor.math import sum as at_sum
from pytensor.tensor.random.op import RandomVariable
from pytensor.tensor.random.utils import broadcast_params
from pytensor.tensor.shape import Shape, Shape_i
from pytensor.tensor.shape import Shape, Shape_i, shape_padleft
from pytensor.tensor.subtensor import (
AdvancedSubtensor,
AdvancedSubtensor1,
......@@ -115,23 +115,10 @@ def local_dimshuffle_rv_lift(fgraph, node):
For example, ``normal(mu, std).T == normal(mu.T, std.T)``.
The basic idea behind this rewrite is that we need to separate the
``DimShuffle``-ing into distinct ``DimShuffle``s that each occur in two
distinct sub-spaces: the (set of independent) parameters and ``size``
(i.e. replications) sub-spaces.
If a ``DimShuffle`` exchanges dimensions across those two sub-spaces, then we
don't do anything.
Otherwise, if the ``DimShuffle`` only exchanges dimensions within each of
those sub-spaces, we can break it apart and apply the parameter-space
``DimShuffle`` to the distribution parameters, and then apply the
replications-space ``DimShuffle`` to the ``size`` tuple. The latter is a
particularly simple rearranging of a tuple, but the former requires a
little more work.
TODO: Currently, multivariate support for this rewrite is disabled.
This rewrite is only applicable when the Dimshuffle operation does
not affect support dimensions.
TODO: Support dimension dropping
"""
ds_op = node.op
......@@ -142,118 +129,59 @@ def local_dimshuffle_rv_lift(fgraph, node):
base_rv = node.inputs[0]
rv_node = base_rv.owner
if not (
rv_node and isinstance(rv_node.op, RandomVariable) and rv_node.op.ndim_supp == 0
):
if not (rv_node and isinstance(rv_node.op, RandomVariable)):
return False
# If no one else is using the underlying `RandomVariable`, then we can
# do this; otherwise, the graph would be internally inconsistent.
if is_rv_used_in_graph(base_rv, node, fgraph):
# Dimshuffle which drop dimensions not supported yet
if ds_op.drop:
return False
rv_op = rv_node.op
rng, size, dtype, *dist_params = rv_node.inputs
rv = rv_node.default_output()
# We need to know the dimensions that were *not* added by the `size`
# parameter (i.e. the dimensions corresponding to independent variates with
# different parameter values)
num_ind_dims = None
if len(dist_params) == 1:
num_ind_dims = dist_params[0].ndim
else:
# When there is more than one distribution parameter, assume that all
# of them will broadcast to the maximum number of dimensions
num_ind_dims = max(d.ndim for d in dist_params)
# If the indices in `ds_new_order` are entirely within the replication
# indices group or the independent variates indices group, then we can apply
# this rewrite.
ds_new_order = ds_op.new_order
# Create a map from old index order to new/`DimShuffled` index order
dim_orders = [(n, d) for n, d in enumerate(ds_new_order) if isinstance(d, int)]
# Find the index at which the replications/independents split occurs
reps_ind_split_idx = len(dim_orders) - (num_ind_dims + rv_op.ndim_supp)
# Check that Dimshuffle does not affect support dims
supp_dims = set(range(rv.ndim - rv_op.ndim_supp, rv.ndim))
shuffled_dims = {dim for i, dim in enumerate(ds_op.shuffle) if dim != i}
augmented_dims = set(d - rv_op.ndim_supp for d in ds_op.augment)
if (shuffled_dims | augmented_dims) & supp_dims:
return False
ds_reps_new_dims = dim_orders[:reps_ind_split_idx]
ds_ind_new_dims = dim_orders[reps_ind_split_idx:]
ds_in_ind_space = ds_ind_new_dims and all(
d >= reps_ind_split_idx for n, d in ds_ind_new_dims
)
# If no one else is using the underlying RandomVariable, then we can
# do this; otherwise, the graph would be internally inconsistent.
if is_rv_used_in_graph(base_rv, node, fgraph):
return False
if ds_in_ind_space or (not ds_ind_new_dims and not ds_reps_new_dims):
batched_dims = rv.ndim - rv_op.ndim_supp
batched_dims_ds_order = tuple(o for o in ds_op.new_order if o not in supp_dims)
# Update the `size` array to reflect the `DimShuffle`d dimensions,
# since the trailing dimensions in `size` represent the independent
# variates dimensions (for univariate distributions, at least)
has_size = get_vector_length(size) > 0
new_size = (
[constant(1, dtype="int64") if o == "x" else size[o] for o in ds_new_order]
if has_size
else size
)
# Make size explicit
missing_size_dims = batched_dims - get_vector_length(size)
if missing_size_dims > 0:
full_size = tuple(broadcast_params(dist_params, rv_op.ndims_params)[0].shape)
size = full_size[:missing_size_dims] + tuple(size)
# Compute the new axes parameter(s) for the `DimShuffle` that will be
# applied to the `RandomVariable` parameters (they need to be offset)
if ds_ind_new_dims:
rv_params_new_order = [
d - reps_ind_split_idx if isinstance(d, int) else d
for d in ds_new_order[ds_ind_new_dims[0][0] :]
# Update the size to reflect the DimShuffled dimensions
new_size = [
constant(1, dtype="int64") if o == "x" else size[o]
for o in batched_dims_ds_order
]
if not has_size and len(ds_new_order[: ds_ind_new_dims[0][0]]) > 0:
# Additional broadcast dimensions need to be added to the
# independent dimensions (i.e. parameters), since there's no
# `size` to which they can be added
rv_params_new_order = (
list(ds_new_order[: ds_ind_new_dims[0][0]]) + rv_params_new_order
)
else:
# This case is reached when, for example, `ds_new_order` only
# consists of new broadcastable dimensions (i.e. `"x"`s)
rv_params_new_order = ds_new_order
# Lift the `DimShuffle`s into the parameters
# NOTE: The parameters might not be broadcasted against each other, so
# we can only apply the parts of the `DimShuffle` that are relevant.
# Updates the params to reflect the Dimshuffled dimensions
new_dist_params = []
for d in dist_params:
if d.ndim < len(ds_ind_new_dims):
_rv_params_new_order = [
o
for o in rv_params_new_order
if (isinstance(o, int) and o < d.ndim) or o == "x"
]
else:
_rv_params_new_order = rv_params_new_order
new_dist_params.append(
type(ds_op)(d.type.broadcastable, _rv_params_new_order)(d)
for param, param_ndim_supp in zip(dist_params, rv_op.ndims_params):
# Add broadcastable dimensions to the parameters that would have been expanded by the size
padleft = batched_dims - (param.ndim - param_ndim_supp)
if padleft > 0:
param = shape_padleft(param, padleft)
# Add the parameter support dimension indexes to the batched dimensions Dimshuffle
param_new_order = batched_dims_ds_order + tuple(
range(batched_dims, batched_dims + param_ndim_supp)
)
new_node = rv_op.make_node(rng, new_size, dtype, *new_dist_params)
if config.compute_test_value != "off":
compute_test_value(new_node)
new_dist_params.append(param.dimshuffle(param_new_order))
out = new_node.outputs[1]
if base_rv.name:
out.name = f"{base_rv.name}_lifted"
return [out]
ds_in_reps_space = ds_reps_new_dims and all(
d < reps_ind_split_idx for n, d in ds_reps_new_dims
)
if ds_in_reps_space:
# Update the `size` array to reflect the `DimShuffle`d dimensions.
# There should be no need to `DimShuffle` now.
new_size = [
constant(1, dtype="int64") if o == "x" else size[o] for o in ds_new_order
]
new_node = rv_op.make_node(rng, new_size, dtype, *dist_params)
new_node = rv_op.make_node(rng, new_size, dtype, *new_dist_params)
if config.compute_test_value != "off":
compute_test_value(new_node)
......@@ -263,8 +191,6 @@ def local_dimshuffle_rv_lift(fgraph, node):
out.name = f"{base_rv.name}_lifted"
return [out]
return False
@node_rewriter([Subtensor, AdvancedSubtensor1, AdvancedSubtensor])
def local_subtensor_rv_lift(fgraph, node):
......
......@@ -9,6 +9,7 @@ from pytensor.graph.basic import Constant
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.rewriting.basic import EquilibriumGraphRewriter
from pytensor.graph.rewriting.db import RewriteDatabaseQuery
from pytensor.tensor import constant
from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.random.basic import (
dirichlet,
......@@ -42,6 +43,10 @@ def apply_local_rewrite_to_rv(
size_at = []
for s in size:
# To test DimShuffle with dropping dims we need that size dimension to be constant
if s == 1:
s_at = constant(np.array(1, dtype="int32"))
else:
s_at = iscalar()
s_at.tag.test_value = s
size_at.append(s_at)
......@@ -314,7 +319,7 @@ def test_local_rv_size_lift(dist_op, dist_params, size):
),
(
("x", 1, 0, 2, "x"),
False,
True,
normal,
(
np.array([[-1, 20], [300, -4000]], dtype=config.floatX),
......@@ -332,7 +337,30 @@ def test_local_rv_size_lift(dist_op, dist_params, size):
(3, 2, 2),
1,
),
# A multi-dimensional case
# Supported multi-dimensional cases
(
(1, 0, 2),
True,
multivariate_normal,
(
np.array([[-1, 20], [300, -4000]], dtype=config.floatX),
np.eye(2).astype(config.floatX) * 1e-6,
),
(3, 2),
1e-3,
),
(
(1, 0, "x", 2),
True,
multivariate_normal,
(
np.array([[-1, 20], [300, -4000]], dtype=config.floatX),
np.eye(2).astype(config.floatX) * 1e-6,
),
(3, 2),
1e-3,
),
# Not supported multi-dimensional cases where dimshuffle affects the support dimensionality
(
(0, 2, 1),
False,
......@@ -344,6 +372,35 @@ def test_local_rv_size_lift(dist_op, dist_params, size):
(3, 2),
1e-3,
),
(
(0, 1, 2, "x"),
False,
multivariate_normal,
(
np.array([[-1, 20], [300, -4000]], dtype=config.floatX),
np.eye(2).astype(config.floatX) * 1e-6,
),
(3, 2),
1e-3,
),
pytest.param(
(1,),
True,
normal,
(0, 1),
(1, 2),
1e-3,
marks=pytest.mark.xfail(reason="Dropping dimensions not supported yet"),
),
pytest.param(
(1,),
True,
normal,
([[0, 0]], 1),
(1, 2),
1e-3,
marks=pytest.mark.xfail(reason="Dropping dimensions not supported yet"),
),
],
)
@config.change_flags(compute_test_value_opt="raise", compute_test_value="raise")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论