提交 7a82a3f4 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Fix advanced indexing in subtensor_rv_lift

Also excludes the following cases: 1. expand_dims via broadcasting 2. multi-dimensional integer indexing (could lead to duplicates which is inconsitent with the lifted RV graph)
上级 9df54a54
from itertools import zip_longest from itertools import chain
from pytensor.compile import optdb from pytensor.compile import optdb
from pytensor.configdefaults import config from pytensor.configdefaults import config
from pytensor.graph import ancestors
from pytensor.graph.op import compute_test_value from pytensor.graph.op import compute_test_value
from pytensor.graph.rewriting.basic import in2out, node_rewriter from pytensor.graph.rewriting.basic import copy_stack_trace, in2out, node_rewriter
from pytensor.scalar import integer_types
from pytensor.tensor import NoneConst from pytensor.tensor import NoneConst
from pytensor.tensor.basic import constant, get_vector_length from pytensor.tensor.basic import constant, get_vector_length
from pytensor.tensor.elemwise import DimShuffle from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.extra_ops import broadcast_to from pytensor.tensor.extra_ops import broadcast_to
from pytensor.tensor.math import sum as at_sum
from pytensor.tensor.random.op import RandomVariable from pytensor.tensor.random.op import RandomVariable
from pytensor.tensor.random.utils import broadcast_params from pytensor.tensor.random.utils import broadcast_params
from pytensor.tensor.shape import Shape, Shape_i, shape_padleft from pytensor.tensor.shape import Shape, Shape_i, shape_padleft
...@@ -18,7 +19,6 @@ from pytensor.tensor.subtensor import ( ...@@ -18,7 +19,6 @@ from pytensor.tensor.subtensor import (
Subtensor, Subtensor,
as_index_variable, as_index_variable,
get_idx_list, get_idx_list,
indexed_result_shape,
) )
from pytensor.tensor.type_other import SliceType from pytensor.tensor.type_other import SliceType
...@@ -207,47 +207,76 @@ def local_subtensor_rv_lift(fgraph, node): ...@@ -207,47 +207,76 @@ def local_subtensor_rv_lift(fgraph, node):
``mvnormal(mu, cov, size=(2,))[0, 0]``. ``mvnormal(mu, cov, size=(2,))[0, 0]``.
""" """
st_op = node.op def is_nd_advanced_idx(idx, dtype):
if isinstance(dtype, str):
return (getattr(idx.type, "dtype", None) == dtype) and (idx.type.ndim >= 1)
else:
return (getattr(idx.type, "dtype", None) in dtype) and (idx.type.ndim >= 1)
if not isinstance(st_op, (AdvancedSubtensor, AdvancedSubtensor1, Subtensor)): subtensor_op = node.op
return False
old_subtensor = node.outputs[0]
rv = node.inputs[0] rv = node.inputs[0]
rv_node = rv.owner rv_node = rv.owner
if not (rv_node and isinstance(rv_node.op, RandomVariable)): if not (rv_node and isinstance(rv_node.op, RandomVariable)):
return False return False
shape_feature = getattr(fgraph, "shape_feature", None)
if not shape_feature:
return None
# Use shape_feature to facilitate inferring final shape.
# Check that neither the RV nor the old Subtensor are in the shape graph.
output_shape = fgraph.shape_feature.shape_of.get(old_subtensor, None)
if output_shape is None or {old_subtensor, rv} & set(ancestors(output_shape)):
return None
rv_op = rv_node.op rv_op = rv_node.op
rng, size, dtype, *dist_params = rv_node.inputs rng, size, dtype, *dist_params = rv_node.inputs
# Parse indices # Parse indices
idx_list = getattr(st_op, "idx_list", None) idx_list = getattr(subtensor_op, "idx_list", None)
if idx_list: if idx_list:
cdata = get_idx_list(node.inputs, idx_list) idx_vars = get_idx_list(node.inputs, idx_list)
else: else:
cdata = node.inputs[1:] idx_vars = node.inputs[1:]
st_indices, st_is_bool = zip( indices = tuple(as_index_variable(idx) for idx in idx_vars)
*tuple(
(as_index_variable(i), getattr(i, "dtype", None) == "bool") for i in cdata # The rewrite doesn't apply if advanced indexing could broadcast the samples (leading to duplicates)
) # Note: For simplicity this also excludes subtensor-related expand_dims (np.newaxis).
) # If we wanted to support that we could rewrite it as subtensor + dimshuffle
# and make use of the dimshuffle lift rewrite
integer_dtypes = {type.dtype for type in integer_types}
if any(
is_nd_advanced_idx(idx, integer_dtypes) or NoneConst.equals(idx)
for idx in indices
):
return False
# Check that indexing does not act on support dims # Check that indexing does not act on support dims
batched_ndims = rv.ndim - rv_op.ndim_supp batch_ndims = rv.ndim - rv_op.ndim_supp
if len(st_indices) > batched_ndims: # We decompose the boolean indexes, which makes it clear whether they act on support dims or not
# If the last indexes are just dummy `slice(None)` we discard them non_bool_indices = tuple(
st_is_bool = st_is_bool[:batched_ndims] chain.from_iterable(
st_indices, supp_indices = ( idx.nonzero() if is_nd_advanced_idx(idx, "bool") else (idx,)
st_indices[:batched_ndims], for idx in indices
st_indices[batched_ndims:], )
) )
for index in supp_indices: if len(non_bool_indices) > batch_ndims:
# If the last indexes are just dummy `slice(None)` we discard them instead of quitting
non_bool_indices, supp_indices = (
non_bool_indices[:batch_ndims],
non_bool_indices[batch_ndims:],
)
for idx in supp_indices:
if not ( if not (
isinstance(index.type, SliceType) isinstance(idx.type, SliceType)
and all(NoneConst.equals(i) for i in index.owner.inputs) and all(NoneConst.equals(i) for i in idx.owner.inputs)
): ):
return False return False
n_discarded_idxs = len(supp_indices)
indices = indices[:-n_discarded_idxs]
# If no one else is using the underlying `RandomVariable`, then we can # If no one else is using the underlying `RandomVariable`, then we can
# do this; otherwise, the graph would be internally inconsistent. # do this; otherwise, the graph would be internally inconsistent.
...@@ -255,50 +284,71 @@ def local_subtensor_rv_lift(fgraph, node): ...@@ -255,50 +284,71 @@ def local_subtensor_rv_lift(fgraph, node):
return False return False
# Update the size to reflect the indexed dimensions # Update the size to reflect the indexed dimensions
# TODO: Could use `ShapeFeature` info. We would need to be sure that new_size = output_shape[: len(output_shape) - rv_op.ndim_supp]
# `node` isn't in the results, though.
# if hasattr(fgraph, "shape_feature"):
# output_shape = fgraph.shape_feature.shape_of(node.outputs[0])
# else:
output_shape_ignoring_bool = indexed_result_shape(rv.shape, st_indices)
new_size_ignoring_boolean = (
output_shape_ignoring_bool
if rv_op.ndim_supp == 0
else output_shape_ignoring_bool[: -rv_op.ndim_supp]
)
# Boolean indices can actually change the `size` value (compared to just *which* dimensions of `size` are used). # Propagate indexing to the parameters' batch dims.
# The `indexed_result_shape` helper does not consider this # We try to avoid broadcasting the parameters together (and with size), by only indexing
if any(st_is_bool): # non-broadcastable (non-degenerate) parameter dims. These parameters and the new size
new_size = tuple( # should still correctly broadcast any degenerate parameter dims.
at_sum(idx) if is_bool else s new_dist_params = []
for s, is_bool, idx in zip_longest( for param, param_ndim_supp in zip(dist_params, rv_op.ndims_params):
new_size_ignoring_boolean, st_is_bool, st_indices, fillvalue=False # We first expand any missing parameter dims (and later index them away or keep them with none-slicing)
batch_param_dims_missing = batch_ndims - (param.ndim - param_ndim_supp)
batch_param = (
shape_padleft(param, batch_param_dims_missing)
if batch_param_dims_missing
else param
)
# Check which dims are actually broadcasted
bcast_batch_param_dims = tuple(
dim
for dim, (param_dim, output_dim) in enumerate(
zip(batch_param.type.shape, rv.type.shape)
) )
if (param_dim == 1) and (output_dim != 1)
) )
batch_indices = []
curr_dim = 0
for idx in indices:
# Advanced boolean indexing
if is_nd_advanced_idx(idx, "bool"):
# Check if any broadcasted dim overlaps with advanced boolean indexing.
# If not, we use that directly, instead of the more inefficient `nonzero` form
bool_dims = range(curr_dim, curr_dim + idx.type.ndim)
# There's an overlap, we have to decompose the boolean mask as a `nonzero`
if set(bool_dims) & set(bcast_batch_param_dims):
int_indices = list(idx.nonzero())
# Indexing by 0 drops the degenerate dims
for bool_dim in bool_dims:
if bool_dim in bcast_batch_param_dims:
int_indices[bool_dim - curr_dim] = 0
batch_indices.extend(int_indices)
# No overlap, use index as is
else: else:
new_size = new_size_ignoring_boolean batch_indices.append(idx)
curr_dim += len(bool_dims)
# Update the parameters to reflect the indexed dimensions # Basic-indexing (slice or integer)
new_dist_params = []
for param, param_ndim_supp in zip(dist_params, rv_op.ndims_params):
# Apply indexing on the batched dimensions of the parameter
batched_param_dims_missing = batched_ndims - (param.ndim - param_ndim_supp)
batched_param = shape_padleft(param, batched_param_dims_missing)
batched_st_indices = []
for st_index, batched_param_shape in zip(st_indices, batched_param.type.shape):
# If we have a degenerate dimension indexing it should always do the job
if batched_param_shape == 1:
batched_st_indices.append(0)
else: else:
batched_st_indices.append(st_index) # Broadcasted dim
new_dist_params.append(batched_param[tuple(batched_st_indices)]) if curr_dim in bcast_batch_param_dims:
# Slice indexing, keep degenerate dim by none-slicing
if isinstance(idx.type, SliceType):
batch_indices.append(slice(None))
# Integer indexing, drop degenerate dim by 0-indexing
else:
batch_indices.append(0)
# Non-broadcasted dim
else:
# Use index as is
batch_indices.append(idx)
curr_dim += 1
new_dist_params.append(batch_param[tuple(batch_indices)])
# Create new RV # Create new RV
new_node = rv_op.make_node(rng, new_size, dtype, *new_dist_params) new_node = rv_op.make_node(rng, new_size, dtype, *new_dist_params)
new_rv = new_node.default_output() new_rv = new_node.default_output()
if config.compute_test_value != "off": copy_stack_trace(rv, new_rv)
compute_test_value(new_node)
return [new_rv] return [new_rv]
...@@ -26,6 +26,7 @@ from pytensor.tensor.random.rewriting import ( ...@@ -26,6 +26,7 @@ from pytensor.tensor.random.rewriting import (
local_rv_size_lift, local_rv_size_lift,
local_subtensor_rv_lift, local_subtensor_rv_lift,
) )
from pytensor.tensor.rewriting.shape import ShapeFeature, ShapeOptimizer
from pytensor.tensor.subtensor import AdvancedSubtensor, AdvancedSubtensor1, Subtensor from pytensor.tensor.subtensor import AdvancedSubtensor, AdvancedSubtensor1, Subtensor
from pytensor.tensor.type import iscalar, vector from pytensor.tensor.type import iscalar, vector
...@@ -58,7 +59,9 @@ def apply_local_rewrite_to_rv( ...@@ -58,7 +59,9 @@ def apply_local_rewrite_to_rv(
p for p in dist_params_at + size_at if not isinstance(p, (slice, Constant)) p for p in dist_params_at + size_at if not isinstance(p, (slice, Constant))
] ]
mode = Mode("py", EquilibriumGraphRewriter([rewrite], max_use_ratio=100)) mode = Mode(
"py", EquilibriumGraphRewriter([ShapeOptimizer(), rewrite], max_use_ratio=100)
)
f_rewritten = function( f_rewritten = function(
f_inputs, f_inputs,
...@@ -440,30 +443,48 @@ def test_DimShuffle_lift(ds_order, lifted, dist_op, dist_params, size, rtol): ...@@ -440,30 +443,48 @@ def test_DimShuffle_lift(ds_order, lifted, dist_op, dist_params, size, rtol):
np.testing.assert_allclose(res_base, res_rewritten, rtol=rtol) np.testing.assert_allclose(res_base, res_rewritten, rtol=rtol)
def rand_bool_mask(shape, rng=None):
if rng is None:
rng = np.random.default_rng()
return rng.binomial(n=1, p=0.5, size=shape).astype(bool)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"indices, lifted, dist_op, dist_params, size", "indices, lifted, dist_op, dist_params, size",
[ [
# 0
( (
# `size`-less advanced boolean indexing # `size`-less simple integer indexing
(np.r_[True, False, False, True],), (slice(None), 2),
True, True,
uniform, normal,
( (
(0.1 - 1e-5) * np.arange(4).astype(dtype=config.floatX), np.arange(30, dtype=config.floatX).reshape(3, 5, 2),
0.1 * np.arange(4).astype(dtype=config.floatX), np.full((1, 5, 1), 1e-6),
), ),
(), (),
), ),
( (
# `size`-only advanced boolean indexing # `size`-only slice
(np.r_[True, False, False, True],), (2, -1),
True, True,
uniform, uniform,
( (
np.array(0.9 - 1e-5, dtype=config.floatX), np.array(0.9 - 1e-5, dtype=config.floatX),
np.array(0.9, dtype=config.floatX), np.array(0.9, dtype=config.floatX),
), ),
(4,), (5, 2),
),
(
# `size`-less slice
(slice(None), slice(4, -6, -1), slice(1, None)),
True,
normal,
(
np.arange(30, dtype=config.floatX).reshape(3, 5, 2),
np.full((1, 5, 1), 1e-6),
),
(),
), ),
( (
# `size`-only slice # `size`-only slice
...@@ -477,8 +498,32 @@ def test_DimShuffle_lift(ds_order, lifted, dist_op, dist_params, size, rtol): ...@@ -477,8 +498,32 @@ def test_DimShuffle_lift(ds_order, lifted, dist_op, dist_params, size, rtol):
(5, 2), (5, 2),
), ),
( (
(slice(1, None), [0, 2]), # `size`-less advanced boolean indexing
(np.r_[True, False, False, True],),
True,
uniform,
(
(0.1 - 1e-5) * np.arange(4).astype(dtype=config.floatX),
0.1 * np.arange(4).astype(dtype=config.floatX),
),
(),
),
# 5
(
# `size`-only advanced boolean indexing
(np.r_[True, False, False, True],),
True, True,
uniform,
(
np.array(0.9 - 1e-5, dtype=config.floatX),
np.array(0.9, dtype=config.floatX),
),
(4,),
),
(
# Advanced integer indexing
(slice(1, None), [0, 2]),
False, # Could have duplicates
normal, normal,
( (
np.array([1, 10, 100], dtype=config.floatX), np.array([1, 10, 100], dtype=config.floatX),
...@@ -487,8 +532,9 @@ def test_DimShuffle_lift(ds_order, lifted, dist_op, dist_params, size, rtol): ...@@ -487,8 +532,9 @@ def test_DimShuffle_lift(ds_order, lifted, dist_op, dist_params, size, rtol):
(4, 3), (4, 3),
), ),
( (
# Advanced integer indexing
(np.array([1]), 0), (np.array([1]), 0),
True, False, # We don't support expand_dims
normal, normal,
( (
np.array([[-1, 20], [300, -4000]], dtype=config.floatX), np.array([[-1, 20], [300, -4000]], dtype=config.floatX),
...@@ -496,23 +542,39 @@ def test_DimShuffle_lift(ds_order, lifted, dist_op, dist_params, size, rtol): ...@@ -496,23 +542,39 @@ def test_DimShuffle_lift(ds_order, lifted, dist_op, dist_params, size, rtol):
), ),
(3, 2, 2), (3, 2, 2),
), ),
# Only one distribution parameter
( (
(0,), # Advanced integer-boolean indexing
(0, np.r_[True, False]),
True, True,
poisson, normal,
(np.array([[1, 2], [3, 4]], dtype=config.floatX),), (
np.array([[1, 2], [3, 4]], dtype=config.floatX),
np.array([1e-6], dtype=config.floatX),
),
(3, 2, 2), (3, 2, 2),
), ),
# Univariate distribution with vector parameters
( (
(np.array([0, 2]),), # Advanced non-consecutive integer-boolean indexing
(slice(None), 0, slice(None), np.r_[True, False]),
True,
normal,
(
np.array([[1, 2], [3, 4]], dtype=config.floatX),
np.array([[1e-6]], dtype=config.floatX),
),
(7, 3, 2, 2),
),
# 10
(
# Univariate distribution with core-vector parameters
(1,),
True, True,
categorical, categorical,
(np.array([0.0, 0.0, 1.0], dtype=config.floatX),), (np.array([0.0, 0.0, 1.0], dtype=config.floatX),),
(4,), (4,),
), ),
( (
# Univariate distribution with core-vector parameters
(np.array([True, False, True, True]),), (np.array([True, False, True, True]),),
True, True,
categorical, categorical,
...@@ -520,6 +582,7 @@ def test_DimShuffle_lift(ds_order, lifted, dist_op, dist_params, size, rtol): ...@@ -520,6 +582,7 @@ def test_DimShuffle_lift(ds_order, lifted, dist_op, dist_params, size, rtol):
(4,), (4,),
), ),
( (
# Univariate distribution with core-vector parameters
(np.array([True, False, True]),), (np.array([True, False, True]),),
True, True,
categorical, categorical,
...@@ -532,10 +595,8 @@ def test_DimShuffle_lift(ds_order, lifted, dist_op, dist_params, size, rtol): ...@@ -532,10 +595,8 @@ def test_DimShuffle_lift(ds_order, lifted, dist_op, dist_params, size, rtol):
(), (),
), ),
( (
( # Univariate distribution with core-vector parameters
slice(None), (slice(None), np.array([True, False, True])),
np.array([True, False, True]),
),
True, True,
categorical, categorical,
( (
...@@ -546,16 +607,18 @@ def test_DimShuffle_lift(ds_order, lifted, dist_op, dist_params, size, rtol): ...@@ -546,16 +607,18 @@ def test_DimShuffle_lift(ds_order, lifted, dist_op, dist_params, size, rtol):
), ),
(4, 3), (4, 3),
), ),
# Boolean indexing where output is empty
( (
# Boolean indexing where output is empty
(np.array([False, False]),), (np.array([False, False]),),
True, True,
normal, normal,
(np.array([[1.0, 0.0, 0.0]], dtype=config.floatX),), (np.array([[1.0, 0.0, 0.0]], dtype=config.floatX),),
(2, 3), (2, 3),
), ),
# 15
( (
(np.array([False, False]),), # Boolean indexing where output is empty
(np.array([False, False]), slice(1, None)),
True, True,
categorical, categorical,
( (
...@@ -566,10 +629,107 @@ def test_DimShuffle_lift(ds_order, lifted, dist_op, dist_params, size, rtol): ...@@ -566,10 +629,107 @@ def test_DimShuffle_lift(ds_order, lifted, dist_op, dist_params, size, rtol):
), ),
(2, 3), (2, 3),
), ),
# Multivariate cases, indexing only supported if it does not affect core dimensions
( (
# Indexing dips into core dimension # Empty-slice
(np.array([1]), 0), (slice(None), slice(10, None), slice(1, None)),
True,
normal,
(
np.arange(30).reshape(2, 3, 5),
np.full((1, 5), 1e-6),
),
(2, 3, 5),
),
(
# Multidimensional boolean indexing
(rand_bool_mask((5, 3, 2)),),
True,
normal,
(
np.arange(30).reshape(5, 3, 2),
1e-6,
),
(),
),
(
# Multidimensional boolean indexing
(rand_bool_mask((5, 3)),),
True,
normal,
(
np.arange(30).reshape(5, 3, 2),
1e-6,
),
(),
),
(
# Multidimensional boolean indexing
(rand_bool_mask((5, 3)), slice(None)),
True,
normal,
(
np.arange(30).reshape(5, 3, 2),
1e-6,
),
(),
),
# 20
(
# Multidimensional boolean indexing
(slice(None), rand_bool_mask((3, 2))),
True,
normal,
(
np.arange(30).reshape(5, 3, 2),
1e-6,
),
(),
),
(
# Multidimensional boolean indexing
(rand_bool_mask((5, 3)),),
True,
normal,
(
np.arange(3).reshape(1, 3, 1),
np.full((5, 1, 2), 1e-6),
),
(5, 3, 2),
),
(
# Multidimensional boolean indexing
(
np.array([True, False, True, False, False]),
slice(None),
(np.array([True, True])),
),
True,
normal,
(
np.arange(30).reshape(5, 3, 2),
1e-6,
),
(),
),
(
# Multidimensional boolean indexing,
# requires runtime broadcasting of the zeros arrays
(
np.array([True, False, True, False, False]), # nonzero().shape == (2,)
slice(None),
(np.array([True, False])), # nonzero().shape == (1,)
),
True,
normal,
(
np.arange(30).reshape(5, 3, 2),
1e-6,
),
(),
),
(
# Multivariate distribution: indexing dips into core dimension
(1, 0),
False, False,
multivariate_normal, multivariate_normal,
( (
...@@ -578,9 +738,22 @@ def test_DimShuffle_lift(ds_order, lifted, dist_op, dist_params, size, rtol): ...@@ -578,9 +738,22 @@ def test_DimShuffle_lift(ds_order, lifted, dist_op, dist_params, size, rtol):
), ),
(), (),
), ),
# 25
( (
(np.array([0, 2]),), # Multivariate distribution: indexing dips into core dimension
True, (rand_bool_mask((2, 2)),),
False,
multivariate_normal,
(
np.array([[-1, 20], [300, -4000]], dtype=config.floatX),
np.eye(2).astype(config.floatX) * 1e-6,
),
(),
),
(
# Multivariate distribution: advanced integer indexing
(np.array([0, 0]),),
False, # Could have duplicates (it has in this case)!
multivariate_normal, multivariate_normal,
( (
np.array( np.array(
...@@ -592,6 +765,7 @@ def test_DimShuffle_lift(ds_order, lifted, dist_op, dist_params, size, rtol): ...@@ -592,6 +765,7 @@ def test_DimShuffle_lift(ds_order, lifted, dist_op, dist_params, size, rtol):
(), (),
), ),
( (
# Multivariate distribution: dummy slice "dips" into core dimension
(np.array([True, False, True]), slice(None)), (np.array([True, False, True]), slice(None)),
True, True,
multivariate_normal, multivariate_normal,
...@@ -603,6 +777,17 @@ def test_DimShuffle_lift(ds_order, lifted, dist_op, dist_params, size, rtol): ...@@ -603,6 +777,17 @@ def test_DimShuffle_lift(ds_order, lifted, dist_op, dist_params, size, rtol):
), ),
(3,), (3,),
), ),
(
# Multivariate distribution
(0, slice(1, None), rand_bool_mask((4, 3))),
True,
multivariate_normal,
(
np.arange(4 * 3 * 2).reshape(4, 3, 2).astype(dtype=config.floatX),
np.eye(2) * 1e-6,
),
(5, 3, 4, 3),
),
], ],
) )
@config.change_flags(compute_test_value_opt="raise", compute_test_value="raise") @config.change_flags(compute_test_value_opt="raise", compute_test_value="raise")
...@@ -650,7 +835,7 @@ def test_Subtensor_lift(indices, lifted, dist_op, dist_params, size): ...@@ -650,7 +835,7 @@ def test_Subtensor_lift(indices, lifted, dist_op, dist_params, size):
res_base = f_base(*arg_values) res_base = f_base(*arg_values)
res_rewritten = f_rewritten(*arg_values) res_rewritten = f_rewritten(*arg_values)
np.testing.assert_allclose(res_base, res_rewritten, rtol=1e-3) np.testing.assert_allclose(res_base, res_rewritten, rtol=1e-3, atol=1e-2)
def test_Subtensor_lift_restrictions(): def test_Subtensor_lift_restrictions():
...@@ -664,7 +849,7 @@ def test_Subtensor_lift_restrictions(): ...@@ -664,7 +849,7 @@ def test_Subtensor_lift_restrictions():
# the lift # the lift
z = x - y z = x - y
fg = FunctionGraph([rng], [z], clone=False) fg = FunctionGraph([rng], [z], clone=False, features=[ShapeFeature()])
_ = EquilibriumGraphRewriter([local_subtensor_rv_lift], max_use_ratio=100).apply(fg) _ = EquilibriumGraphRewriter([local_subtensor_rv_lift], max_use_ratio=100).apply(fg)
subtensor_node = fg.outputs[0].owner.inputs[1].owner.inputs[0].owner subtensor_node = fg.outputs[0].owner.inputs[1].owner.inputs[0].owner
...@@ -676,7 +861,7 @@ def test_Subtensor_lift_restrictions(): ...@@ -676,7 +861,7 @@ def test_Subtensor_lift_restrictions():
# We add `x` as an output to make sure that `is_rv_used_in_graph` handles # We add `x` as an output to make sure that `is_rv_used_in_graph` handles
# `"output"` "nodes" correctly. # `"output"` "nodes" correctly.
fg = FunctionGraph([rng], [z, x], clone=False) fg = FunctionGraph([rng], [z, x], clone=False, features=[ShapeFeature()])
EquilibriumGraphRewriter([local_subtensor_rv_lift], max_use_ratio=100).apply(fg) EquilibriumGraphRewriter([local_subtensor_rv_lift], max_use_ratio=100).apply(fg)
assert fg.outputs[0] == z assert fg.outputs[0] == z
...@@ -684,7 +869,7 @@ def test_Subtensor_lift_restrictions(): ...@@ -684,7 +869,7 @@ def test_Subtensor_lift_restrictions():
# The non-`Subtensor` client doesn't depend on the RNG state, so we can # The non-`Subtensor` client doesn't depend on the RNG state, so we can
# perform the lift # perform the lift
fg = FunctionGraph([rng], [z], clone=False) fg = FunctionGraph([rng], [z], clone=False, features=[ShapeFeature()])
EquilibriumGraphRewriter([local_subtensor_rv_lift], max_use_ratio=100).apply(fg) EquilibriumGraphRewriter([local_subtensor_rv_lift], max_use_ratio=100).apply(fg)
rv_node = fg.outputs[0].owner.inputs[1].owner.inputs[0].owner rv_node = fg.outputs[0].owner.inputs[1].owner.inputs[0].owner
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论