提交 2d81ccae authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Simplify RV rewrites

上级 94e9ef06
...@@ -5,21 +5,20 @@ from pytensor.configdefaults import config ...@@ -5,21 +5,20 @@ from pytensor.configdefaults import config
from pytensor.graph import ancestors from pytensor.graph import ancestors
from pytensor.graph.op import compute_test_value from pytensor.graph.op import compute_test_value
from pytensor.graph.rewriting.basic import copy_stack_trace, in2out, node_rewriter from pytensor.graph.rewriting.basic import copy_stack_trace, in2out, node_rewriter
from pytensor.scalar import integer_types from pytensor.tensor import NoneConst, TensorVariable
from pytensor.tensor import NoneConst
from pytensor.tensor.basic import constant from pytensor.tensor.basic import constant
from pytensor.tensor.elemwise import DimShuffle from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.extra_ops import broadcast_to from pytensor.tensor.extra_ops import broadcast_to
from pytensor.tensor.random.op import RandomVariable from pytensor.tensor.random.op import RandomVariable
from pytensor.tensor.random.utils import broadcast_params from pytensor.tensor.random.utils import broadcast_params
from pytensor.tensor.shape import Shape, Shape_i, shape_padleft from pytensor.tensor.shape import Shape, Shape_i
from pytensor.tensor.subtensor import ( from pytensor.tensor.subtensor import (
AdvancedSubtensor, AdvancedSubtensor,
AdvancedSubtensor1, AdvancedSubtensor1,
Subtensor, Subtensor,
as_index_variable,
get_idx_list, get_idx_list,
) )
from pytensor.tensor.type import integer_dtypes
from pytensor.tensor.type_other import NoneTypeT, SliceType from pytensor.tensor.type_other import NoneTypeT, SliceType
...@@ -127,22 +126,23 @@ def local_dimshuffle_rv_lift(fgraph, node): ...@@ -127,22 +126,23 @@ def local_dimshuffle_rv_lift(fgraph, node):
ds_op = node.op ds_op = node.op
if not isinstance(ds_op, DimShuffle): # Dimshuffle which drop dimensions not supported yet
if ds_op.drop:
return False return False
base_rv = node.inputs[0] rv_node = node.inputs[0].owner
rv_node = base_rv.owner
if not (rv_node and isinstance(rv_node.op, RandomVariable)): if not (rv_node and isinstance(rv_node.op, RandomVariable)):
return False return False
# Dimshuffle which drop dimensions not supported yet
if ds_op.drop:
return False
rv_op = rv_node.op rv_op = rv_node.op
rng, size, *dist_params = rv_node.inputs rng, size, *dist_params = rv_node.inputs
rv = rv_node.default_output() next_rng, rv = rv_node.outputs
# If no one else is using the underlying `RandomVariable`, then we can
# do this; otherwise, the graph would be internally inconsistent.
if is_rv_used_in_graph(rv, node, fgraph):
return False
# Check that Dimshuffle does not affect support dims # Check that Dimshuffle does not affect support dims
supp_dims = set(range(rv.ndim - rv_op.ndim_supp, rv.ndim)) supp_dims = set(range(rv.ndim - rv_op.ndim_supp, rv.ndim))
...@@ -153,17 +153,15 @@ def local_dimshuffle_rv_lift(fgraph, node): ...@@ -153,17 +153,15 @@ def local_dimshuffle_rv_lift(fgraph, node):
# If no one else is using the underlying RandomVariable, then we can # If no one else is using the underlying RandomVariable, then we can
# do this; otherwise, the graph would be internally inconsistent. # do this; otherwise, the graph would be internally inconsistent.
if is_rv_used_in_graph(base_rv, node, fgraph): if is_rv_used_in_graph(rv, node, fgraph):
return False return False
batched_dims = rv.ndim - rv_op.ndim_supp batched_dims = rv.ndim - rv_op.ndim_supp
batched_dims_ds_order = tuple(o for o in ds_op.new_order if o not in supp_dims) batched_dims_ds_order = tuple(o for o in ds_op.new_order if o not in supp_dims)
if isinstance(size.type, NoneTypeT): if isinstance(size.type, NoneTypeT):
# Make size explicit new_size = size
shape = tuple(broadcast_params(dist_params, rv_op.ndims_params)[0].shape) else:
size = shape[:batched_dims]
# Update the size to reflect the DimShuffled dimensions # Update the size to reflect the DimShuffled dimensions
new_size = [ new_size = [
constant(1, dtype="int64") if o == "x" else size[o] constant(1, dtype="int64") if o == "x" else size[o]
...@@ -173,11 +171,6 @@ def local_dimshuffle_rv_lift(fgraph, node): ...@@ -173,11 +171,6 @@ def local_dimshuffle_rv_lift(fgraph, node):
# Updates the params to reflect the Dimshuffled dimensions # Updates the params to reflect the Dimshuffled dimensions
new_dist_params = [] new_dist_params = []
for param, param_ndim_supp in zip(dist_params, rv_op.ndims_params): for param, param_ndim_supp in zip(dist_params, rv_op.ndims_params):
# Add broadcastable dimensions to the parameters that would have been expanded by the size
padleft = batched_dims - (param.ndim - param_ndim_supp)
if padleft > 0:
param = shape_padleft(param, padleft)
# Add the parameter support dimension indexes to the batched dimensions Dimshuffle # Add the parameter support dimension indexes to the batched dimensions Dimshuffle
param_new_order = batched_dims_ds_order + tuple( param_new_order = batched_dims_ds_order + tuple(
range(batched_dims, batched_dims + param_ndim_supp) range(batched_dims, batched_dims + param_ndim_supp)
...@@ -189,10 +182,10 @@ def local_dimshuffle_rv_lift(fgraph, node): ...@@ -189,10 +182,10 @@ def local_dimshuffle_rv_lift(fgraph, node):
if config.compute_test_value != "off": if config.compute_test_value != "off":
compute_test_value(new_node) compute_test_value(new_node)
out = new_node.outputs[1] new_rv = new_node.default_output()
if base_rv.name: if rv.name:
out.name = f"{base_rv.name}_lifted" new_rv.name = f"{rv.name}_lifted"
return [out] return [new_rv]
@node_rewriter([Subtensor, AdvancedSubtensor1, AdvancedSubtensor]) @node_rewriter([Subtensor, AdvancedSubtensor1, AdvancedSubtensor])
...@@ -206,7 +199,9 @@ def local_subtensor_rv_lift(fgraph, node): ...@@ -206,7 +199,9 @@ def local_subtensor_rv_lift(fgraph, node):
``mvnormal(mu, cov, size=(2,))[0, 0]``. ``mvnormal(mu, cov, size=(2,))[0, 0]``.
""" """
def is_nd_advanced_idx(idx, dtype): def is_nd_advanced_idx(idx, dtype) -> bool:
if not isinstance(idx, TensorVariable):
return False
if isinstance(dtype, str): if isinstance(dtype, str):
return (getattr(idx.type, "dtype", None) == dtype) and (idx.type.ndim >= 1) return (getattr(idx.type, "dtype", None) == dtype) and (idx.type.ndim >= 1)
else: else:
...@@ -214,39 +209,28 @@ def local_subtensor_rv_lift(fgraph, node): ...@@ -214,39 +209,28 @@ def local_subtensor_rv_lift(fgraph, node):
subtensor_op = node.op subtensor_op = node.op
old_subtensor = node.outputs[0] [indexed_rv] = node.outputs
rv = node.inputs[0] rv_node = node.inputs[0].owner
rv_node = rv.owner
if not (rv_node and isinstance(rv_node.op, RandomVariable)): if not (rv_node and isinstance(rv_node.op, RandomVariable)):
return False return False
shape_feature = getattr(fgraph, "shape_feature", None)
if not shape_feature:
return None
# Use shape_feature to facilitate inferring final shape.
# Check that neither the RV nor the old Subtensor are in the shape graph.
output_shape = fgraph.shape_feature.shape_of.get(old_subtensor, None)
if output_shape is None or {old_subtensor, rv} & set(ancestors(output_shape)):
return None
rv_op = rv_node.op rv_op = rv_node.op
rng, size, *dist_params = rv_node.inputs rng, size, *dist_params = rv_node.inputs
rv = rv_node.default_output()
# If no one else is using the underlying `RandomVariable`, then we can
# do this; otherwise, the graph would be internally inconsistent.
if is_rv_used_in_graph(rv, node, fgraph):
return False
# Parse indices # Parse indices
idx_list = getattr(subtensor_op, "idx_list", None) indices = get_idx_list(node.inputs, getattr(subtensor_op, "idx_list", None))
if idx_list:
idx_vars = get_idx_list(node.inputs, idx_list)
else:
idx_vars = node.inputs[1:]
indices = tuple(as_index_variable(idx) for idx in idx_vars)
# The rewrite doesn't apply if advanced indexing could broadcast the samples (leading to duplicates) # The rewrite doesn't apply if advanced indexing could broadcast the samples (leading to duplicates)
# Note: For simplicity this also excludes subtensor-related expand_dims (np.newaxis). # Note: For simplicity this also excludes subtensor-related expand_dims (np.newaxis).
# If we wanted to support that we could rewrite it as subtensor + dimshuffle # If we wanted to support that we could rewrite it as subtensor + dimshuffle
# and make use of the dimshuffle lift rewrite # and make use of the dimshuffle lift rewrite
integer_dtypes = {type.dtype for type in integer_types}
if any( if any(
is_nd_advanced_idx(idx, integer_dtypes) or NoneConst.equals(idx) is_nd_advanced_idx(idx, integer_dtypes) or NoneConst.equals(idx)
for idx in indices for idx in indices
...@@ -277,12 +261,20 @@ def local_subtensor_rv_lift(fgraph, node): ...@@ -277,12 +261,20 @@ def local_subtensor_rv_lift(fgraph, node):
n_discarded_idxs = len(supp_indices) n_discarded_idxs = len(supp_indices)
indices = indices[:-n_discarded_idxs] indices = indices[:-n_discarded_idxs]
# If no one else is using the underlying `RandomVariable`, then we can
# do this; otherwise, the graph would be internally inconsistent.
if is_rv_used_in_graph(rv, node, fgraph):
return False
# Update the size to reflect the indexed dimensions # Update the size to reflect the indexed dimensions
if isinstance(size.type, NoneTypeT):
new_size = size
else:
shape_feature = getattr(fgraph, "shape_feature", None)
if not shape_feature:
return None
# Use shape_feature to facilitate inferring final shape.
# Check that neither the RV nor the old Subtensor are in the shape graph.
output_shape = fgraph.shape_feature.shape_of.get(indexed_rv, None)
if output_shape is None or {indexed_rv, rv} & set(ancestors(output_shape)):
return None
new_size = output_shape[: len(output_shape) - rv_op.ndim_supp] new_size = output_shape[: len(output_shape) - rv_op.ndim_supp]
# Propagate indexing to the parameters' batch dims. # Propagate indexing to the parameters' batch dims.
...@@ -291,20 +283,13 @@ def local_subtensor_rv_lift(fgraph, node): ...@@ -291,20 +283,13 @@ def local_subtensor_rv_lift(fgraph, node):
# should still correctly broadcast any degenerate parameter dims. # should still correctly broadcast any degenerate parameter dims.
new_dist_params = [] new_dist_params = []
for param, param_ndim_supp in zip(dist_params, rv_op.ndims_params): for param, param_ndim_supp in zip(dist_params, rv_op.ndims_params):
# We first expand any missing parameter dims (and later index them away or keep them with none-slicing) # Check which dims are broadcasted by either size or other parameters
batch_param_dims_missing = batch_ndims - (param.ndim - param_ndim_supp) bcast_param_dims = tuple(
batch_param = (
shape_padleft(param, batch_param_dims_missing)
if batch_param_dims_missing
else param
)
# Check which dims are actually broadcasted
bcast_batch_param_dims = tuple(
dim dim
for dim, (param_dim, output_dim) in enumerate( for dim, (param_dim_bcast, output_dim_bcast) in enumerate(
zip(batch_param.type.shape, rv.type.shape) zip(param.type.broadcastable, rv.type.broadcastable)
) )
if (param_dim == 1) and (output_dim != 1) if param_dim_bcast and not output_dim_bcast
) )
batch_indices = [] batch_indices = []
curr_dim = 0 curr_dim = 0
...@@ -315,23 +300,23 @@ def local_subtensor_rv_lift(fgraph, node): ...@@ -315,23 +300,23 @@ def local_subtensor_rv_lift(fgraph, node):
# If not, we use that directly, instead of the more inefficient `nonzero` form # If not, we use that directly, instead of the more inefficient `nonzero` form
bool_dims = range(curr_dim, curr_dim + idx.type.ndim) bool_dims = range(curr_dim, curr_dim + idx.type.ndim)
# There's an overlap, we have to decompose the boolean mask as a `nonzero` # There's an overlap, we have to decompose the boolean mask as a `nonzero`
if set(bool_dims) & set(bcast_batch_param_dims): if set(bool_dims) & set(bcast_param_dims):
int_indices = list(idx.nonzero()) int_indices = list(idx.nonzero())
# Indexing by 0 drops the degenerate dims # Indexing by 0 drops the degenerate dims
for bool_dim in bool_dims: for bool_dim in bool_dims:
if bool_dim in bcast_batch_param_dims: if bool_dim in bcast_param_dims:
int_indices[bool_dim - curr_dim] = 0 int_indices[bool_dim - curr_dim] = 0
batch_indices.extend(int_indices) batch_indices.extend(int_indices)
# No overlap, use index as is # No overlap, use boolean index as is
else: else:
batch_indices.append(idx) batch_indices.append(idx)
curr_dim += len(bool_dims) curr_dim += len(bool_dims)
# Basic-indexing (slice or integer) # Basic-indexing (slice or integer)
else: else:
# Broadcasted dim # Broadcasted dim
if curr_dim in bcast_batch_param_dims: if curr_dim in bcast_param_dims:
# Slice indexing, keep degenerate dim by none-slicing # Slice indexing, keep degenerate dim by none-slicing
if isinstance(idx.type, SliceType): if isinstance(idx, slice) or isinstance(idx.type, SliceType):
batch_indices.append(slice(None)) batch_indices.append(slice(None))
# Integer indexing, drop degenerate dim by 0-indexing # Integer indexing, drop degenerate dim by 0-indexing
else: else:
...@@ -342,7 +327,7 @@ def local_subtensor_rv_lift(fgraph, node): ...@@ -342,7 +327,7 @@ def local_subtensor_rv_lift(fgraph, node):
batch_indices.append(idx) batch_indices.append(idx)
curr_dim += 1 curr_dim += 1
new_dist_params.append(batch_param[tuple(batch_indices)]) new_dist_params.append(param[tuple(batch_indices)])
# Create new RV # Create new RV
new_node = rv_op.make_node(rng, new_size, *new_dist_params) new_node = rv_op.make_node(rng, new_size, *new_dist_params)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论