提交 2ebfbf1c authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Expand and simplify `local_dimshuffle_rv_lift`

* The rewrite no longer bails out when dimshuffle affects both unique param dimensions and repeated param dimensions from the size argument. This requires: 1) Adding broadcastable dimensions to the parameters, which should be "cost-free" and would need to be done in the `perform` method anyway. 2) Extend size to incorporate implicit batch dimensions coming from the parameters. This requires computing the shape resulting from broadcasting the parameters. It's unclear whether this is less performant, because the `perform` method can now simply broadcast each parameter to the size, instead of having to broadcast the parameters together. * The rewrite now works with Multivariate RVs * The rewrite bails out when dimensions are dropped by the Dimshuffle. This case was not correctly handled by the previous rewrite
上级 8e61224a
...@@ -8,7 +8,7 @@ from pytensor.tensor.extra_ops import broadcast_to ...@@ -8,7 +8,7 @@ from pytensor.tensor.extra_ops import broadcast_to
from pytensor.tensor.math import sum as at_sum from pytensor.tensor.math import sum as at_sum
from pytensor.tensor.random.op import RandomVariable from pytensor.tensor.random.op import RandomVariable
from pytensor.tensor.random.utils import broadcast_params from pytensor.tensor.random.utils import broadcast_params
from pytensor.tensor.shape import Shape, Shape_i from pytensor.tensor.shape import Shape, Shape_i, shape_padleft
from pytensor.tensor.subtensor import ( from pytensor.tensor.subtensor import (
AdvancedSubtensor, AdvancedSubtensor,
AdvancedSubtensor1, AdvancedSubtensor1,
...@@ -115,23 +115,10 @@ def local_dimshuffle_rv_lift(fgraph, node): ...@@ -115,23 +115,10 @@ def local_dimshuffle_rv_lift(fgraph, node):
For example, ``normal(mu, std).T == normal(mu.T, std.T)``. For example, ``normal(mu, std).T == normal(mu.T, std.T)``.
The basic idea behind this rewrite is that we need to separate the This rewrite is only applicable when the Dimshuffle operation does
``DimShuffle``-ing into distinct ``DimShuffle``s that each occur in two not affect support dimensions.
distinct sub-spaces: the (set of independent) parameters and ``size``
(i.e. replications) sub-spaces.
If a ``DimShuffle`` exchanges dimensions across those two sub-spaces, then we
don't do anything.
Otherwise, if the ``DimShuffle`` only exchanges dimensions within each of
those sub-spaces, we can break it apart and apply the parameter-space
``DimShuffle`` to the distribution parameters, and then apply the
replications-space ``DimShuffle`` to the ``size`` tuple. The latter is a
particularly simple rearranging of a tuple, but the former requires a
little more work.
TODO: Currently, multivariate support for this rewrite is disabled.
TODO: Support dimension dropping
""" """
ds_op = node.op ds_op = node.op
...@@ -142,128 +129,67 @@ def local_dimshuffle_rv_lift(fgraph, node): ...@@ -142,128 +129,67 @@ def local_dimshuffle_rv_lift(fgraph, node):
base_rv = node.inputs[0] base_rv = node.inputs[0]
rv_node = base_rv.owner rv_node = base_rv.owner
if not ( if not (rv_node and isinstance(rv_node.op, RandomVariable)):
rv_node and isinstance(rv_node.op, RandomVariable) and rv_node.op.ndim_supp == 0
):
return False return False
# If no one else is using the underlying `RandomVariable`, then we can # Dimshuffle which drop dimensions not supported yet
# do this; otherwise, the graph would be internally inconsistent. if ds_op.drop:
if is_rv_used_in_graph(base_rv, node, fgraph):
return False return False
rv_op = rv_node.op rv_op = rv_node.op
rng, size, dtype, *dist_params = rv_node.inputs rng, size, dtype, *dist_params = rv_node.inputs
rv = rv_node.default_output()
# We need to know the dimensions that were *not* added by the `size` # Check that Dimshuffle does not affect support dims
# parameter (i.e. the dimensions corresponding to independent variates with supp_dims = set(range(rv.ndim - rv_op.ndim_supp, rv.ndim))
# different parameter values) shuffled_dims = {dim for i, dim in enumerate(ds_op.shuffle) if dim != i}
num_ind_dims = None augmented_dims = set(d - rv_op.ndim_supp for d in ds_op.augment)
if len(dist_params) == 1: if (shuffled_dims | augmented_dims) & supp_dims:
num_ind_dims = dist_params[0].ndim return False
else:
# When there is more than one distribution parameter, assume that all
# of them will broadcast to the maximum number of dimensions
num_ind_dims = max(d.ndim for d in dist_params)
# If the indices in `ds_new_order` are entirely within the replication
# indices group or the independent variates indices group, then we can apply
# this rewrite.
ds_new_order = ds_op.new_order
# Create a map from old index order to new/`DimShuffled` index order
dim_orders = [(n, d) for n, d in enumerate(ds_new_order) if isinstance(d, int)]
# Find the index at which the replications/independents split occurs
reps_ind_split_idx = len(dim_orders) - (num_ind_dims + rv_op.ndim_supp)
ds_reps_new_dims = dim_orders[:reps_ind_split_idx]
ds_ind_new_dims = dim_orders[reps_ind_split_idx:]
ds_in_ind_space = ds_ind_new_dims and all(
d >= reps_ind_split_idx for n, d in ds_ind_new_dims
)
if ds_in_ind_space or (not ds_ind_new_dims and not ds_reps_new_dims): # If no one else is using the underlying RandomVariable, then we can
# do this; otherwise, the graph would be internally inconsistent.
if is_rv_used_in_graph(base_rv, node, fgraph):
return False
# Update the `size` array to reflect the `DimShuffle`d dimensions, batched_dims = rv.ndim - rv_op.ndim_supp
# since the trailing dimensions in `size` represent the independent batched_dims_ds_order = tuple(o for o in ds_op.new_order if o not in supp_dims)
# variates dimensions (for univariate distributions, at least)
has_size = get_vector_length(size) > 0 # Make size explicit
new_size = ( missing_size_dims = batched_dims - get_vector_length(size)
[constant(1, dtype="int64") if o == "x" else size[o] for o in ds_new_order] if missing_size_dims > 0:
if has_size full_size = tuple(broadcast_params(dist_params, rv_op.ndims_params)[0].shape)
else size size = full_size[:missing_size_dims] + tuple(size)
# Update the size to reflect the DimShuffled dimensions
new_size = [
constant(1, dtype="int64") if o == "x" else size[o]
for o in batched_dims_ds_order
]
# Updates the params to reflect the Dimshuffled dimensions
new_dist_params = []
for param, param_ndim_supp in zip(dist_params, rv_op.ndims_params):
# Add broadcastable dimensions to the parameters that would have been expanded by the size
padleft = batched_dims - (param.ndim - param_ndim_supp)
if padleft > 0:
param = shape_padleft(param, padleft)
# Add the parameter support dimension indexes to the batched dimensions Dimshuffle
param_new_order = batched_dims_ds_order + tuple(
range(batched_dims, batched_dims + param_ndim_supp)
) )
new_dist_params.append(param.dimshuffle(param_new_order))
# Compute the new axes parameter(s) for the `DimShuffle` that will be new_node = rv_op.make_node(rng, new_size, dtype, *new_dist_params)
# applied to the `RandomVariable` parameters (they need to be offset)
if ds_ind_new_dims:
rv_params_new_order = [
d - reps_ind_split_idx if isinstance(d, int) else d
for d in ds_new_order[ds_ind_new_dims[0][0] :]
]
if not has_size and len(ds_new_order[: ds_ind_new_dims[0][0]]) > 0:
# Additional broadcast dimensions need to be added to the
# independent dimensions (i.e. parameters), since there's no
# `size` to which they can be added
rv_params_new_order = (
list(ds_new_order[: ds_ind_new_dims[0][0]]) + rv_params_new_order
)
else:
# This case is reached when, for example, `ds_new_order` only
# consists of new broadcastable dimensions (i.e. `"x"`s)
rv_params_new_order = ds_new_order
# Lift the `DimShuffle`s into the parameters
# NOTE: The parameters might not be broadcasted against each other, so
# we can only apply the parts of the `DimShuffle` that are relevant.
new_dist_params = []
for d in dist_params:
if d.ndim < len(ds_ind_new_dims):
_rv_params_new_order = [
o
for o in rv_params_new_order
if (isinstance(o, int) and o < d.ndim) or o == "x"
]
else:
_rv_params_new_order = rv_params_new_order
new_dist_params.append(
type(ds_op)(d.type.broadcastable, _rv_params_new_order)(d)
)
new_node = rv_op.make_node(rng, new_size, dtype, *new_dist_params)
if config.compute_test_value != "off":
compute_test_value(new_node)
out = new_node.outputs[1]
if base_rv.name:
out.name = f"{base_rv.name}_lifted"
return [out]
ds_in_reps_space = ds_reps_new_dims and all( if config.compute_test_value != "off":
d < reps_ind_split_idx for n, d in ds_reps_new_dims compute_test_value(new_node)
)
if ds_in_reps_space:
# Update the `size` array to reflect the `DimShuffle`d dimensions.
# There should be no need to `DimShuffle` now.
new_size = [
constant(1, dtype="int64") if o == "x" else size[o] for o in ds_new_order
]
new_node = rv_op.make_node(rng, new_size, dtype, *dist_params)
if config.compute_test_value != "off":
compute_test_value(new_node)
out = new_node.outputs[1]
if base_rv.name:
out.name = f"{base_rv.name}_lifted"
return [out]
return False out = new_node.outputs[1]
if base_rv.name:
out.name = f"{base_rv.name}_lifted"
return [out]
@node_rewriter([Subtensor, AdvancedSubtensor1, AdvancedSubtensor]) @node_rewriter([Subtensor, AdvancedSubtensor1, AdvancedSubtensor])
......
...@@ -9,6 +9,7 @@ from pytensor.graph.basic import Constant ...@@ -9,6 +9,7 @@ from pytensor.graph.basic import Constant
from pytensor.graph.fg import FunctionGraph from pytensor.graph.fg import FunctionGraph
from pytensor.graph.rewriting.basic import EquilibriumGraphRewriter from pytensor.graph.rewriting.basic import EquilibriumGraphRewriter
from pytensor.graph.rewriting.db import RewriteDatabaseQuery from pytensor.graph.rewriting.db import RewriteDatabaseQuery
from pytensor.tensor import constant
from pytensor.tensor.elemwise import DimShuffle from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.random.basic import ( from pytensor.tensor.random.basic import (
dirichlet, dirichlet,
...@@ -42,7 +43,11 @@ def apply_local_rewrite_to_rv( ...@@ -42,7 +43,11 @@ def apply_local_rewrite_to_rv(
size_at = [] size_at = []
for s in size: for s in size:
s_at = iscalar() # To test DimShuffle with dropping dims we need that size dimension to be constant
if s == 1:
s_at = constant(np.array(1, dtype="int32"))
else:
s_at = iscalar()
s_at.tag.test_value = s s_at.tag.test_value = s
size_at.append(s_at) size_at.append(s_at)
...@@ -314,7 +319,7 @@ def test_local_rv_size_lift(dist_op, dist_params, size): ...@@ -314,7 +319,7 @@ def test_local_rv_size_lift(dist_op, dist_params, size):
), ),
( (
("x", 1, 0, 2, "x"), ("x", 1, 0, 2, "x"),
False, True,
normal, normal,
( (
np.array([[-1, 20], [300, -4000]], dtype=config.floatX), np.array([[-1, 20], [300, -4000]], dtype=config.floatX),
...@@ -332,7 +337,30 @@ def test_local_rv_size_lift(dist_op, dist_params, size): ...@@ -332,7 +337,30 @@ def test_local_rv_size_lift(dist_op, dist_params, size):
(3, 2, 2), (3, 2, 2),
1, 1,
), ),
# A multi-dimensional case # Supported multi-dimensional cases
(
(1, 0, 2),
True,
multivariate_normal,
(
np.array([[-1, 20], [300, -4000]], dtype=config.floatX),
np.eye(2).astype(config.floatX) * 1e-6,
),
(3, 2),
1e-3,
),
(
(1, 0, "x", 2),
True,
multivariate_normal,
(
np.array([[-1, 20], [300, -4000]], dtype=config.floatX),
np.eye(2).astype(config.floatX) * 1e-6,
),
(3, 2),
1e-3,
),
# Not supported multi-dimensional cases where dimshuffle affects the support dimensionality
( (
(0, 2, 1), (0, 2, 1),
False, False,
...@@ -344,6 +372,35 @@ def test_local_rv_size_lift(dist_op, dist_params, size): ...@@ -344,6 +372,35 @@ def test_local_rv_size_lift(dist_op, dist_params, size):
(3, 2), (3, 2),
1e-3, 1e-3,
), ),
(
(0, 1, 2, "x"),
False,
multivariate_normal,
(
np.array([[-1, 20], [300, -4000]], dtype=config.floatX),
np.eye(2).astype(config.floatX) * 1e-6,
),
(3, 2),
1e-3,
),
pytest.param(
(1,),
True,
normal,
(0, 1),
(1, 2),
1e-3,
marks=pytest.mark.xfail(reason="Dropping dimensions not supported yet"),
),
pytest.param(
(1,),
True,
normal,
([[0, 0]], 1),
(1, 2),
1e-3,
marks=pytest.mark.xfail(reason="Dropping dimensions not supported yet"),
),
], ],
) )
@config.change_flags(compute_test_value_opt="raise", compute_test_value="raise") @config.change_flags(compute_test_value_opt="raise", compute_test_value="raise")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论