提交 7a82a3f4 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Fix advanced indexing in subtensor_rv_lift

Also excludes the following cases: 1. expand_dims via broadcasting 2. multi-dimensional integer indexing (could lead to duplicates which is inconsitent with the lifted RV graph)
上级 9df54a54
from itertools import zip_longest from itertools import chain
from pytensor.compile import optdb from pytensor.compile import optdb
from pytensor.configdefaults import config from pytensor.configdefaults import config
from pytensor.graph import ancestors
from pytensor.graph.op import compute_test_value from pytensor.graph.op import compute_test_value
from pytensor.graph.rewriting.basic import in2out, node_rewriter from pytensor.graph.rewriting.basic import copy_stack_trace, in2out, node_rewriter
from pytensor.scalar import integer_types
from pytensor.tensor import NoneConst from pytensor.tensor import NoneConst
from pytensor.tensor.basic import constant, get_vector_length from pytensor.tensor.basic import constant, get_vector_length
from pytensor.tensor.elemwise import DimShuffle from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.extra_ops import broadcast_to from pytensor.tensor.extra_ops import broadcast_to
from pytensor.tensor.math import sum as at_sum
from pytensor.tensor.random.op import RandomVariable from pytensor.tensor.random.op import RandomVariable
from pytensor.tensor.random.utils import broadcast_params from pytensor.tensor.random.utils import broadcast_params
from pytensor.tensor.shape import Shape, Shape_i, shape_padleft from pytensor.tensor.shape import Shape, Shape_i, shape_padleft
...@@ -18,7 +19,6 @@ from pytensor.tensor.subtensor import ( ...@@ -18,7 +19,6 @@ from pytensor.tensor.subtensor import (
Subtensor, Subtensor,
as_index_variable, as_index_variable,
get_idx_list, get_idx_list,
indexed_result_shape,
) )
from pytensor.tensor.type_other import SliceType from pytensor.tensor.type_other import SliceType
...@@ -207,47 +207,76 @@ def local_subtensor_rv_lift(fgraph, node): ...@@ -207,47 +207,76 @@ def local_subtensor_rv_lift(fgraph, node):
``mvnormal(mu, cov, size=(2,))[0, 0]``. ``mvnormal(mu, cov, size=(2,))[0, 0]``.
""" """
st_op = node.op def is_nd_advanced_idx(idx, dtype):
if isinstance(dtype, str):
return (getattr(idx.type, "dtype", None) == dtype) and (idx.type.ndim >= 1)
else:
return (getattr(idx.type, "dtype", None) in dtype) and (idx.type.ndim >= 1)
if not isinstance(st_op, (AdvancedSubtensor, AdvancedSubtensor1, Subtensor)): subtensor_op = node.op
return False
old_subtensor = node.outputs[0]
rv = node.inputs[0] rv = node.inputs[0]
rv_node = rv.owner rv_node = rv.owner
if not (rv_node and isinstance(rv_node.op, RandomVariable)): if not (rv_node and isinstance(rv_node.op, RandomVariable)):
return False return False
shape_feature = getattr(fgraph, "shape_feature", None)
if not shape_feature:
return None
# Use shape_feature to facilitate inferring final shape.
# Check that neither the RV nor the old Subtensor are in the shape graph.
output_shape = fgraph.shape_feature.shape_of.get(old_subtensor, None)
if output_shape is None or {old_subtensor, rv} & set(ancestors(output_shape)):
return None
rv_op = rv_node.op rv_op = rv_node.op
rng, size, dtype, *dist_params = rv_node.inputs rng, size, dtype, *dist_params = rv_node.inputs
# Parse indices # Parse indices
idx_list = getattr(st_op, "idx_list", None) idx_list = getattr(subtensor_op, "idx_list", None)
if idx_list: if idx_list:
cdata = get_idx_list(node.inputs, idx_list) idx_vars = get_idx_list(node.inputs, idx_list)
else: else:
cdata = node.inputs[1:] idx_vars = node.inputs[1:]
st_indices, st_is_bool = zip( indices = tuple(as_index_variable(idx) for idx in idx_vars)
*tuple(
(as_index_variable(i), getattr(i, "dtype", None) == "bool") for i in cdata # The rewrite doesn't apply if advanced indexing could broadcast the samples (leading to duplicates)
) # Note: For simplicity this also excludes subtensor-related expand_dims (np.newaxis).
) # If we wanted to support that we could rewrite it as subtensor + dimshuffle
# and make use of the dimshuffle lift rewrite
integer_dtypes = {type.dtype for type in integer_types}
if any(
is_nd_advanced_idx(idx, integer_dtypes) or NoneConst.equals(idx)
for idx in indices
):
return False
# Check that indexing does not act on support dims # Check that indexing does not act on support dims
batched_ndims = rv.ndim - rv_op.ndim_supp batch_ndims = rv.ndim - rv_op.ndim_supp
if len(st_indices) > batched_ndims: # We decompose the boolean indexes, which makes it clear whether they act on support dims or not
# If the last indexes are just dummy `slice(None)` we discard them non_bool_indices = tuple(
st_is_bool = st_is_bool[:batched_ndims] chain.from_iterable(
st_indices, supp_indices = ( idx.nonzero() if is_nd_advanced_idx(idx, "bool") else (idx,)
st_indices[:batched_ndims], for idx in indices
st_indices[batched_ndims:], )
) )
for index in supp_indices: if len(non_bool_indices) > batch_ndims:
# If the last indexes are just dummy `slice(None)` we discard them instead of quitting
non_bool_indices, supp_indices = (
non_bool_indices[:batch_ndims],
non_bool_indices[batch_ndims:],
)
for idx in supp_indices:
if not ( if not (
isinstance(index.type, SliceType) isinstance(idx.type, SliceType)
and all(NoneConst.equals(i) for i in index.owner.inputs) and all(NoneConst.equals(i) for i in idx.owner.inputs)
): ):
return False return False
n_discarded_idxs = len(supp_indices)
indices = indices[:-n_discarded_idxs]
# If no one else is using the underlying `RandomVariable`, then we can # If no one else is using the underlying `RandomVariable`, then we can
# do this; otherwise, the graph would be internally inconsistent. # do this; otherwise, the graph would be internally inconsistent.
...@@ -255,50 +284,71 @@ def local_subtensor_rv_lift(fgraph, node): ...@@ -255,50 +284,71 @@ def local_subtensor_rv_lift(fgraph, node):
return False return False
# Update the size to reflect the indexed dimensions # Update the size to reflect the indexed dimensions
# TODO: Could use `ShapeFeature` info. We would need to be sure that new_size = output_shape[: len(output_shape) - rv_op.ndim_supp]
# `node` isn't in the results, though.
# if hasattr(fgraph, "shape_feature"):
# output_shape = fgraph.shape_feature.shape_of(node.outputs[0])
# else:
output_shape_ignoring_bool = indexed_result_shape(rv.shape, st_indices)
new_size_ignoring_boolean = (
output_shape_ignoring_bool
if rv_op.ndim_supp == 0
else output_shape_ignoring_bool[: -rv_op.ndim_supp]
)
# Boolean indices can actually change the `size` value (compared to just *which* dimensions of `size` are used). # Propagate indexing to the parameters' batch dims.
# The `indexed_result_shape` helper does not consider this # We try to avoid broadcasting the parameters together (and with size), by only indexing
if any(st_is_bool): # non-broadcastable (non-degenerate) parameter dims. These parameters and the new size
new_size = tuple( # should still correctly broadcast any degenerate parameter dims.
at_sum(idx) if is_bool else s new_dist_params = []
for s, is_bool, idx in zip_longest( for param, param_ndim_supp in zip(dist_params, rv_op.ndims_params):
new_size_ignoring_boolean, st_is_bool, st_indices, fillvalue=False # We first expand any missing parameter dims (and later index them away or keep them with none-slicing)
batch_param_dims_missing = batch_ndims - (param.ndim - param_ndim_supp)
batch_param = (
shape_padleft(param, batch_param_dims_missing)
if batch_param_dims_missing
else param
)
# Check which dims are actually broadcasted
bcast_batch_param_dims = tuple(
dim
for dim, (param_dim, output_dim) in enumerate(
zip(batch_param.type.shape, rv.type.shape)
) )
if (param_dim == 1) and (output_dim != 1)
) )
batch_indices = []
curr_dim = 0
for idx in indices:
# Advanced boolean indexing
if is_nd_advanced_idx(idx, "bool"):
# Check if any broadcasted dim overlaps with advanced boolean indexing.
# If not, we use that directly, instead of the more inefficient `nonzero` form
bool_dims = range(curr_dim, curr_dim + idx.type.ndim)
# There's an overlap, we have to decompose the boolean mask as a `nonzero`
if set(bool_dims) & set(bcast_batch_param_dims):
int_indices = list(idx.nonzero())
# Indexing by 0 drops the degenerate dims
for bool_dim in bool_dims:
if bool_dim in bcast_batch_param_dims:
int_indices[bool_dim - curr_dim] = 0
batch_indices.extend(int_indices)
# No overlap, use index as is
else: else:
new_size = new_size_ignoring_boolean batch_indices.append(idx)
curr_dim += len(bool_dims)
# Update the parameters to reflect the indexed dimensions # Basic-indexing (slice or integer)
new_dist_params = []
for param, param_ndim_supp in zip(dist_params, rv_op.ndims_params):
# Apply indexing on the batched dimensions of the parameter
batched_param_dims_missing = batched_ndims - (param.ndim - param_ndim_supp)
batched_param = shape_padleft(param, batched_param_dims_missing)
batched_st_indices = []
for st_index, batched_param_shape in zip(st_indices, batched_param.type.shape):
# If we have a degenerate dimension indexing it should always do the job
if batched_param_shape == 1:
batched_st_indices.append(0)
else: else:
batched_st_indices.append(st_index) # Broadcasted dim
new_dist_params.append(batched_param[tuple(batched_st_indices)]) if curr_dim in bcast_batch_param_dims:
# Slice indexing, keep degenerate dim by none-slicing
if isinstance(idx.type, SliceType):
batch_indices.append(slice(None))
# Integer indexing, drop degenerate dim by 0-indexing
else:
batch_indices.append(0)
# Non-broadcasted dim
else:
# Use index as is
batch_indices.append(idx)
curr_dim += 1
new_dist_params.append(batch_param[tuple(batch_indices)])
# Create new RV # Create new RV
new_node = rv_op.make_node(rng, new_size, dtype, *new_dist_params) new_node = rv_op.make_node(rng, new_size, dtype, *new_dist_params)
new_rv = new_node.default_output() new_rv = new_node.default_output()
if config.compute_test_value != "off": copy_stack_trace(rv, new_rv)
compute_test_value(new_node)
return [new_rv] return [new_rv]
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论